diff --git a/.github/workflows/build_base_and_push.yml b/.github/workflows/build_base_and_push.yml new file mode 100644 index 00000000..2b565bbf --- /dev/null +++ b/.github/workflows/build_base_and_push.yml @@ -0,0 +1,104 @@ +name: build-base-and-push + +run-name: 构建镜像并推送仓库 ${{ github.event.inputs.dockerImageTag }} (${{ github.event.inputs.registry }}) (${{ github.event.inputs.architecture }}) + +on: + workflow_dispatch: + inputs: + dockerImageTag: + description: 'Image Tag' + default: 'v0.9.0' + required: true + dockerImageTagWithLatest: + description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)' + default: false + required: true + type: boolean + architecture: + description: 'Architecture' + required: true + default: 'linux/amd64' + type: choice + options: + - linux/amd64 + - linux/arm64 + - linux/amd64,linux/arm64 + registry: + description: 'Push To Registry' + required: true + default: 'aliyun-registry' + type: choice + options: + - aliyun-registry + - dockerhub + - dockerhub, aliyun-registry + +jobs: + build-and-push-to-aliyun-registry: + if: ${{ contains(github.event.inputs.registry, 'aliyun') }} + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + - name: Prepare + id: prepare + run: | + DOCKER_IMAGE=${{ secrets.ALIYUN_REGISTRY_HOST }}/dataease/sqlbot-python-pg + DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} + TAG_NAME=${{ github.event.inputs.dockerImageTag }} + TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }} + if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*} --tag ${DOCKER_IMAGE}:latest" + else + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" + fi + echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \ + ${DOCKER_IMAGE_TAGS} . + - name: Set up Docker Buildx + uses: crazy-max/ghaction-docker-buildx@v3 + - name: Login to Aliyun Registry + uses: docker/login-action@v2 + with: + registry: ${{ secrets.ALIYUN_REGISTRY_HOST }} + username: ${{ secrets.ALIYUN_REGISTRY_USERNAME }} + password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }} + - name: Docker Buildx (build-and-push) + run: | + docker buildx build -f Dockerfile-base --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} + + build-and-push-to-dockerhub: + if: ${{ contains(github.event.inputs.registry, 'dockerhub') }} + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + - name: Prepare + id: prepare + run: | + DOCKER_IMAGE=dataease/sqlbot-python-pg + DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} + TAG_NAME=${{ github.event.inputs.dockerImageTag }} + TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }} + if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*} --tag ${DOCKER_IMAGE}:latest" + else + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" + fi + echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \ + ${DOCKER_IMAGE_TAGS} . + - name: Set up Docker Buildx + uses: crazy-max/ghaction-docker-buildx@v3 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Docker Buildx (build-and-push) + run: | + docker buildx build -f Dockerfile-base --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} \ No newline at end of file diff --git a/.github/workflows/package_and_push.yml b/.github/workflows/package_and_push.yml index 3ec89e7c..0e364777 100644 --- a/.github/workflows/package_and_push.yml +++ b/.github/workflows/package_and_push.yml @@ -9,6 +9,11 @@ on: description: 'Image Tag' default: 'v0.9.5' required: true + packageEE: + description: '是否打包企业版' + default: false + required: true + type: boolean jobs: package-and-push-to-aliyun-oss: @@ -31,6 +36,7 @@ jobs: ALIYUN_OSS_BUCKET_ENDPOINT: ${{ secrets.ALIYUN_OSS_BUCKET_ENDPOINT }} ALIYUN_OSS_ACCESS_KEY: ${{ secrets.ALIYUN_OSS_ACCESS_KEY }} ALIYUN_OSS_ACCESS_SECRET: ${{ secrets.ALIYUN_OSS_ACCESS_SECRET }} + PACKAGE_EE: ${{ github.event.inputs.packageEE }} run: | DOCKER_IMAGE=${ALIYUN_REGISTRY_HOST}/dataease/sqlbot cd installer @@ -76,12 +82,27 @@ jobs: --exclude .git \ --exclude images \ --exclude docker \ - -czvf $package_online . + -czvf $package_online . + + if [[ ${PACKAGE_EE} == 'true' ]]; then + package_offline_ee="sqlbot-offline-installer-${TAG_NAME}${platform}-ee.tar.gz" + touch $package_offline_ee + tar --transform "s/^\./sqlbot-offline-installer-${TAG_NAME}${platform}-ee/" \ + --exclude $package_online \ + --exclude $package_offline \ + --exclude $package_offline_ee \ + --exclude .git \ + -czvf $package_offline_ee . + + ossutil cp -rf ${package_offline_ee} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline_ee} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} + fi #Sync files to OSS - ossutil cp -rf ${package_offline} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} + ossutil cp -rf ${package_offline} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} ossutil cp -rf ${package_online} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_online} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} + + diff --git a/.github/workflows/sync2gitee.yml b/.github/workflows/sync2gitee.yml new file mode 100644 index 00000000..78092f29 --- /dev/null +++ b/.github/workflows/sync2gitee.yml @@ -0,0 +1,15 @@ +name: Synchronize to Gitee +on: [push] +jobs: + repo-sync: + runs-on: ubuntu-latest + steps: + - name: Mirror the Github organization repos to Gitee. + uses: Yikun/hub-mirror-action@master + with: + src: 'github/dataease' + dst: 'gitee/fit2cloud-feizhiyun' + dst_key: ${{ secrets.GITEE_PRIVATE_KEY }} + dst_token: ${{ secrets.GITEE_TOKEN }} + static_list: "SQLBot" + force_update: true diff --git a/Dockerfile b/Dockerfile index cb6724de..654ca8e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,5 @@ # Build sqlbot +FROM ghcr.io/1panel-dev/maxkb-vector-model:v1.0.1 AS vector-model FROM registry.cn-qingdao.aliyuncs.com/dataease/sqlbot-base:latest AS sqlbot-builder # Set build environment variables @@ -32,7 +33,7 @@ COPY ./backend ${APP_HOME} # Final sync to ensure all dependencies are installed RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync + uv sync --extra cpu # Build g2-ssr FROM registry.cn-qingdao.aliyuncs.com/dataease/sqlbot-base:latest AS ssr-builder @@ -45,7 +46,10 @@ COPY g2-ssr/charts/* /app/charts/ RUN npm install # Runtime stage -FROM registry.cn-qingdao.aliyuncs.com/dataease/sqlbot-base:latest +FROM registry.cn-qingdao.aliyuncs.com/dataease/sqlbot-python-pg:latest + +RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + echo "Asia/Shanghai" > /etc/timezone # Set runtime environment variables ENV PYTHONUNBUFFERED=1 @@ -53,20 +57,25 @@ ENV SQLBOT_HOME=/opt/sqlbot ENV PYTHONPATH=${SQLBOT_HOME}/app ENV PATH="${SQLBOT_HOME}/app/.venv/bin:$PATH" +ENV POSTGRES_DB=sqlbot +ENV POSTGRES_USER=root +ENV POSTGRES_PASSWORD=Password123@pg + # Copy necessary files from builder COPY start.sh /opt/sqlbot/app/start.sh COPY g2-ssr/*.ttf /usr/share/fonts/truetype/liberation/ COPY --from=sqlbot-builder ${SQLBOT_HOME} ${SQLBOT_HOME} COPY --from=ssr-builder /app /opt/sqlbot/g2-ssr +COPY --from=vector-model /opt/maxkb/app/model /opt/sqlbot/models WORKDIR ${SQLBOT_HOME}/app RUN mkdir -p /opt/sqlbot/images /opt/sqlbot/g2-ssr -EXPOSE 3000 8000 +EXPOSE 3000 8000 8001 5432 # Add health check HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD curl -f http://localhost:8000 || exit 1 -ENTRYPOINT ["sh", "start.sh"] \ No newline at end of file +ENTRYPOINT ["sh", "start.sh"] diff --git a/Dockerfile-base b/Dockerfile-base new file mode 100644 index 00000000..89d2b711 --- /dev/null +++ b/Dockerfile-base @@ -0,0 +1,30 @@ +FROM python:3.11-slim-bookworm AS python-builder +FROM registry.cn-qingdao.aliyuncs.com/dataease/postgres:17.6 + +# python environment +COPY --from=python-builder /usr/local /usr/local + +RUN python --version && pip --version + +# Install uv tool +COPY --from=ghcr.io/astral-sh/uv:0.7.8 /uv /uvx /bin/ + +RUN apt-get update && apt-get install -y --no-install-recommends \ + wait-for-it \ + build-essential \ + curl \ + gnupg \ + gcc \ + g++ \ + libcairo2-dev \ + libpango1.0-dev \ + libjpeg-dev \ + libgif-dev \ + librsvg2-dev \ + && curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ + && curl -qL https://www.npmjs.com/install.sh | sh - \ + && apt-get install -y nodejs \ + && curl -L --connect-timeout 60 -m 1800 https://fit2cloud-support.oss-cn-beijing.aliyuncs.com/xpack-license/get-validator-linux | sh \ + && rm -rf /var/lib/apt/lists/* \ + && chmod g+xr /usr/bin/ld.so \ + && chmod g+x /usr/local/bin/python* \ No newline at end of file diff --git a/README.md b/README.md index eadda354..d3622ab9 100644 --- a/README.md +++ b/README.md @@ -14,26 +14,34 @@ SQLBot 是一款基于大模型和 RAG 的智能问数系统。SQLBot 的优势 - **易于集成**: 支持快速嵌入到第三方业务系统,也支持被 n8n、MaxKB、Dify、Coze 等 AI 应用开发平台集成调用,让各类应用快速拥有智能问数能力; - **安全可控**: 提供基于工作空间的资源隔离机制,能够实现细粒度的数据权限控制。 +## 工作原理 + +system-arch + ## 快速开始 ### 安装部署 -准备一台 Linux 服务器,执行以下一键安装脚本。 -在运行 SQLBot 前,请确保已安装好 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/)。 +准备一台 Linux 服务器,安装好 [Docker](https://docs.docker.com/get-docker/),执行以下一键安装脚本: ```bash -# 创建目录 -mkdir -p /opt/sqlbot -cd /opt/sqlbot - -# 下载 docker-compose.yaml -curl -o docker-compose.yaml https://raw.githubusercontent.com/dataease/SQLBot/main/docker-compose.yaml - -# 启动服务 -docker compose up -d +docker run -d \ + --name sqlbot \ + --restart unless-stopped \ + -p 8000:8000 \ + -p 8001:8001 \ + -v ./data/sqlbot/excel:/opt/sqlbot/data/excel \ + -v ./data/sqlbot/file:/opt/sqlbot/data/file \ + -v ./data/sqlbot/images:/opt/sqlbot/images \ + -v ./data/sqlbot/logs:/opt/sqlbot/app/logs \ + -v ./data/postgresql:/var/lib/postgresql/data \ + --privileged=true \ + dataease/sqlbot ``` -你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 SQLBot; +你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 SQLBot。 + +如果是内网环境,你可以通过 [离线安装包方式](https://community.fit2cloud.com/#/products/sqlbot/downloads) 部署 SQLBot。 ### 访问方式 @@ -45,8 +53,7 @@ docker compose up -d 如你有更多问题,可以加入我们的技术交流群与我们交流。 -contact_me_qr - +contact_me_qr ## UI 展示 @@ -64,9 +71,17 @@ docker compose up -d - [1Panel](https://github.com/1panel-dev/1panel/) - 现代化、开源的 Linux 服务器运维管理面板 - [MaxKB](https://github.com/1panel-dev/MaxKB/) - 强大易用的企业级智能体平台 - [JumpServer](https://github.com/jumpserver/jumpserver/) - 广受欢迎的开源堡垒机 +- [Cordys CRM](https://github.com/1Panel-dev/CordysCRM) - 新一代的开源 AI CRM 系统 - [Halo](https://github.com/halo-dev/halo/) - 强大易用的开源建站工具 - [MeterSphere](https://github.com/metersphere/metersphere/) - 新一代的开源持续测试工具 ## License 本仓库遵循 [FIT2CLOUD Open Source License](LICENSE) 开源协议,该许可证本质上是 GPLv3,但有一些额外的限制。 + +你可以基于 SQLBot 的源代码进行二次开发,但是需要遵守以下规定: + +- 不能替换和修改 SQLBot 的 Logo 和版权信息; +- 二次开发后的衍生作品必须遵守 GPL V3 的开源义务。 + +如需商业授权,请联系 support@fit2cloud.com 。 diff --git a/backend/alembic/env.py b/backend/alembic/env.py index db6e9413..16ef1c3e 100755 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -24,9 +24,13 @@ # from apps.system.models.user import SQLModel # noqa # from apps.settings.models.setting_models import SQLModel -from apps.chat.models.chat_model import SQLModel +# from apps.chat.models.chat_model import SQLModel +from apps.terminology.models.terminology_model import SQLModel +#from apps.custom_prompt.models.custom_prompt_model import SQLModel +# from apps.data_training.models.data_training_model import SQLModel # from apps.dashboard.models.dashboard_model import SQLModel from common.core.config import settings # noqa +#from apps.datasource.models.datasource import SQLModel target_metadata = SQLModel.metadata diff --git a/backend/alembic/versions/037_create_chat_log.py b/backend/alembic/versions/037_create_chat_log.py new file mode 100644 index 00000000..59a7f425 --- /dev/null +++ b/backend/alembic/versions/037_create_chat_log.py @@ -0,0 +1,131 @@ +"""035_create_chat_log + +Revision ID: 68a06302cf70 +Revises: 29559ee607af +Create Date: 2025-08-18 16:02:43.353110 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '68a06302cf70' +down_revision = '646e7ca28e0e' +branch_labels = None +depends_on = None + + +sql=''' +CREATE OR REPLACE FUNCTION safe_jsonb_cast(text) RETURNS jsonb AS +$$ +BEGIN + RETURN $1::jsonb; +EXCEPTION + WHEN others THEN + RETURN to_json($1::text)::jsonb; +END; +$$ LANGUAGE plpgsql; + +INSERT INTO chat_log(type, operate, pid, ai_modal_id, messages, start_time, finish_time, token_usage, reasoning_content) +SELECT '0', + '0', + id, + ai_modal_id, + safe_jsonb_cast(full_sql_message), + create_time, + finish_time, + safe_jsonb_cast(token_sql), + safe_jsonb_cast(sql_answer)->>'reasoning_content' +FROM chat_record +WHERE full_sql_message IS NOT NULL; +INSERT INTO chat_log(type, operate, pid, ai_modal_id, messages, start_time, finish_time, token_usage, reasoning_content) +SELECT '0', + '1', + id, + ai_modal_id, + safe_jsonb_cast(full_chart_message), + create_time, + finish_time, + safe_jsonb_cast(token_chart), + safe_jsonb_cast(chart_answer)->>'reasoning_content' +FROM chat_record +WHERE full_chart_message IS NOT NULL; +INSERT INTO chat_log(type, operate, pid, ai_modal_id, messages, start_time, finish_time, token_usage, reasoning_content) +SELECT '0', + '2', + id, + ai_modal_id, + safe_jsonb_cast(full_analysis_message), + create_time, + finish_time, + safe_jsonb_cast(token_analysis), + safe_jsonb_cast(analysis)->>'reasoning_content' +FROM chat_record +WHERE full_analysis_message IS NOT NULL; +INSERT INTO chat_log(type, operate, pid, ai_modal_id, messages, start_time, finish_time, token_usage, reasoning_content) +SELECT '0', + '3', + id, + ai_modal_id, + safe_jsonb_cast(full_predict_message), + create_time, + finish_time, + safe_jsonb_cast(token_predict), + safe_jsonb_cast(predict)->>'reasoning_content' +FROM chat_record +WHERE full_predict_message IS NOT NULL; +INSERT INTO chat_log(type, operate, pid, ai_modal_id, messages, start_time, finish_time, token_usage, reasoning_content) +SELECT '0', + '4', + id, + ai_modal_id, + safe_jsonb_cast(full_recommended_question_message), + create_time, + finish_time, + safe_jsonb_cast(token_recommended_question), + safe_jsonb_cast(recommended_question_answer)->>'reasoning_content' +FROM chat_record +WHERE full_recommended_question_message IS NOT NULL; +INSERT INTO chat_log(type, operate, pid, ai_modal_id, messages, start_time, finish_time, token_usage, reasoning_content) +SELECT '0', + '6', + id, + ai_modal_id, + safe_jsonb_cast(full_select_datasource_message), + create_time, + finish_time, + safe_jsonb_cast(token_select_datasource_question), + safe_jsonb_cast(datasource_select_answer)->>'reasoning_content' +FROM chat_record +WHERE full_select_datasource_message IS NOT NULL; + +''' + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('chat_log', + sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False), + sa.Column('type', sa.Enum('0', name='typeenum', native_enum=False, length=3), nullable=True), + sa.Column('operate', sa.Enum('0', '1', '2', '3', '4', '5', '6', name='operationenum', native_enum=False, length=3), nullable=True), + sa.Column('pid', sa.BigInteger(), nullable=True), + sa.Column('ai_modal_id', sa.BigInteger(), nullable=True), + sa.Column('base_modal', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('messages', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('reasoning_content', sa.Text(), nullable=True), + sa.Column('start_time', sa.DateTime(), nullable=True), + sa.Column('finish_time', sa.DateTime(), nullable=True), + sa.Column('token_usage', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + op.execute(sql) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_table('chat_log') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/038_remove_chat_record_cloumns.py b/backend/alembic/versions/038_remove_chat_record_cloumns.py new file mode 100644 index 00000000..870c5477 --- /dev/null +++ b/backend/alembic/versions/038_remove_chat_record_cloumns.py @@ -0,0 +1,51 @@ +"""038_remove_chat_record_cloumns + +Revision ID: fc23c4f3e755 +Revises: 68a06302cf70 +Create Date: 2025-08-21 14:34:59.149410 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'fc23c4f3e755' +down_revision = '68a06302cf70' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('chat_record', 'token_predict') + op.drop_column('chat_record', 'token_select_datasource_question') + op.drop_column('chat_record', 'token_sql') + op.drop_column('chat_record', 'full_analysis_message') + op.drop_column('chat_record', 'full_recommended_question_message') + op.drop_column('chat_record', 'token_chart') + op.drop_column('chat_record', 'full_predict_message') + op.drop_column('chat_record', 'full_chart_message') + op.drop_column('chat_record', 'full_sql_message') + op.drop_column('chat_record', 'full_select_datasource_message') + op.drop_column('chat_record', 'token_recommended_question') + op.drop_column('chat_record', 'token_analysis') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('chat_record', sa.Column('token_analysis', sa.VARCHAR(length=256), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('token_recommended_question', sa.VARCHAR(length=256), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('full_select_datasource_message', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('full_sql_message', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('full_chart_message', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('full_predict_message', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('token_chart', sa.VARCHAR(length=256), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('full_recommended_question_message', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('full_analysis_message', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('token_sql', sa.VARCHAR(length=256), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('token_select_datasource_question', sa.VARCHAR(length=256), autoincrement=False, nullable=True)) + op.add_column('chat_record', sa.Column('token_predict', sa.VARCHAR(length=256), autoincrement=False, nullable=True)) + # ### end Alembic commands ### diff --git a/backend/alembic/versions/039_create_terminology.py b/backend/alembic/versions/039_create_terminology.py new file mode 100644 index 00000000..42ab588b --- /dev/null +++ b/backend/alembic/versions/039_create_terminology.py @@ -0,0 +1,41 @@ +"""039_create_terminology + +Revision ID: 25cbc85766fd +Revises: fc23c4f3e755 +Create Date: 2025-08-25 11:38:32.990973 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +import pgvector +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '25cbc85766fd' +down_revision = 'fc23c4f3e755' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.execute("CREATE EXTENSION IF NOT EXISTS vector;") + + op.create_table('terminology', + sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False), + sa.Column('pid', sa.BigInteger(), nullable=True), + sa.Column('create_time', sa.DateTime(), nullable=True), + sa.Column('word', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('embedding', pgvector.sqlalchemy.vector.VECTOR(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('terminology') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/040_modify_ai_model.py b/backend/alembic/versions/040_modify_ai_model.py new file mode 100644 index 00000000..dc924881 --- /dev/null +++ b/backend/alembic/versions/040_modify_ai_model.py @@ -0,0 +1,51 @@ +"""040_modify_ai_model + +Revision ID: 0fc14c2cfe41 +Revises: 25cbc85766fd +Create Date: 2025-08-26 23:30:50.192799 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '0fc14c2cfe41' +down_revision = '25cbc85766fd' +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column( + 'ai_model', + 'api_key', + type_=sa.Text(), + existing_type=sa.String(length=255), + existing_nullable=True + ) + op.alter_column( + 'ai_model', + 'api_domain', + type_=sa.Text(), + existing_type=sa.String(length=255), + existing_nullable=False + ) + + +def downgrade(): + op.alter_column( + 'ai_model', + 'api_key', + type_=sa.String(), + existing_type=sa.Text(), + existing_nullable=True + ) + op.alter_column( + 'ai_model', + 'api_domain', + type_=sa.String(), + existing_type=sa.Text(), + existing_nullable=False + ) diff --git a/backend/alembic/versions/041_add_terminology_oid.py b/backend/alembic/versions/041_add_terminology_oid.py new file mode 100644 index 00000000..1b4d783c --- /dev/null +++ b/backend/alembic/versions/041_add_terminology_oid.py @@ -0,0 +1,33 @@ +"""041_add_terminology_oid + +Revision ID: c4c3c36b720d +Revises: 0fc14c2cfe41 +Create Date: 2025-08-28 16:41:33.977242 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c4c3c36b720d' +down_revision = '0fc14c2cfe41' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.add_column('terminology', sa.Column('oid', sa.BigInteger(), nullable=True)) + + op.execute('update terminology set oid=1 where oid is null') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('terminology', 'oid') + + # ### end Alembic commands ### diff --git a/backend/alembic/versions/042_data_training.py b/backend/alembic/versions/042_data_training.py new file mode 100644 index 00000000..ec44c89f --- /dev/null +++ b/backend/alembic/versions/042_data_training.py @@ -0,0 +1,41 @@ +"""042_data_training + +Revision ID: a487d9c69341 +Revises: c4c3c36b720d +Create Date: 2025-09-15 15:41:43.332771 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql +import pgvector + +# revision identifiers, used by Alembic. +revision = 'a487d9c69341' +down_revision = 'c4c3c36b720d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('data_training', + sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False), + sa.Column('oid', sa.BigInteger(), nullable=True), + sa.Column('datasource', sa.BigInteger(), nullable=True), + sa.Column('create_time', sa.DateTime(), nullable=True), + sa.Column('question', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('embedding', pgvector.sqlalchemy.vector.VECTOR(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_table('data_training') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/043_modify_ds_id_type.py b/backend/alembic/versions/043_modify_ds_id_type.py new file mode 100644 index 00000000..4ef303a0 --- /dev/null +++ b/backend/alembic/versions/043_modify_ds_id_type.py @@ -0,0 +1,55 @@ +"""043_modify_ds_id_type + +Revision ID: dac062c1f7b1 +Revises: a487d9c69341 +Create Date: 2025-09-22 17:20:44.465735 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'dac062c1f7b1' +down_revision = 'a487d9c69341' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('core_datasource', 'id', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_field', 'id', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_table', 'id', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + autoincrement=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('core_table', 'id', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_field', 'id', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_datasource', 'id', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + autoincrement=True) + # ### end Alembic commands ### diff --git a/backend/alembic/versions/044_table_relation.py b/backend/alembic/versions/044_table_relation.py new file mode 100644 index 00000000..9f655ed4 --- /dev/null +++ b/backend/alembic/versions/044_table_relation.py @@ -0,0 +1,29 @@ +"""044_table_relation + +Revision ID: 455b8ce69e80 +Revises: dac062c1f7b1 +Create Date: 2025-09-24 13:34:08.205659 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '455b8ce69e80' +down_revision = 'dac062c1f7b1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('core_datasource', sa.Column('table_relation', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('core_datasource', 'table_relation') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/045_modify_terminolog.py b/backend/alembic/versions/045_modify_terminolog.py new file mode 100644 index 00000000..a452bb6c --- /dev/null +++ b/backend/alembic/versions/045_modify_terminolog.py @@ -0,0 +1,31 @@ +"""045_modify_terminolog + +Revision ID: 45e7e52bf2b8 +Revises: 455b8ce69e80 +Create Date: 2025-09-25 14:49:24.521795 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '45e7e52bf2b8' +down_revision = '455b8ce69e80' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('terminology', sa.Column('specific_ds', sa.Boolean(), nullable=True)) + op.add_column('terminology', sa.Column('datasource_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('terminology', 'datasource_ids') + op.drop_column('terminology', 'specific_ds') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/046_add_custom_prompt.py b/backend/alembic/versions/046_add_custom_prompt.py new file mode 100644 index 00000000..692c0f3f --- /dev/null +++ b/backend/alembic/versions/046_add_custom_prompt.py @@ -0,0 +1,39 @@ +"""046_add_custom_prompt + +Revision ID: 8855aea2dd61 +Revises: 45e7e52bf2b8 +Create Date: 2025-09-28 13:57:01.509249 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8855aea2dd61' +down_revision = '45e7e52bf2b8' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('custom_prompt', + sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False), + sa.Column('oid', sa.BigInteger(), nullable=True), + sa.Column('type', sa.Enum('GENERATE_SQL', 'ANALYSIS', 'PREDICT_DATA', name='customprompttypeenum', native_enum=False, length=20), nullable=True), + sa.Column('create_time', sa.DateTime(), nullable=True), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('prompt', sa.Text(), nullable=True), + sa.Column('specific_ds', sa.Boolean(), nullable=True), + sa.Column('datasource_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('custom_prompt') + # ### end Alembic commands ### diff --git a/backend/apps/ai_model/embedding.py b/backend/apps/ai_model/embedding.py new file mode 100644 index 00000000..315c8461 --- /dev/null +++ b/backend/apps/ai_model/embedding.py @@ -0,0 +1,63 @@ +import os.path +import threading +from typing import Optional + +from langchain_core.embeddings import Embeddings +from langchain_huggingface import HuggingFaceEmbeddings +from pydantic import BaseModel + +from common.core.config import settings + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class EmbeddingModelInfo(BaseModel): + folder: str + name: str + device: str = 'cpu' + + +local_embedding_model = EmbeddingModelInfo(folder=settings.LOCAL_MODEL_PATH, + name=os.path.join(settings.LOCAL_MODEL_PATH, 'embedding', + "shibing624_text2vec-base-chinese")) + +_lock = threading.Lock() +locks = {} + +_embedding_model: dict[str, Optional[Embeddings]] = {} + + +class EmbeddingModelCache: + + @staticmethod + def _new_instance(config: EmbeddingModelInfo = local_embedding_model): + return HuggingFaceEmbeddings(model_name=config.name, cache_folder=config.folder, + model_kwargs={'device': config.device}, + encode_kwargs={'normalize_embeddings': True} + ) + + @staticmethod + def _get_lock(key: str = settings.DEFAULT_EMBEDDING_MODEL): + lock = locks.get(key) + if lock is None: + with _lock: + lock = locks.get(key) + if lock is None: + lock = threading.Lock() + locks[key] = lock + + return lock + + @staticmethod + def get_model(key: str = settings.DEFAULT_EMBEDDING_MODEL, + config: EmbeddingModelInfo = local_embedding_model) -> Embeddings: + model_instance = _embedding_model.get(key) + if model_instance is None: + lock = EmbeddingModelCache._get_lock(key) + with lock: + model_instance = _embedding_model.get(key) + if model_instance is None: + model_instance = EmbeddingModelCache._new_instance(config) + _embedding_model[key] = model_instance + + return model_instance diff --git a/backend/apps/ai_model/model_factory.py b/backend/apps/ai_model/model_factory.py index 243c91e0..03479fd8 100644 --- a/backend/apps/ai_model/model_factory.py +++ b/backend/apps/ai_model/model_factory.py @@ -10,9 +10,10 @@ from apps.ai_model.openai.llm import BaseChatOpenAI from apps.system.models.system_model import AiModelDetail from common.core.db import engine +from common.utils.crypto import sqlbot_decrypt from common.utils.utils import prepare_model_arg from langchain_community.llms import VLLMOpenAI - +from langchain_openai import AzureChatOpenAI # from langchain_community.llms import Tongyi, VLLM class LLMConfig(BaseModel): @@ -69,6 +70,24 @@ def _init_llm(self) -> VLLMOpenAI: streaming=True, **self.config.additional_params, ) + +class OpenAIAzureLLM(BaseLLM): + def _init_llm(self) -> AzureChatOpenAI: + api_version = self.config.additional_params.get("api_version") + deployment_name = self.config.additional_params.get("deployment_name") + if api_version: + self.config.additional_params.pop("api_version") + if deployment_name: + self.config.additional_params.pop("deployment_name") + return AzureChatOpenAI( + azure_endpoint=self.config.api_base_url, + api_key=self.config.api_key or 'Empty', + model_name=self.config.model_name, + api_version=api_version, + deployment_name=deployment_name, + streaming=True, + **self.config.additional_params, + ) class OpenAILLM(BaseLLM): def _init_llm(self) -> BaseChatModel: return BaseChatOpenAI( @@ -89,7 +108,8 @@ class LLMFactory: _llm_types: Dict[str, Type[BaseLLM]] = { "openai": OpenAILLM, "tongyi": OpenAILLM, - "vllm": OpenAIvLLM + "vllm": OpenAIvLLM, + "azure": OpenAIAzureLLM, } @classmethod @@ -118,7 +138,7 @@ def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]): return config """ -def get_default_config() -> LLMConfig: +async def get_default_config() -> LLMConfig: with Session(engine) as session: db_model = session.exec( select(AiModelDetail).where(AiModelDetail.default_model == True) @@ -133,6 +153,11 @@ def get_default_config() -> LLMConfig: additional_params = {item["key"]: prepare_model_arg(item.get('val')) for item in config_raw if "key" in item and "val" in item} except Exception: pass + if not db_model.api_domain.startswith("http"): + db_model.api_domain = await sqlbot_decrypt(db_model.api_domain) + if db_model.api_key: + db_model.api_key = await sqlbot_decrypt(db_model.api_key) + # 构造 LLMConfig return LLMConfig( diff --git a/backend/apps/api.py b/backend/apps/api.py index 4f573134..8b836c0d 100644 --- a/backend/apps/api.py +++ b/backend/apps/api.py @@ -2,10 +2,11 @@ from apps.chat.api import chat from apps.dashboard.api import dashboard_api -from apps.datasource.api import datasource -from apps.settings.api import terminology -from apps.system.api import login, user, aimodel, workspace, assistant +from apps.data_training.api import data_training +from apps.datasource.api import datasource, table_relation from apps.mcp import mcp +from apps.system.api import login, user, aimodel, workspace, assistant +from apps.terminology.api import terminology api_router = APIRouter() api_router.include_router(login.router) @@ -14,8 +15,9 @@ api_router.include_router(assistant.router) api_router.include_router(aimodel.router) api_router.include_router(terminology.router) +api_router.include_router(data_training.router) api_router.include_router(datasource.router) api_router.include_router(chat.router) api_router.include_router(dashboard_api.router) api_router.include_router(mcp.router) - +api_router.include_router(table_relation.router) diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 3cc59380..a28c03d7 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -3,6 +3,7 @@ import traceback import numpy as np +import orjson import pandas as pd from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse @@ -12,7 +13,7 @@ delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData from apps.chat.task.llm import LLMService -from common.core.deps import CurrentAssistant, SessionDep, CurrentUser +from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans router = APIRouter(tags=["Data Q&A"], prefix="/chat") @@ -104,25 +105,27 @@ async def start_chat(session: SessionDep, current_user: CurrentUser): @router.post("/recommend_questions/{chat_record_id}") async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, current_assistant: CurrentAssistant): + def _return_empty(): + yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n' + try: record = get_chat_record_by_id(session, chat_record_id) if not record: - raise HTTPException( - status_code=400, - detail=f"Chat record with id {chat_record_id} not found" - ) + return StreamingResponse(_return_empty(), media_type="text/event-stream") + request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '') - llm_service = LLMService(current_user, request_question, current_assistant, True) + llm_service = await LLMService.create(current_user, request_question, current_assistant, True) llm_service.set_record(record) llm_service.run_recommend_questions_task_async() except Exception as e: traceback.print_exc() - raise HTTPException( - status_code=500, - detail=str(e) - ) + + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") @@ -142,15 +145,16 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que """ try: - llm_service = LLMService(current_user, request_question, current_assistant) + llm_service = await LLMService.create(current_user, request_question, current_assistant, embedding=True) llm_service.init_record() llm_service.run_task_async() except Exception as e: traceback.print_exc() - raise HTTPException( - status_code=500, - detail=str(e) - ) + + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") @@ -158,54 +162,55 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que @router.post("/record/{chat_record_id}/{action_type}") async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str, current_assistant: CurrentAssistant): - if action_type != 'analysis' and action_type != 'predict': - raise HTTPException( - status_code=404, - detail="Not Found" - ) - record: ChatRecord | None = None - - stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type, - ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where( - and_(ChatRecord.id == chat_record_id)) - result = session.execute(stmt) - for r in result: - record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, - engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, chart=r.chart, - data=r.data) - - if not record: - raise HTTPException( - status_code=400, - detail=f"Chat record with id {chat_record_id} not found" - ) + try: + if action_type != 'analysis' and action_type != 'predict': + raise Exception(f"Type {action_type} Not Found") + record: ChatRecord | None = None + + stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, + ChatRecord.engine_type, + ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where( + and_(ChatRecord.id == chat_record_id)) + result = session.execute(stmt) + for r in result: + record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, + engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, + chart=r.chart, + data=r.data) - if not record.chart: - raise HTTPException( - status_code=500, - detail=f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it" - ) + if not record: + raise Exception(f"Chat record with id {chat_record_id} not found") - request_question = ChatQuestion(chat_id=record.chat_id, question='') + if not record.chart: + raise Exception( + f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it") - try: - llm_service = LLMService(current_user, request_question, current_assistant) + request_question = ChatQuestion(chat_id=record.chat_id, question=record.question) + + llm_service = await LLMService.create(current_user, request_question, current_assistant) llm_service.run_analysis_or_predict_task_async(action_type, record) except Exception as e: traceback.print_exc() - raise HTTPException( - status_code=500, - detail=str(e) - ) + + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") @router.post("/excel/export") -async def export_excel(excel_data: ExcelData): +async def export_excel(excel_data: ExcelData, trans: Trans): def inner(): _fields_list = [] data = [] + if not excel_data.data: + raise HTTPException( + status_code=500, + detail=trans("i18n_excel_export.data_is_empty") + ) + for _data in excel_data.data: _row = [] for field in excel_data.axis: diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index b715b32e..01081553 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -4,8 +4,10 @@ import orjson import sqlparse from sqlalchemy import and_, select, update +from sqlalchemy.orm import aliased -from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion +from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \ + TypeEnum, OperationEnum, ChatRecordResult from apps.datasource.models.datasource import CoreDatasource from apps.system.crud.assistant import AssistantOutDsFactory from common.core.deps import CurrentAssistant, SessionDep, CurrentUser @@ -69,6 +71,21 @@ def get_chart_config(session: SessionDep, chart_record_id: int): return {} +def get_last_execute_sql_error(session: SessionDep, chart_id: int): + stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by( + ChatRecord.create_time.desc()).limit(1) + res = session.execute(stmt).scalar() + if res: + try: + obj = orjson.loads(res) + if obj.get('type') and obj.get('type') == 'exec-sql-err': + return obj.get('traceback') + except Exception: + pass + + return None + + def get_chat_chart_data(session: SessionDep, chart_record_id: int): stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chart_record_id)) res = session.execute(stmt) @@ -96,6 +113,9 @@ def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_ return get_chat_with_records(session, chart_id, current_user, current_assistant, True) +dynamic_ds_types = [1, 3] + + def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser, current_assistant: CurrentAssistant, with_data: bool = False) -> ChatInfo: chat = session.get(Chat, chart_id) @@ -104,7 +124,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr chat_info = ChatInfo(**chat.model_dump()) - if current_assistant and current_assistant.type == 1: + if current_assistant and current_assistant.type in dynamic_ds_types: out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant) ds = out_ds_instance.get_ds(chat.datasource) else: @@ -118,13 +138,36 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr chat_info.datasource_name = ds.name chat_info.ds_type = ds.type - stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time, - ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, - ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict, - ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id, - ChatRecord.recommended_question, ChatRecord.first_chat, - ChatRecord.finish, ChatRecord.error).where( - and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time) + sql_alias_log = aliased(ChatLog) + chart_alias_log = aliased(ChatLog) + analysis_alias_log = aliased(ChatLog) + predict_alias_log = aliased(ChatLog) + + stmt = (select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time, + ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, + ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict, + ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id, + ChatRecord.recommended_question, ChatRecord.first_chat, + ChatRecord.finish, ChatRecord.error, + sql_alias_log.reasoning_content.label('sql_reasoning_content'), + chart_alias_log.reasoning_content.label('chart_reasoning_content'), + analysis_alias_log.reasoning_content.label('analysis_reasoning_content'), + predict_alias_log.reasoning_content.label('predict_reasoning_content') + ) + .outerjoin(sql_alias_log, and_(sql_alias_log.pid == ChatRecord.id, + sql_alias_log.type == TypeEnum.CHAT, + sql_alias_log.operate == OperationEnum.GENERATE_SQL)) + .outerjoin(chart_alias_log, and_(chart_alias_log.pid == ChatRecord.id, + chart_alias_log.type == TypeEnum.CHAT, + chart_alias_log.operate == OperationEnum.GENERATE_CHART)) + .outerjoin(analysis_alias_log, and_(analysis_alias_log.pid == ChatRecord.id, + analysis_alias_log.type == TypeEnum.CHAT, + analysis_alias_log.operate == OperationEnum.ANALYSIS)) + .outerjoin(predict_alias_log, and_(predict_alias_log.pid == ChatRecord.id, + predict_alias_log.type == TypeEnum.CHAT, + predict_alias_log.operate == OperationEnum.PREDICT_DATA)) + .where(and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by( + ChatRecord.create_time)) if with_data: stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time, ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, @@ -136,28 +179,35 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr ChatRecord.create_time) result = session.execute(stmt).all() - record_list: list[ChatRecord] = [] + record_list: list[ChatRecordResult] = [] for row in result: if not with_data: record_list.append( - ChatRecord(id=row.id, chat_id=row.chat_id, create_time=row.create_time, finish_time=row.finish_time, - question=row.question, sql_answer=row.sql_answer, sql=row.sql, - chart_answer=row.chart_answer, chart=row.chart, - analysis=row.analysis, predict=row.predict, - datasource_select_answer=row.datasource_select_answer, - analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id, - recommended_question=row.recommended_question, first_chat=row.first_chat, - finish=row.finish, error=row.error)) + ChatRecordResult(id=row.id, chat_id=row.chat_id, create_time=row.create_time, + finish_time=row.finish_time, + question=row.question, sql_answer=row.sql_answer, sql=row.sql, + chart_answer=row.chart_answer, chart=row.chart, + analysis=row.analysis, predict=row.predict, + datasource_select_answer=row.datasource_select_answer, + analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id, + recommended_question=row.recommended_question, first_chat=row.first_chat, + finish=row.finish, error=row.error, + sql_reasoning_content=row.sql_reasoning_content, + chart_reasoning_content=row.chart_reasoning_content, + analysis_reasoning_content=row.analysis_reasoning_content, + predict_reasoning_content=row.predict_reasoning_content, + )) else: record_list.append( - ChatRecord(id=row.id, chat_id=row.chat_id, create_time=row.create_time, finish_time=row.finish_time, - question=row.question, sql_answer=row.sql_answer, sql=row.sql, - chart_answer=row.chart_answer, chart=row.chart, - analysis=row.analysis, predict=row.predict, - datasource_select_answer=row.datasource_select_answer, - analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id, - recommended_question=row.recommended_question, first_chat=row.first_chat, - finish=row.finish, error=row.error, data=row.data, predict_data=row.predict_data)) + ChatRecordResult(id=row.id, chat_id=row.chat_id, create_time=row.create_time, + finish_time=row.finish_time, + question=row.question, sql_answer=row.sql_answer, sql=row.sql, + chart_answer=row.chart_answer, chart=row.chart, + analysis=row.analysis, predict=row.predict, + datasource_select_answer=row.datasource_select_answer, + analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id, + recommended_question=row.recommended_question, first_chat=row.first_chat, + finish=row.finish, error=row.error, data=row.data, predict_data=row.predict_data)) result = list(map(format_record, record_list)) @@ -166,27 +216,35 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr return chat_info -def format_record(record: ChatRecord): +def format_record(record: ChatRecordResult): _dict = record.model_dump() if record.sql_answer and record.sql_answer.strip() != '' and record.sql_answer.strip()[0] == '{' and \ record.sql_answer.strip()[-1] == '}': _obj = orjson.loads(record.sql_answer) _dict['sql_answer'] = _obj.get('reasoning_content') + if record.sql_reasoning_content and record.sql_reasoning_content.strip() != '': + _dict['sql_answer'] = record.sql_reasoning_content if record.chart_answer and record.chart_answer.strip() != '' and record.chart_answer.strip()[0] == '{' and \ record.chart_answer.strip()[-1] == '}': _obj = orjson.loads(record.chart_answer) _dict['chart_answer'] = _obj.get('reasoning_content') + if record.chart_reasoning_content and record.chart_reasoning_content.strip() != '': + _dict['chart_answer'] = record.chart_reasoning_content if record.analysis and record.analysis.strip() != '' and record.analysis.strip()[0] == '{' and \ record.analysis.strip()[-1] == '}': _obj = orjson.loads(record.analysis) _dict['analysis_thinking'] = _obj.get('reasoning_content') _dict['analysis'] = _obj.get('content') + if record.analysis_reasoning_content and record.analysis_reasoning_content.strip() != '': + _dict['analysis_thinking'] = record.analysis_reasoning_content if record.predict and record.predict.strip() != '' and record.predict.strip()[0] == '{' and record.predict.strip()[ -1] == '}': _obj = orjson.loads(record.predict) _dict['predict'] = _obj.get('reasoning_content') _dict['predict_content'] = _obj.get('content') + if record.predict_reasoning_content and record.predict_reasoning_content.strip() != '': + _dict['predict'] = record.predict_reasoning_content if record.data and record.data.strip() != '': try: _obj = orjson.loads(record.data) @@ -208,19 +266,30 @@ def format_record(record: ChatRecord): return _dict -def list_base_records(session: SessionDep, chart_id: int, current_user: CurrentUser) -> List[ChatRecord]: - stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.full_sql_message, ChatRecord.full_chart_message, - ChatRecord.first_chat, ChatRecord.create_time).where( - and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id, - ChatRecord.analysis_record_id.is_(None), ChatRecord.predict_record_id.is_(None))).order_by( - ChatRecord.create_time) +def list_generate_sql_logs(session: SessionDep, chart_id: int) -> List[ChatLog]: + stmt = select(ChatLog).where( + and_(ChatLog.pid.in_(select(ChatRecord.id).where(and_(ChatRecord.chat_id == chart_id))), + ChatLog.type == TypeEnum.CHAT, ChatLog.operate == OperationEnum.GENERATE_SQL)).order_by( + ChatLog.start_time) result = session.execute(stmt).all() - record_list: List[ChatRecord] = [] - for r in result: - record_list.append( - ChatRecord(id=r.id, chat_id=r.chat_id, create_time=r.create_time, full_sql_message=r.full_sql_message, - full_chart_message=r.full_chart_message, first_chat=r.first_chat)) - return record_list + _list = [] + for row in result: + for r in row: + _list.append(ChatLog(**r.model_dump())) + return _list + + +def list_generate_chart_logs(session: SessionDep, chart_id: int) -> List[ChatLog]: + stmt = select(ChatLog).where( + and_(ChatLog.pid.in_(select(ChatRecord.id).where(and_(ChatRecord.chat_id == chart_id))), + ChatLog.type == TypeEnum.CHAT, ChatLog.operate == OperationEnum.GENERATE_CHART)).order_by( + ChatLog.start_time) + result = session.execute(stmt).all() + _list = [] + for row in result: + for r in row: + _list.append(ChatLog(**r.model_dump())) + return _list def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat, @@ -344,125 +413,117 @@ def save_analysis_predict_record(session: SessionDep, base_record: ChatRecord, a return result -def save_full_sql_message(session: SessionDep, record_id: int, full_message: str) -> ChatRecord: - return save_full_sql_message_and_answer(session=session, record_id=record_id, full_message=full_message, answer='') +def start_log(session: SessionDep, ai_modal_id: int, ai_modal_name: str, operate: OperationEnum, record_id: int, + full_message: list[dict]) -> ChatLog: + log = ChatLog(type=TypeEnum.CHAT, operate=operate, pid=record_id, ai_modal_id=ai_modal_id, base_modal=ai_modal_name, + messages=full_message, start_time=datetime.datetime.now()) + result = ChatLog(**log.model_dump()) -def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer: str, full_message: str, - token_usage: dict = None) -> ChatRecord: - if not record_id: - raise Exception("Record id cannot be None") - record = get_chat_record_by_id(session, record_id) - - record.full_sql_message = full_message - record.sql_answer = answer + session.add(log) + session.flush() + session.refresh(log) + result.id = log.id + session.commit() - if token_usage: - record.token_sql = orjson.dumps(token_usage).decode() + return result - result = ChatRecord(**record.model_dump()) - stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_sql_message=record.full_sql_message, - sql_answer=record.sql_answer, - token_sql=record.token_sql, +def end_log(session: SessionDep, log: ChatLog, full_message: list[dict], reasoning_content: str = None, + token_usage=None) -> ChatLog: + if token_usage is None: + token_usage = {} + log.messages = full_message + log.token_usage = token_usage + log.finish_time = datetime.datetime.now() + log.reasoning_content = reasoning_content if reasoning_content and len(reasoning_content.strip()) > 0 else None + + stmt = update(ChatLog).where(and_(ChatLog.id == log.id)).values( + messages=log.messages, + token_usage=log.token_usage, + finish_time=log.finish_time, + reasoning_content=log.reasoning_content ) - session.execute(stmt) - session.commit() - return result + return log -def save_full_analysis_message_and_answer(session: SessionDep, record_id: int, answer: str, - full_message: str, token_usage: dict = None) -> ChatRecord: +def save_sql_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = get_chat_record_by_id(session, record_id) - - record.full_analysis_message = full_message - record.analysis = answer - - if token_usage: - record.token_analysis = orjson.dumps(token_usage).decode() - result = ChatRecord(**record.model_dump()) - - stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_analysis_message=record.full_analysis_message, - analysis=record.analysis, - token_analysis=record.token_analysis, + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + sql_answer=answer, ) session.execute(stmt) session.commit() - return result + record = get_chat_record_by_id(session, record_id) + + return record -def save_full_predict_message_and_answer(session: SessionDep, record_id: int, answer: str, - full_message: str, data: str, token_usage: dict = None) -> ChatRecord: +def save_analysis_answer(session: SessionDep, record_id: int, answer: str = '') -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + analysis=answer, + ) + + session.execute(stmt) + + session.commit() + record = get_chat_record_by_id(session, record_id) - record.full_predict_message = full_message - record.predict = answer - record.predict_data = data + return record - if token_usage: - record.token_predict = orjson.dumps(token_usage).decode() - result = ChatRecord(**record.model_dump()) +def save_predict_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") - stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_predict_message=record.full_predict_message, - predict=record.predict, - predict_data=record.predict_data, - token_predict=record.token_predict + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + predict=answer, ) session.execute(stmt) session.commit() - return result + record = get_chat_record_by_id(session, record_id) + + return record -def save_full_select_datasource_message_and_answer(session: SessionDep, record_id: int, answer: str, - full_message: str, datasource: int = None, - engine_type: str = None, token_usage: dict = None) -> ChatRecord: +def save_select_datasource_answer(session: SessionDep, record_id: int, answer: str, + datasource: int = None, engine_type: str = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") record = get_chat_record_by_id(session, record_id) - record.full_select_datasource_message = full_message record.datasource_select_answer = answer if datasource: record.datasource = datasource record.engine_type = engine_type - if token_usage: - record.token_select_datasource_question = orjson.dumps(token_usage).decode() - result = ChatRecord(**record.model_dump()) if datasource: stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_select_datasource_message=record.full_select_datasource_message, datasource_select_answer=record.datasource_select_answer, - token_select_datasource_question=record.token_select_datasource_question, datasource=record.datasource, engine_type=record.engine_type, ) else: stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_select_datasource_message=record.full_select_datasource_message, datasource_select_answer=record.datasource_select_answer, - token_select_datasource_question=record.token_select_datasource_question, ) session.execute(stmt) @@ -472,16 +533,12 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i return result -def save_full_recommend_question_message_and_answer(session: SessionDep, record_id: int, answer: dict = None, - full_message: str = '[]', token_usage: dict = None) -> ChatRecord: +def save_recommend_question_answer(session: SessionDep, record_id: int, + answer: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = get_chat_record_by_id(session, record_id) - record.full_recommended_question_message = full_message - - if answer: - record.recommended_question_answer = orjson.dumps(answer).decode() + recommended_question_answer = orjson.dumps(answer).decode() json_str = '[]' if answer and answer.get('content') and answer.get('content') != '': @@ -492,25 +549,22 @@ def save_full_recommend_question_message_and_answer(session: SessionDep, record_ json_str = '[]' except Exception as e: pass - record.recommended_question = json_str - - if token_usage: - record.token_recommended_question = orjson.dumps(token_usage).decode() + recommended_question = json_str - result = ChatRecord(**record.model_dump()) - - stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_recommended_question_message=record.full_recommended_question_message, - recommended_question_answer=record.recommended_question_answer, - recommended_question=record.recommended_question, - token_recommended_question=record.token_recommended_question + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + recommended_question_answer=recommended_question_answer, + recommended_question=recommended_question, ) session.execute(stmt) session.commit() - return result + record = get_chat_record_by_id(session, record_id) + record.recommended_question_answer = recommended_question_answer + record.recommended_question = recommended_question + + return record def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord: @@ -534,36 +588,21 @@ def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord: return result -def save_full_chart_message(session: SessionDep, record_id: int, full_message: str) -> ChatRecord: - return save_full_chart_message_and_answer(session=session, record_id=record_id, full_message=full_message, - answer='') - - -def save_full_chart_message_and_answer(session: SessionDep, record_id: int, answer: str, - full_message: str, token_usage: dict = None) -> ChatRecord: +def save_chart_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = get_chat_record_by_id(session, record_id) - record.full_chart_message = full_message - record.chart_answer = answer - - if token_usage: - record.token_chart = orjson.dumps(token_usage).decode() - - result = ChatRecord(**record.model_dump()) - - stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( - full_chart_message=record.full_chart_message, - chart_answer=record.chart_answer, - token_chart=record.token_chart + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + chart_answer=answer, ) session.execute(stmt) session.commit() - return result + record = get_chat_record_by_id(session, record_id) + + return record def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord: @@ -677,7 +716,8 @@ def get_old_questions(session: SessionDep, datasource: int): if not datasource: return records stmt = select(ChatRecord.question).where( - and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None))).order_by( + and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None), + ChatRecord.error.is_(None))).order_by( ChatRecord.create_time.desc()).limit(20) result = session.execute(stmt) for r in result: diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 07dae5e2..d24e19e1 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -1,9 +1,12 @@ from datetime import datetime +from enum import Enum from typing import List, Optional from fastapi import Body from pydantic import BaseModel from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean +from sqlalchemy import Enum as SQLAlchemyEnum +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import SQLModel, Field from apps.template.filter.generator import get_permissions_template @@ -16,6 +19,53 @@ from apps.template.select_datasource.generator import get_datasource_template +def enum_values(enum_class: type[Enum]) -> list: + """Get values for enum.""" + return [status.value for status in enum_class] + + +class TypeEnum(Enum): + CHAT = "0" + + +# TODO other usage + +class OperationEnum(Enum): + GENERATE_SQL = '0' + GENERATE_CHART = '1' + ANALYSIS = '2' + PREDICT_DATA = '3' + GENERATE_RECOMMENDED_QUESTIONS = '4' + GENERATE_SQL_WITH_PERMISSIONS = '5' + CHOOSE_DATASOURCE = '6' + GENERATE_DYNAMIC_SQL = '7' + + +class ChatFinishStep(Enum): + GENERATE_SQL = 1 + QUERY_DATA = 2 + GENERATE_CHART = 3 + + +# TODO choose table / check connection / generate description + +class ChatLog(SQLModel, table=True): + __tablename__ = "chat_log" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + type: TypeEnum = Field( + sa_column=Column(SQLAlchemyEnum(TypeEnum, native_enum=False, values_callable=enum_values, length=3))) + operate: OperationEnum = Field( + sa_column=Column(SQLAlchemyEnum(OperationEnum, native_enum=False, values_callable=enum_values, length=3))) + pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True)) + ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger)) + base_modal: Optional[str] = Field(max_length=255) + messages: Optional[list[dict]] = Field(sa_column=Column(JSONB)) + reasoning_content: Optional[str | None] = Field(sa_column=Column(Text, nullable=True)) + start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB)) + + class Chat(SQLModel, table=True): __tablename__ = "chat" id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) @@ -54,29 +104,45 @@ class ChatRecord(SQLModel, table=True): recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True)) recommended_question: str = Field(sa_column=Column(Text, nullable=True)) datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True)) - full_sql_message: str = Field(sa_column=Column(Text, nullable=True)) - token_sql: str = Field(max_length=256, nullable=True) - full_chart_message: str = Field(sa_column=Column(Text, nullable=True)) - token_chart: str = Field(max_length=256, nullable=True) - full_analysis_message: str = Field(sa_column=Column(Text, nullable=True)) - token_analysis: str = Field(max_length=256, nullable=True) - full_predict_message: str = Field(sa_column=Column(Text, nullable=True)) - token_predict: str = Field(max_length=256, nullable=True) - full_recommended_question_message: str = Field(sa_column=Column(Text, nullable=True)) - token_recommended_question: str = Field(max_length=256, nullable=True) - full_select_datasource_message: str = Field(sa_column=Column(Text, nullable=True)) - token_select_datasource_question: str = Field(max_length=256, nullable=True) finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False)) error: str = Field(sa_column=Column(Text, nullable=True)) analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) +class ChatRecordResult(BaseModel): + id: Optional[int] = None + chat_id: Optional[int] = None + ai_modal_id: Optional[int] = None + first_chat: bool = False + create_time: Optional[datetime] = None + finish_time: Optional[datetime] = None + question: Optional[str] = None + sql_answer: Optional[str] = None + sql: Optional[str] = None + data: Optional[str] = None + chart_answer: Optional[str] = None + chart: Optional[str] = None + analysis: Optional[str] = None + predict: Optional[str] = None + predict_data: Optional[str] = None + recommended_question: Optional[str] = None + datasource_select_answer: Optional[str] = None + finish: Optional[bool] = None + error: Optional[str] = None + analysis_record_id: Optional[int] = None + predict_record_id: Optional[int] = None + sql_reasoning_content: Optional[str] = None + chart_reasoning_content: Optional[str] = None + analysis_reasoning_content: Optional[str] = None + predict_reasoning_content: Optional[str] = None + + class CreateChat(BaseModel): id: int = None question: str = None datasource: int = None - origin: Optional[int] = 0 + origin: Optional[int] = 0 # 0是页面上,mcp是1,小助手是2 class RenameChat(BaseModel): @@ -99,7 +165,9 @@ class ChatInfo(BaseModel): class AiModelQuestion(BaseModel): + question: str = None ai_modal_id: int = None + ai_modal_name: str = None # Specific model name engine: str = "" db_schema: str = "" sql: str = "" @@ -109,29 +177,36 @@ class AiModelQuestion(BaseModel): lang: str = "简体中文" filter: str = [] sub_query: Optional[list[dict]] = None + terminologies: str = "" + data_training: str = "" + custom_prompt: str = "" + error_msg: str = "" def sql_sys_question(self): return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, - lang=self.lang) + lang=self.lang, terminologies=self.terminologies, + data_training=self.data_training, custom_prompt=self.custom_prompt) - def sql_user_question(self): + def sql_user_question(self, current_time: str): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, - rule=self.rule) + rule=self.rule, current_time=current_time, error_msg=self.error_msg) def chart_sys_question(self): return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) - def chart_user_question(self): - return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule) + def chart_user_question(self, chart_type: Optional[str] = None): + return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule, + chart_type=chart_type) def analysis_sys_question(self): - return get_analysis_template()['system'].format(lang=self.lang) + return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies, + custom_prompt=self.custom_prompt) def analysis_user_question(self): return get_analysis_template()['user'].format(fields=self.fields, data=self.data) def predict_sys_question(self): - return get_predict_template()['system'].format(lang=self.lang) + return get_predict_template()['system'].format(lang=self.lang, custom_prompt=self.custom_prompt) def predict_user_question(self): return get_predict_template()['user'].format(fields=self.fields, data=self.data) @@ -163,7 +238,6 @@ def dynamic_user_question(self): class ChatQuestion(AiModelQuestion): - question: str chat_id: int @@ -180,6 +254,7 @@ class McpQuestion(BaseModel): question: str = Body(description='用户提问') chat_id: int = Body(description='会话ID') token: str = Body(description='token') + stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) class AxisObj(BaseModel): @@ -192,3 +267,10 @@ class ExcelData(BaseModel): axis: list[AxisObj] = [] data: list[dict] = [] name: str = 'Excel' + + +class McpAssistant(BaseModel): + question: str = Body(description='用户提问') + url: str = Body(description='第三方数据接口') + authorization: str = Body(description='第三方接口凭证') + stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index b0b00bc7..7b46c83b 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -1,9 +1,12 @@ import concurrent import json +import os import traceback +import urllib.parse import warnings from concurrent.futures import ThreadPoolExecutor, Future -from typing import Any, List, Optional, Union, Dict +from datetime import datetime +from typing import Any, List, Optional, Union, Dict, Iterator import numpy as np import orjson @@ -13,35 +16,50 @@ from langchain.chat_models.base import BaseChatModel from langchain_community.utilities import SQLDatabase from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk -from sqlalchemy import select +from sqlalchemy import and_, select from sqlalchemy.orm import sessionmaker -from sqlmodel import create_engine, Session +from sqlmodel import Session from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config -from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \ - save_error_message, save_sql_exec_data, save_full_chart_message, save_full_chart_message_and_answer, save_chart, \ - finish_record, save_full_analysis_message_and_answer, save_full_predict_message_and_answer, save_predict_data, \ - save_full_select_datasource_message_and_answer, save_full_recommend_question_message_and_answer, \ - get_old_questions, save_analysis_predict_record, list_base_records, rename_chat, get_chart_config, \ - get_chat_chart_data -from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat +from apps.chat.curd.chat import save_question, save_sql_answer, save_sql, \ + save_error_message, save_sql_exec_data, save_chart_answer, save_chart, \ + finish_record, save_analysis_answer, save_predict_answer, save_predict_data, \ + save_select_datasource_answer, save_recommend_question_answer, \ + get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ + get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ + get_last_execute_sql_error +from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ + ChatFinishStep +from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil +from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts +from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum +from apps.data_training.curd.data_training import get_training_template from apps.datasource.crud.datasource import get_table_schema from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user +from apps.datasource.embedding.ds_embedding import get_ds_embedding from apps.datasource.models.datasource import CoreDatasource -from apps.db.db import exec_sql +from apps.db.db import exec_sql, get_version, check_connection from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds from apps.system.schemas.system_schema import AssistantOutDsSchema +from apps.terminology.curd.terminology import get_terminology_template from common.core.config import settings +from common.core.db import engine from common.core.deps import CurrentAssistant, CurrentUser -from common.error import SingleMessageError +from common.error import SingleMessageError, SQLBotDBError, ParseSQLResultError, SQLBotDBConnectionError from common.utils.utils import SQLBotLogUtil, extract_nested_json, prepare_for_orjson warnings.filterwarnings("ignore") -base_message_count_limit = 5 +base_message_count_limit = 6 executor = ThreadPoolExecutor(max_workers=200) +dynamic_ds_types = [1, 3] +dynamic_subsql_prefix = 'select * from sqlbot_dynamic_temp_table_' + +session_maker = sessionmaker(bind=engine) +db_session = session_maker() + class LLMService: ds: CoreDatasource @@ -51,22 +69,30 @@ class LLMService: llm: BaseChatModel sql_message: List[Union[BaseMessage, dict[str, Any]]] = [] chart_message: List[Union[BaseMessage, dict[str, Any]]] = [] - history_records: List[ChatRecord] = [] - session: Session + + session: Session = db_session current_user: CurrentUser current_assistant: Optional[CurrentAssistant] = None out_ds_instance: Optional[AssistantOutDs] = None change_title: bool = False + generate_sql_logs: List[ChatLog] = [] + generate_chart_logs: List[ChatLog] = [] + + current_logs: dict[OperationEnum, ChatLog] = {} + chunk_list: List[str] = [] future: Future + last_execute_sql_error: str = None + def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, - current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False): + current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False, + embedding: bool = False, config: LLMConfig = None): self.chunk_list = [] - engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) - session_maker = sessionmaker(bind=engine) - self.session = session_maker() + # engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) + # session_maker = sessionmaker(bind=engine) + # self.session = session_maker() self.session.exec = self.session.exec if hasattr(self.session, "exec") else self.session.execute self.current_user = current_user self.current_assistant = current_assistant @@ -79,32 +105,31 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, if chat.datasource: # Get available datasource # ds = self.session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first() - if current_assistant and current_assistant.type == 1: + if current_assistant and current_assistant.type in dynamic_ds_types: self.out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant) ds = self.out_ds_instance.get_ds(chat.datasource) if not ds: raise SingleMessageError("No available datasource configuration found") - chat_question.engine = ds.type + chat_question.engine = ds.type + get_version(ds) chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id) else: ds = self.session.get(CoreDatasource, chat.datasource) if not ds: raise SingleMessageError("No available datasource configuration found") - chat_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL' - chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds) + chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds) + chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds, + question=chat_question.question, embedding=embedding) - history_records: List[ChatRecord] = list( - map(lambda x: ChatRecord(**x.model_dump()), filter(lambda r: True if r.first_chat != True else False, - list_base_records(session=self.session, - current_user=current_user, - chart_id=chat_id)))) - self.change_title = len(history_records) == 0 + self.generate_sql_logs = list_generate_sql_logs(session=self.session, chart_id=chat_id) + self.generate_chart_logs = list_generate_chart_logs(session=self.session, chart_id=chat_id) + + self.change_title = len(self.generate_sql_logs) == 0 chat_question.lang = get_lang_name(current_user.language) self.ds = (ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None self.chat_question = chat_question - self.config = get_default_config() + self.config = config if no_reasoning: # only work while using qwen if self.config.additional_params: @@ -113,14 +138,26 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, del self.config.additional_params['extra_body']['enable_thinking'] self.chat_question.ai_modal_id = self.config.model_id + self.chat_question.ai_modal_name = self.config.model_name # Create LLM instance through factory llm_instance = LLMFactory.create_llm(self.config) self.llm = llm_instance.llm - self.history_records = history_records + # get last_execute_sql_error + last_execute_sql_error = get_last_execute_sql_error(self.session, self.chat_question.chat_id) + if last_execute_sql_error: + self.chat_question.error_msg = f''' +{last_execute_sql_error} +''' + else: + self.chat_question.error_msg = '' - self.init_messages() + @classmethod + async def create(cls, *args, **kwargs): + config: LLMConfig = await get_default_config() + instance = cls(*args, **kwargs, config=config) + return instance def is_running(self, timeout=0.5): try: @@ -133,21 +170,8 @@ def is_running(self, timeout=0.5): return True def init_messages(self): - # self.agent_executor = create_react_agent(self.llm) - last_sql_messages = list( - filter(lambda r: True if r.full_sql_message is not None and r.full_sql_message.strip() != '' else False, - self.history_records)) - last_sql_message_str = "[]" if last_sql_messages is None or len(last_sql_messages) == 0 else last_sql_messages[ - -1].full_sql_message - - last_chart_messages = list( - filter( - lambda r: True if r.full_chart_message is not None and r.full_chart_message.strip() != '' else False, - self.history_records)) - last_chart_message_str = "[]" if last_chart_messages is None or len(last_chart_messages) == 0 else \ - last_chart_messages[-1].full_chart_message - - last_sql_messages: List[dict[str, Any]] = orjson.loads(last_sql_message_str) + last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len( + self.generate_sql_logs) > 0 else [] # todo maybe can configure count_limit = 0 - base_message_count_limit @@ -166,7 +190,8 @@ def init_messages(self): _msg = AIMessage(content=last_sql_message['content']) self.sql_message.append(_msg) - last_chart_messages: List[dict[str, Any]] = orjson.loads(last_chart_message_str) + last_chart_messages: List[dict[str, Any]] = self.generate_chart_logs[-1].messages if len( + self.generate_chart_logs) > 0 else [] self.chart_message = [] # add sys prompt @@ -176,11 +201,11 @@ def init_messages(self): # limit count for last_chart_message in last_chart_messages: _msg: BaseMessage - if last_chart_message['type'] == 'human': - _msg = HumanMessage(content=last_chart_message['content']) + if last_chart_message.get('type') == 'human': + _msg = HumanMessage(content=last_chart_message.get('content')) self.chart_message.append(_msg) - elif last_chart_message['type'] == 'ai': - _msg = AIMessage(content=last_chart_message['content']) + elif last_chart_message.get('type') == 'ai': + _msg = AIMessage(content=last_chart_message.get('content')) self.chart_message.append(_msg) def init_record(self) -> ChatRecord: @@ -218,87 +243,100 @@ def generate_analysis(self): data = get_chat_chart_data(self.session, self.record.id) self.chat_question.data = orjson.dumps(data.get('data')).decode() analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = [] + + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None + self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, + self.current_user.oid, ds_id) + if SQLBotLicenseUtil.valid(): + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS, + self.current_user.oid, ds_id) + analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) - self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, answer='', - full_message=orjson.dumps([{'type': msg.type, - 'content': msg.content} for msg - in - analysis_msg]).decode()) + self.current_logs[OperationEnum.ANALYSIS] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.ANALYSIS, + record_id=self.record.id, + full_message=[ + {'type': msg.type, + 'content': msg.content} for + msg + in analysis_msg]) full_thinking_text = '' full_analysis_text = '' - res = self.llm.stream(analysis_msg) token_usage = {} + res = process_stream(self.llm.stream(analysis_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_analysis_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_analysis_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk analysis_msg.append(AIMessage(full_analysis_text)) - self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, - token_usage=token_usage, - answer=orjson.dumps({'content': full_analysis_text, - 'reasoning_content': full_thinking_text}).decode(), - full_message=orjson.dumps([{'type': msg.type, - 'content': msg.content} for msg - in - analysis_msg]).decode()) + + self.current_logs[OperationEnum.ANALYSIS] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.ANALYSIS], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in analysis_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) + self.record = save_analysis_answer(session=self.session, record_id=self.record.id, + answer=orjson.dumps({'content': full_analysis_text}).decode()) def generate_predict(self): fields = self.get_fields_from_chart() self.chat_question.fields = orjson.dumps(fields).decode() data = get_chat_chart_data(self.session, self.record.id) self.chat_question.data = orjson.dumps(data.get('data')).decode() + + if SQLBotLicenseUtil.valid(): + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA, + self.current_user.oid, ds_id) + predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question())) - self.record = save_full_predict_message_and_answer(session=self.session, record_id=self.record.id, answer='', - data='', - full_message=orjson.dumps([{'type': msg.type, - 'content': msg.content} for msg - in - predict_msg]).decode()) + self.current_logs[OperationEnum.PREDICT_DATA] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.PREDICT_DATA, + record_id=self.record.id, + full_message=[ + {'type': msg.type, + 'content': msg.content} for + msg + in predict_msg]) full_thinking_text = '' full_predict_text = '' - res = self.llm.stream(predict_msg) token_usage = {} + res = process_stream(self.llm.stream(predict_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_predict_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_predict_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk predict_msg.append(AIMessage(full_predict_text)) - self.record = save_full_predict_message_and_answer(session=self.session, record_id=self.record.id, - token_usage=token_usage, - answer=orjson.dumps({'content': full_predict_text, - 'reasoning_content': full_thinking_text}).decode(), - data='', - full_message=orjson.dumps([{'type': msg.type, - 'content': msg.content} for msg - in - predict_msg]).decode()) + self.record = save_predict_answer(session=self.session, record_id=self.record.id, + answer=orjson.dumps({'content': full_predict_text}).decode()) + self.current_logs[OperationEnum.PREDICT_DATA] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.PREDICT_DATA], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in predict_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) def generate_recommend_questions_task(self): @@ -306,7 +344,9 @@ def generate_recommend_questions_task(self): if self.ds and not self.chat_question.db_schema: self.chat_question.db_schema = self.out_ds_instance.get_db_schema( self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, - current_user=self.current_user, ds=self.ds) + current_user=self.current_user, ds=self.ds, + question=self.chat_question.question, + embedding=False) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) @@ -315,52 +355,51 @@ def generate_recommend_questions_task(self): guess_msg.append( HumanMessage(content=self.chat_question.guess_user_question(orjson.dumps(old_questions).decode()))) - self.record = save_full_recommend_question_message_and_answer(session=self.session, record_id=self.record.id, - full_message=orjson.dumps([{'type': msg.type, - 'content': msg.content} - for msg - in - guess_msg]).decode()) + self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.GENERATE_RECOMMENDED_QUESTIONS, + record_id=self.record.id, + full_message=[ + {'type': msg.type, + 'content': msg.content} for + msg + in guess_msg]) full_thinking_text = '' full_guess_text = '' token_usage = {} - res = self.llm.stream(guess_msg) + res = process_stream(self.llm.stream(guess_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_guess_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_guess_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk guess_msg.append(AIMessage(full_guess_text)) - self.record = save_full_recommend_question_message_and_answer(session=self.session, record_id=self.record.id, - token_usage=token_usage, - answer={'content': full_guess_text, - 'reasoning_content': full_thinking_text}, - full_message=orjson.dumps([{'type': msg.type, - 'content': msg.content} - for msg - in - guess_msg]).decode()) + + self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.GENERATE_RECOMMENDED_QUESTIONS], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in guess_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) + self.record = save_recommend_question_answer(session=self.session, record_id=self.record.id, + answer={'content': full_guess_text}) + yield {'recommended_question': self.record.recommended_question} def select_datasource(self): datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = [] datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question())) - if self.current_assistant: + if self.current_assistant and self.current_assistant.type != 4: _ds_list = get_assistant_ds(session=self.session, llm_service=self) else: - oid: str = self.current_user.oid stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where( - CoreDatasource.oid == oid) + and_(CoreDatasource.oid == self.current_user.oid)) _ds_list = [ { "id": ds.id, @@ -371,65 +410,74 @@ def select_datasource(self): ] """ _ds_list = self.session.exec(select(CoreDatasource).options( load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all() """ - + if not _ds_list: + raise SingleMessageError('No available datasource configuration found') ignore_auto_select = _ds_list and len(_ds_list) == 1 # ignore auto select ds + full_thinking_text = '' + full_text = '' if not ignore_auto_select: - _ds_list_dict = [] - for _ds in _ds_list: - _ds_list_dict.append(_ds) - datasource_msg.append( - HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) - - history_msg = [] - if self.record.full_select_datasource_message and self.record.full_select_datasource_message.strip() != '': - history_msg = orjson.loads(self.record.full_select_datasource_message) - - self.record = save_full_select_datasource_message_and_answer(session=self.session, record_id=self.record.id, - answer='', - full_message=orjson.dumps(history_msg + - [{'type': msg.type, - 'content': msg.content} - for msg - in - datasource_msg]).decode()) - full_thinking_text = '' - full_text = '' - token_usage = {} - res = self.llm.stream(datasource_msg) - for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) - datasource_msg.append(AIMessage(full_text)) - - json_str = extract_nested_json(full_text) + if settings.TABLE_EMBEDDING_ENABLED: + ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.out_ds_instance, + self.chat_question.question, self.current_assistant) + yield {'content': '{"id":' + str(ds.get('id')) + '}'} + else: + _ds_list_dict = [] + for _ds in _ds_list: + _ds_list_dict.append(_ds) + datasource_msg.append( + HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.CHOOSE_DATASOURCE, + record_id=self.record.id, + full_message=[{'type': msg.type, + 'content': msg.content} + for + msg in datasource_msg]) + + token_usage = {} + res = process_stream(self.llm.stream(datasource_msg), token_usage) + for chunk in res: + if chunk.get('content'): + full_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk + datasource_msg.append(AIMessage(full_text)) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.CHOOSE_DATASOURCE], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in datasource_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) + + json_str = extract_nested_json(full_text) + if json_str is None: + raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}') + ds = orjson.loads(json_str) _error: Exception | None = None _datasource: int | None = None _engine_type: str | None = None try: - data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str) + data: dict = _ds_list[0] if ignore_auto_select else ds if data.get('id') and data.get('id') != 0: _datasource = data['id'] _chat = self.session.get(Chat, self.record.chat_id) _chat.datasource = _datasource - if self.current_assistant and self.current_assistant.type == 1: + if self.current_assistant and self.current_assistant.type in dynamic_ds_types: _ds = self.out_ds_instance.get_ds(data['id']) self.ds = _ds - self.chat_question.engine = _ds.type + self.chat_question.engine = _ds.type + get_version(self.ds) self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id) _engine_type = self.chat_question.engine _chat.engine_type = _ds.type @@ -439,16 +487,24 @@ def select_datasource(self): _datasource = None raise SingleMessageError(f"Datasource configuration with id {_datasource} not found") self.ds = CoreDatasource(**_ds.model_dump()) - self.chat_question.engine = _ds.type_name if _ds.type != 'excel' else 'PostgreSQL' + self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version( + self.ds) self.chat_question.db_schema = get_table_schema(session=self.session, - current_user=self.current_user, ds=self.ds) + current_user=self.current_user, ds=self.ds, + question=self.chat_question.question) _engine_type = self.chat_question.engine _chat.engine_type = _ds.type_name # save chat - self.session.add(_chat) - self.session.flush() - self.session.refresh(_chat) - self.session.commit() + with self.session.begin_nested(): + # 为了能继续记日志,先单独处理下事务 + try: + self.session.add(_chat) + self.session.flush() + self.session.refresh(_chat) + self.session.commit() + except Exception as e: + self.session.rollback() + raise e elif data['fail']: raise SingleMessageError(data['fail']) @@ -458,79 +514,102 @@ def select_datasource(self): except Exception as e: _error = e - if not ignore_auto_select: - self.record = save_full_select_datasource_message_and_answer(session=self.session, record_id=self.record.id, - answer=orjson.dumps({'content': full_text, - 'reasoning_content': full_thinking_text}).decode(), - datasource=_datasource, - engine_type=_engine_type, - full_message=orjson.dumps(history_msg + - [{'type': msg.type, - 'content': msg.content} - for msg - in - datasource_msg]).decode()) - self.init_messages() + if not ignore_auto_select and not settings.TABLE_EMBEDDING_ENABLED: + self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id, + answer=orjson.dumps({'content': full_text}).decode(), + datasource=_datasource, + engine_type=_engine_type) + if self.ds: + oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None + + self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid, + ds_id) + self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id, + oid) + if SQLBotLicenseUtil.valid(): + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, + oid, ds_id) + + self.init_messages() if _error: raise _error def generate_sql(self): # append current question - self.sql_message.append(HumanMessage(self.chat_question.sql_user_question())) - self.record = save_full_sql_message(session=self.session, record_id=self.record.id, - full_message=orjson.dumps( - [{'type': msg.type, 'content': msg.content} for msg in - self.sql_message]).decode()) + self.sql_message.append(HumanMessage( + self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')))) + + self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.GENERATE_SQL, + record_id=self.record.id, + full_message=[ + {'type': msg.type, 'content': msg.content} for msg + in self.sql_message]) full_thinking_text = '' full_sql_text = '' token_usage = {} - res = self.llm.stream(self.sql_message) + res = process_stream(self.llm.stream(self.sql_message), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_sql_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_sql_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk self.sql_message.append(AIMessage(full_sql_text)) - self.record = save_full_sql_message_and_answer(session=self.session, record_id=self.record.id, - token_usage=token_usage, - answer=orjson.dumps({'content': full_sql_text, - 'reasoning_content': full_thinking_text}).decode(), - full_message=orjson.dumps( - [{'type': msg.type, 'content': msg.content} for msg in - self.sql_message]).decode()) + + self.current_logs[OperationEnum.GENERATE_SQL] = end_log(session=self.session, + log=self.current_logs[OperationEnum.GENERATE_SQL], + full_message=[{'type': msg.type, 'content': msg.content} + for msg in self.sql_message], + reasoning_content=full_thinking_text, + token_usage=token_usage) + self.record = save_sql_answer(session=self.session, record_id=self.record.id, + answer=orjson.dumps({'content': full_sql_text}).decode()) def generate_with_sub_sql(self, sql, sub_mappings: list): sub_query = json.dumps(sub_mappings, ensure_ascii=False) self.chat_question.sql = sql self.chat_question.sub_query = sub_query - msg: List[Union[BaseMessage, dict[str, Any]]] = [] - msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question())) - msg.append(HumanMessage(content=self.chat_question.dynamic_user_question())) + dynamic_sql_msg: List[Union[BaseMessage, dict[str, Any]]] = [] + dynamic_sql_msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question())) + dynamic_sql_msg.append(HumanMessage(content=self.chat_question.dynamic_user_question())) + + self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.GENERATE_DYNAMIC_SQL, + record_id=self.record.id, + full_message=[{'type': msg.type, + 'content': msg.content} + for + msg in dynamic_sql_msg]) + full_thinking_text = '' full_dynamic_text = '' - res = self.llm.stream(msg) token_usage = {} + res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - full_dynamic_text += chunk.content - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_dynamic_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + + dynamic_sql_msg.append(AIMessage(full_dynamic_text)) + + self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.GENERATE_DYNAMIC_SQL], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in dynamic_sql_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) SQLBotLogUtil.info(full_dynamic_text) return full_dynamic_text @@ -538,60 +617,58 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): def generate_assistant_dynamic_sql(self, sql, tables: List): ds: AssistantOutDsSchema = self.ds sub_query = [] + result_dict = {} for table in ds.tables: if table.name in tables and table.sql: - sub_query.append({"table": table.name, "query": table.sql}) + # sub_query.append({"table": table.name, "query": table.sql}) + result_dict[table.name] = table.sql + sub_query.append({"table": table.name, "query": f'{dynamic_subsql_prefix}{table.name}'}) if not sub_query: return None - return self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query) + temp_sql_text = self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query) + result_dict['sqlbot_temp_sql_text'] = temp_sql_text + return result_dict def build_table_filter(self, sql: str, filters: list): filter = json.dumps(filters, ensure_ascii=False) self.chat_question.sql = sql self.chat_question.filter = filter - msg: List[Union[BaseMessage, dict[str, Any]]] = [] - msg.append(SystemMessage(content=self.chat_question.filter_sys_question())) - msg.append(HumanMessage(content=self.chat_question.filter_user_question())) - - history_msg = [] - # if self.record.full_analysis_message and self.record.full_analysis_message.strip() != '': - # history_msg = orjson.loads(self.record.full_analysis_message) - - # self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, answer='', - # full_message=orjson.dumps(history_msg + - # [{'type': msg.type, - # 'content': msg.content} for msg - # in - # msg]).decode()) + permission_sql_msg: List[Union[BaseMessage, dict[str, Any]]] = [] + permission_sql_msg.append(SystemMessage(content=self.chat_question.filter_sys_question())) + permission_sql_msg.append(HumanMessage(content=self.chat_question.filter_user_question())) + + self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.GENERATE_SQL_WITH_PERMISSIONS, + record_id=self.record.id, + full_message=[ + {'type': msg.type, + 'content': msg.content} for + msg + in permission_sql_msg]) full_thinking_text = '' full_filter_text = '' - res = self.llm.stream(msg) token_usage = {} + res = process_stream(self.llm.stream(permission_sql_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_filter_text += chunk.content - # yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_filter_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + + permission_sql_msg.append(AIMessage(full_filter_text)) + + self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.GENERATE_SQL_WITH_PERMISSIONS], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in permission_sql_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) - msg.append(AIMessage(full_filter_text)) - # self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, - # token_usage=token_usage, - # answer=orjson.dumps({'content': full_analysis_text, - # 'reasoning_content': full_thinking_text}).decode(), - # full_message=orjson.dumps(history_msg + - # [{'type': msg.type, - # 'content': msg.content} for msg - # in - # analysis_msg]).decode()) SQLBotLogUtil.info(full_filter_text) return full_filter_text @@ -612,42 +689,44 @@ def generate_assistant_filter(self, sql, tables: List): return None return self.build_table_filter(sql=sql, filters=filters) - def generate_chart(self): + def generate_chart(self, chart_type: Optional[str] = ''): # append current question - self.chart_message.append(HumanMessage(self.chat_question.chart_user_question())) - self.record = save_full_chart_message(session=self.session, record_id=self.record.id, - full_message=orjson.dumps( - [{'type': msg.type, 'content': msg.content} for msg in - self.chart_message]).decode()) + self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type))) + + self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.GENERATE_CHART, + record_id=self.record.id, + full_message=[ + {'type': msg.type, 'content': msg.content} for + msg + in self.chart_message]) full_thinking_text = '' full_chart_text = '' token_usage = {} - res = self.llm.stream(self.chart_message) + res = process_stream(self.llm.stream(self.chart_message), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_chart_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_chart_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk self.chart_message.append(AIMessage(full_chart_text)) - self.record = save_full_chart_message_and_answer(session=self.session, record_id=self.record.id, - token_usage=token_usage, - answer=orjson.dumps({'content': full_chart_text, - 'reasoning_content': full_thinking_text}).decode(), - full_message=orjson.dumps( - [{'type': msg.type, 'content': msg.content} for msg in - self.chart_message]).decode()) - - def check_sql(self, res: str) -> tuple[any]: + + self.record = save_chart_answer(session=self.session, record_id=self.record.id, + answer=orjson.dumps({'content': full_chart_text}).decode()) + self.current_logs[OperationEnum.GENERATE_CHART] = end_log(session=self.session, + log=self.current_logs[OperationEnum.GENERATE_CHART], + full_message=[ + {'type': msg.type, 'content': msg.content} + for msg in self.chart_message], + reasoning_content=full_thinking_text, + token_usage=token_usage) + + @staticmethod + def check_sql(res: str) -> tuple[str, Optional[list]]: json_str = extract_nested_json(res) if json_str is None: raise SingleMessageError(orjson.dumps({'message': 'Cannot parse sql from answer', @@ -672,6 +751,26 @@ def check_sql(self, res: str) -> tuple[any]: raise SingleMessageError("SQL query is empty") return sql, data.get('tables') + @staticmethod + def get_chart_type_from_sql_answer(res: str) -> Optional[str]: + json_str = extract_nested_json(res) + if json_str is None: + return None + + chart_type: Optional[str] + data: dict + try: + data = orjson.loads(json_str) + + if data['success']: + chart_type = data['chart-type'] + else: + return None + except Exception: + return None + + return chart_type + def check_save_sql(self, res: str) -> str: sql, *_ = self.check_sql(res=res) save_sql(session=self.session, sql=sql, record_id=self.record.id) @@ -771,7 +870,14 @@ def execute_sql(self, sql: str): Query results """ SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}") - return exec_sql(self.ds, sql) + try: + return exec_sql(ds=self.ds, sql=sql, origin_column=False) + except Exception as e: + if isinstance(e, ParseSQLResultError): + raise e + else: + err = traceback.format_exc(limit=1, chain=True) + raise SQLBotDBError(err) def pop_chunk(self): try: @@ -794,18 +900,39 @@ def await_result(self): break yield chunk - def run_task_async(self, in_chat: bool = True): - self.future = executor.submit(self.run_task_cache, in_chat) + def run_task_async(self, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): + if in_chat: + stream = True + self.future = executor.submit(self.run_task_cache, in_chat, stream, finish_step) - def run_task_cache(self, in_chat: bool = True): - for chunk in self.run_task(in_chat): + def run_task_cache(self, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): + for chunk in self.run_task(in_chat, stream, finish_step): self.chunk_list.append(chunk) - def run_task(self, in_chat: bool = True): + def run_task(self, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): + json_result: Dict[str, Any] = {'success': True} try: + if self.ds: + oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None + self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, + oid, ds_id) + self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, + ds_id, oid) + if SQLBotLicenseUtil.valid(): + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, + oid, ds_id) + + self.init_messages() + # return id if in_chat: yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n' + if not stream: + json_result['record_id'] = self.get_record().id # return title if self.change_title: @@ -815,8 +942,10 @@ def run_task(self, in_chat: bool = True): brief=self.chat_question.question.strip()[:20])) if in_chat: yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' + if not stream: + json_result['title'] = brief - # select datasource if datasource is none + # select datasource if datasource is none if not self.ds: ds_res = self.select_datasource() @@ -828,15 +957,22 @@ def run_task(self, in_chat: bool = True): 'type': 'datasource-result'}).decode() + '\n\n' if in_chat: yield 'data:' + orjson.dumps({'id': self.ds.id, 'datasource_name': self.ds.name, - 'engine_type': self.ds.type_name or self.ds.type, - 'type': 'datasource'}).decode() + '\n\n' + 'engine_type': self.ds.type_name or self.ds.type, + 'type': 'datasource'}).decode() + '\n\n' self.chat_question.db_schema = self.out_ds_instance.get_db_schema( self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, current_user=self.current_user, - ds=self.ds) + ds=self.ds, + question=self.chat_question.question) else: self.validate_history_ds() + + # check connection + connected = check_connection(ds=self.ds, trans=None) + if not connected: + raise SQLBotDBConnectionError('Connect DB failed') + # generate sql sql_res = self.generate_sql() full_sql_text = '' @@ -848,51 +984,104 @@ def run_task(self, in_chat: bool = True): 'type': 'sql-result'}).decode() + '\n\n' if in_chat: yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n' - # filter sql SQLBotLogUtil.info(full_sql_text) - use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type == 1 + chart_type = self.get_chart_type_from_sql_answer(full_sql_text) + + use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types + is_page_embedded: bool = self.current_assistant and self.current_assistant.type == 4 + dynamic_sql_result = None + sqlbot_temp_sql_text = None + assistant_dynamic_sql = None # todo row permission - if (not self.current_assistant and is_normal_user(self.current_user)) or use_dynamic_ds: + if ((not self.current_assistant or is_page_embedded) and is_normal_user( + self.current_user)) or use_dynamic_ds: sql, tables = self.check_sql(res=full_sql_text) sql_result = None - dynamic_sql_result = None - if self.current_assistant: - dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables) - if dynamic_sql_result: - SQLBotLogUtil.info(dynamic_sql_result) - sql, *_ = self.check_sql(res=dynamic_sql_result) - sql_result = self.generate_assistant_filter(sql, tables) + if use_dynamic_ds: + dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables) + sqlbot_temp_sql_text = dynamic_sql_result.get( + 'sqlbot_temp_sql_text') if dynamic_sql_result else None + # sql_result = self.generate_assistant_filter(sql, tables) else: sql_result = self.generate_filter(sql, tables) # maybe no sql and tables if sql_result: SQLBotLogUtil.info(sql_result) sql = self.check_save_sql(res=sql_result) - elif dynamic_sql_result: - sql = self.check_save_sql(res=dynamic_sql_result) + elif dynamic_sql_result and sqlbot_temp_sql_text: + assistant_dynamic_sql = self.check_save_sql(res=sqlbot_temp_sql_text) else: sql = self.check_save_sql(res=full_sql_text) else: sql = self.check_save_sql(res=full_sql_text) - SQLBotLogUtil.info(sql) + SQLBotLogUtil.info('sql: ' + sql) + + if not stream: + json_result['sql'] = sql + format_sql = sqlparse.format(sql, reindent=True) if in_chat: yield 'data:' + orjson.dumps({'content': format_sql, 'type': 'sql'}).decode() + '\n\n' else: - yield f'```sql\n{format_sql}\n```\n\n' + if stream: + yield f'```sql\n{format_sql}\n```\n\n' # execute sql - result = self.execute_sql(sql=sql) + real_execute_sql = sql + if sqlbot_temp_sql_text and assistant_dynamic_sql: + dynamic_sql_result.pop('sqlbot_temp_sql_text') + for origin_table, subsql in dynamic_sql_result.items(): + assistant_dynamic_sql = assistant_dynamic_sql.replace(f'{dynamic_subsql_prefix}{origin_table}', + subsql) + real_execute_sql = assistant_dynamic_sql + + if finish_step.value <= ChatFinishStep.GENERATE_SQL.value: + if in_chat: + yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' + if not stream: + yield json_result + return + + result = self.execute_sql(sql=real_execute_sql) self.save_sql_data(data_obj=result) if in_chat: yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n' + if not stream: + json_result['data'] = result.get('data') + + if finish_step.value <= ChatFinishStep.QUERY_DATA.value: + if stream: + if in_chat: + yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' + else: + data = [] + _fields_list = [] + _fields_skip = False + for _data in result.get('data'): + _row = [] + for field in result.get('fields'): + _row.append(_data.get(field)) + if not _fields_skip: + _fields_list.append(field) + data.append(_row) + _fields_skip = True + + if not data or not _fields_list: + yield 'The SQL execution result is empty.\n\n' + else: + df = pd.DataFrame(np.array(data), columns=_fields_list) + markdown_table = df.to_markdown(index=False) + yield markdown_table + '\n\n' + else: + yield json_result + return # generate chart - chart_res = self.generate_chart() + chart_res = self.generate_chart(chart_type) full_chart_text = '' for chunk in chart_res: full_chart_text += chunk.get('content') @@ -907,38 +1096,47 @@ def run_task(self, in_chat: bool = True): SQLBotLogUtil.info(full_chart_text) chart = self.check_save_chart(res=full_chart_text) SQLBotLogUtil.info(chart) + + if not stream: + json_result['chart'] = chart + if in_chat: - yield 'data:' + orjson.dumps({'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n' + yield 'data:' + orjson.dumps( + {'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n' else: - data = [] - _fields = {} - if chart.get('columns'): - for _column in chart.get('columns'): - if _column: - _fields[_column.get('value')] = _column.get('name') - if chart.get('axis'): - if chart.get('axis').get('x'): - _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name') - if chart.get('axis').get('y'): - _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name') - if chart.get('axis').get('series'): - _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get( - 'name') - _fields_list = [] - _fields_skip = False - for _data in result.get('data'): - _row = [] - for field in result.get('fields'): - _row.append(_data.get(field)) - if not _fields_skip: - _fields_list.append(field if not _fields.get(field) else _fields.get(field)) - data.append(_row) - _fields_skip = True - df = pd.DataFrame(np.array(data), columns=_fields_list) - markdown_table = df.to_markdown(index=False) - yield markdown_table + '\n\n' - - record = self.finish() + if stream: + data = [] + _fields = {} + if chart.get('columns'): + for _column in chart.get('columns'): + if _column: + _fields[_column.get('value')] = _column.get('name') + if chart.get('axis'): + if chart.get('axis').get('x'): + _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name') + if chart.get('axis').get('y'): + _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name') + if chart.get('axis').get('series'): + _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get( + 'name') + _fields_list = [] + _fields_skip = False + for _data in result.get('data'): + _row = [] + for field in result.get('fields'): + _row.append(_data.get(field)) + if not _fields_skip: + _fields_list.append(field if not _fields.get(field) else _fields.get(field)) + data.append(_row) + _fields_skip = True + + if not data or not _fields_list: + yield 'The SQL execution result is empty.\n\n' + else: + df = pd.DataFrame(np.array(data), columns=_fields_list) + markdown_table = df.to_markdown(index=False) + yield markdown_table + '\n\n' + if in_chat: yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' else: @@ -947,18 +1145,39 @@ def run_task(self, in_chat: bool = True): yield '### generated chart picture\n\n' image_url = request_picture(self.record.chat_id, self.record.id, chart, result) SQLBotLogUtil.info(image_url) - yield f'![{chart["type"]}]({image_url})' + if stream: + yield f'![{chart["type"]}]({image_url})' + else: + json_result['image_url'] = image_url + + if not stream: + yield json_result + except Exception as e: + traceback.print_exc() error_msg: str if isinstance(e, SingleMessageError): error_msg = str(e) + elif isinstance(e, SQLBotDBConnectionError): + error_msg = orjson.dumps( + {'message': str(e), 'type': 'db-connection-err'}).decode() + elif isinstance(e, SQLBotDBError): + error_msg = orjson.dumps( + {'message': 'Execute SQL Failed', 'traceback': str(e), 'type': 'exec-sql-err'}).decode() else: error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode() self.save_error(message=error_msg) if in_chat: yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' else: - yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + if stream: + yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + else: + json_result['success'] = False + json_result['message'] = error_msg + yield json_result + finally: + self.finish() def run_recommend_questions_task_async(self): self.future = executor.submit(self.run_recommend_questions_task_cache) @@ -972,10 +1191,10 @@ def run_recommend_questions_task(self): for chunk in res: if chunk.get('recommended_question'): - yield orjson.dumps( + yield 'data:' + orjson.dumps( {'content': chunk.get('recommended_question'), 'type': 'recommended_question'}).decode() + '\n\n' else: - yield orjson.dumps( + yield 'data:' + orjson.dumps( {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), 'type': 'recommended_question_result'}).decode() + '\n\n' @@ -1037,7 +1256,7 @@ def run_analysis_or_predict_task(self, action_type: str): def validate_history_ds(self): _ds = self.ds - if not self.current_assistant: + if not self.current_assistant or self.current_assistant.type == 4: try: current_ds = self.session.get(CoreDatasource, _ds.id) if not current_ds: @@ -1050,7 +1269,7 @@ def validate_history_ds(self): match_ds = any(item.get("id") == _ds.id for item in _ds_list) if not match_ds: type = self.current_assistant.type - msg = f"ds is invalid [please check ds list and public ds list]" if type == 0 else f"ds is invalid [please check ds api]" + msg = f"[please check ds list and public ds list]" if type == 0 else f"[please check ds api]" raise SingleMessageError(msg) except Exception as e: raise SingleMessageError(f"ds is invalid [{str(e)}]") @@ -1105,8 +1324,7 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): axis.append({'name': series.get('name'), 'value': series.get('value'), 'type': 'series'}) request_obj = { - "path": (settings.MCP_IMAGE_PATH if settings.MCP_IMAGE_PATH[-1] == '/' else ( - settings.MCP_IMAGE_PATH + '/')) + file_name, + "path": os.path.join(settings.MCP_IMAGE_PATH, file_name), "type": chart['type'], "data": orjson.dumps(data.get('data') if data.get('data') else []).decode(), "axis": orjson.dumps(axis).decode(), @@ -1114,12 +1332,16 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): requests.post(url=settings.MCP_IMAGE_HOST, json=request_obj) - return f'{(settings.SERVER_IMAGE_HOST if settings.SERVER_IMAGE_HOST[-1] == "/" else (settings.SERVER_IMAGE_HOST + "/"))}{file_name}.png' + request_path = urllib.parse.urljoin(settings.SERVER_IMAGE_HOST, f"{file_name}.png") + return request_path -def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): + +def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None): try: if chunk.usage_metadata: + if token_usage is None: + token_usage = {} token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens') token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens') token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens') @@ -1127,7 +1349,110 @@ def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): pass +def process_stream(res: Iterator[BaseMessageChunk], + token_usage: Dict[str, Any] = None, + enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED, + start_tag: str = settings.DEFAULT_REASONING_CONTENT_START, + end_tag: str = settings.DEFAULT_REASONING_CONTENT_END + ): + if token_usage is None: + token_usage = {} + in_thinking_block = False # 标记是否在思考过程块中 + current_thinking = '' # 当前收集的思考过程内容 + pending_start_tag = '' # 用于缓存可能被截断的开始标签部分 + + for chunk in res: + SQLBotLogUtil.info(chunk) + reasoning_content_chunk = '' + content = chunk.content + output_content = '' # 实际要输出的内容 + + # 检查additional_kwargs中的reasoning_content + if 'reasoning_content' in chunk.additional_kwargs: + reasoning_content = chunk.additional_kwargs.get('reasoning_content', '') + if reasoning_content is None: + reasoning_content = '' + + # 累积additional_kwargs中的思考内容到current_thinking + current_thinking += reasoning_content + reasoning_content_chunk = reasoning_content + + # 只有当current_thinking不是空字符串时才跳过标签解析 + if not in_thinking_block and current_thinking.strip() != '': + output_content = content # 正常输出content + yield { + 'content': output_content, + 'reasoning_content': reasoning_content_chunk + } + get_token_usage(chunk, token_usage) + continue # 跳过后续的标签解析逻辑 + + # 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑 + # 如果有缓存的开始标签部分,先拼接当前内容 + if pending_start_tag: + content = pending_start_tag + content + pending_start_tag = '' + + # 检查是否开始思考过程块(处理可能被截断的开始标签) + if enable_tag_parsing and not in_thinking_block and start_tag: + if start_tag in content: + start_idx = content.index(start_tag) + # 只有当开始标签前面没有其他文本时才认为是真正的思考块开始 + if start_idx == 0 or content[:start_idx].strip() == '': + # 完整标签存在且前面没有其他文本 + output_content += content[:start_idx] # 输出开始标签之前的内容 + content = content[start_idx + len(start_tag):] # 移除开始标签 + in_thinking_block = True + else: + # 开始标签前面有其他文本,不认为是思考块开始 + output_content += content + content = '' + else: + # 检查是否可能有部分开始标签 + for i in range(1, len(start_tag)): + if content.endswith(start_tag[:i]): + # 只有当当前内容全是空白时才缓存部分标签 + if content[:-i].strip() == '': + pending_start_tag = start_tag[:i] + content = content[:-i] # 移除可能的部分标签 + output_content += content + content = '' + break + + # 处理思考块内容 + if enable_tag_parsing and in_thinking_block and end_tag: + if end_tag in content: + # 找到结束标签 + end_idx = content.index(end_tag) + current_thinking += content[:end_idx] # 收集思考内容 + reasoning_content_chunk += current_thinking # 添加到当前块的思考内容 + content = content[end_idx + len(end_tag):] # 移除结束标签后的内容 + current_thinking = '' # 重置当前思考内容 + in_thinking_block = False + output_content += content # 输出结束标签之后的内容 + else: + # 在遇到结束标签前,持续收集思考内容 + current_thinking += content + reasoning_content_chunk += content + content = '' + + else: + # 不在思考块中或标签解析未启用,正常输出 + output_content += content + + yield { + 'content': output_content, + 'reasoning_content': reasoning_content_chunk + } + get_token_usage(chunk, token_usage) + + def get_lang_name(lang: str): - if lang and lang == 'en': + if not lang: + return '简体中文' + normalized = lang.lower() + if normalized.startswith('en'): return '英文' + if normalized.startswith('ko'): + return '韩语' return '简体中文' diff --git a/backend/apps/data_training/__init__.py b/backend/apps/data_training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/api/__init__.py b/backend/apps/data_training/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py new file mode 100644 index 00000000..25422c2b --- /dev/null +++ b/backend/apps/data_training/api/data_training.py @@ -0,0 +1,39 @@ +from typing import Optional + +from fastapi import APIRouter, Query + +from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training +from apps.data_training.models.data_training_model import DataTrainingInfo +from common.core.deps import SessionDep, CurrentUser, Trans + +router = APIRouter(tags=["DataTraining"], prefix="/system/data-training") + + +@router.get("/page/{current_page}/{page_size}") +async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int, + question: Optional[str] = Query(None, description="搜索问题(可选)")): + current_page, page_size, total_count, total_pages, _list = page_data_training(session, current_page, page_size, + question, + current_user.oid) + + return { + "current_page": current_page, + "page_size": page_size, + "total_count": total_count, + "total_pages": total_pages, + "data": _list + } + + +@router.put("") +async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: DataTrainingInfo): + oid = current_user.oid + if info.id: + return update_training(session, info, oid, trans) + else: + return create_training(session, info, oid, trans) + + +@router.delete("") +async def delete(session: SessionDep, id_list: list[int]): + delete_training(session, id_list) diff --git a/backend/apps/data_training/curd/__init__.py b/backend/apps/data_training/curd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py new file mode 100644 index 00000000..15260a9a --- /dev/null +++ b/backend/apps/data_training/curd/data_training.py @@ -0,0 +1,317 @@ +import datetime +import logging +import traceback +from typing import List, Optional +from xml.dom.minidom import parseString + +import dicttoxml +from sqlalchemy import and_, select, func, delete, update, or_ +from sqlalchemy import text +from sqlalchemy.orm.session import Session + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining +from apps.datasource.models.datasource import CoreDatasource +from apps.template.generate_chart.generator import get_base_data_training_template +from common.core.config import settings +from common.core.deps import SessionDep, Trans +from common.utils.embedding_threads import run_save_data_training_embeddings + + +def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None, + oid: Optional[int] = 1): + _list: List[DataTrainingInfo] = [] + + current_page = max(1, current_page) + page_size = max(10, page_size) + + total_count = 0 + total_pages = 0 + + if name and name.strip() != "": + keyword_pattern = f"%{name.strip()}%" + parent_ids_subquery = ( + select(DataTraining.id) + .where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid)) # LIKE查询条件 + ) + else: + parent_ids_subquery = ( + select(DataTraining.id).where(and_(DataTraining.oid == oid)) + ) + + count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) + total_count = session.execute(count_stmt).scalar() + total_pages = (total_count + page_size - 1) // page_size + + if current_page > total_pages: + current_page = 1 + + paginated_parent_ids = ( + parent_ids_subquery + .order_by(DataTraining.create_time.desc()) + .offset((current_page - 1) * page_size) + .limit(page_size) + .subquery() + ) + + stmt = ( + select( + DataTraining.id, + DataTraining.oid, + DataTraining.datasource, + CoreDatasource.name, + DataTraining.question, + DataTraining.create_time, + DataTraining.description, + ) + .outerjoin(CoreDatasource, and_(DataTraining.datasource == CoreDatasource.id)) + .where(and_(DataTraining.id.in_(paginated_parent_ids))) + .order_by(DataTraining.create_time.desc()) + ) + + result = session.execute(stmt) + + for row in result: + _list.append(DataTrainingInfo( + id=row.id, + oid=row.oid, + datasource=row.datasource, + datasource_name=row.name, + question=row.question, + create_time=row.create_time, + description=row.description, + )) + + return current_page, page_size, total_count, total_pages, _list + + +def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + create_time = datetime.datetime.now() + if info.datasource is None: + raise Exception(trans("i18n_data_training.datasource_cannot_be_none")) + parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid, + datasource=info.datasource) + + exists = session.query( + session.query(DataTraining).filter( + and_(DataTraining.question == info.question, DataTraining.oid == oid, + DataTraining.datasource == info.datasource)).exists()).scalar() + if exists: + raise Exception(trans("i18n_data_training.exists_in_db")) + + result = DataTraining(**parent.model_dump()) + + session.add(parent) + session.flush() + session.refresh(parent) + + result.id = parent.id + session.commit() + + # embedding + run_save_data_training_embeddings([result.id]) + + return result.id + + +def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + if info.datasource is None: + raise Exception(trans("i18n_data_training.datasource_cannot_be_none")) + + count = session.query(DataTraining).filter( + DataTraining.id == info.id + ).count() + if count == 0: + raise Exception(trans('i18n_data_training.data_training_not_exists')) + + exists = session.query( + session.query(DataTraining).filter( + and_(DataTraining.question == info.question, DataTraining.oid == oid, + DataTraining.datasource == info.datasource, + DataTraining.id != info.id)).exists()).scalar() + if exists: + raise Exception(trans("i18n_data_training.exists_in_db")) + + stmt = update(DataTraining).where(and_(DataTraining.id == info.id)).values( + question=info.question, + description=info.description, + datasource=info.datasource, + ) + session.execute(stmt) + session.commit() + + # embedding + run_save_data_training_embeddings([info.id]) + + return info.id + + +def delete_training(session: SessionDep, ids: list[int]): + stmt = delete(DataTraining).where(and_(DataTraining.id.in_(ids))) + session.execute(stmt) + session.commit() + + +# def run_save_embeddings(ids: List[int]): +# executor.submit(save_embeddings, ids) +# +# +# def fill_empty_embeddings(): +# executor.submit(run_fill_empty_embeddings) + + +def run_fill_empty_embeddings(session: Session): + if not settings.EMBEDDING_ENABLED: + return + + stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None))) + results = session.execute(stmt).scalars().all() + + save_embeddings(session, results) + + +def save_embeddings(session: Session, ids: List[int]): + if not settings.EMBEDDING_ENABLED: + return + + if not ids or len(ids) == 0: + return + try: + + _list = session.query(DataTraining).filter(and_(DataTraining.id.in_(ids))).all() + + _question_list = [item.question for item in _list] + + model = EmbeddingModelCache.get_model() + + results = model.embed_documents(_question_list) + + for index in range(len(results)): + item = results[index] + stmt = update(DataTraining).where(and_(DataTraining.id == _list[index].id)).values(embedding=item) + session.execute(stmt) + session.commit() + + except Exception: + traceback.print_exc() + + +embedding_sql = f""" +SELECT id, datasource, question, similarity +FROM +(SELECT id, datasource, question, oid, +( 1 - (embedding <=> :embedding_array) ) AS similarity +FROM data_training AS child +) TEMP +WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and datasource = :datasource +ORDER BY similarity DESC +LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT} +""" + + +def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: int): + if question.strip() == "": + return [] + + _list: List[DataTraining] = [] + + # maybe use label later? + stmt = ( + select( + DataTraining.id, + DataTraining.question, + ) + .where( + and_(or_(text(":sentence ILIKE '%' || question || '%'"), text("question ILIKE '%' || :sentence || '%'")), + DataTraining.oid == oid, + DataTraining.datasource == datasource) + ) + ) + + results = session.execute(stmt, {'sentence': question}).fetchall() + + for row in results: + _list.append(DataTraining(id=row.id, question=row.question)) + + if settings.EMBEDDING_ENABLED: + try: + model = EmbeddingModelCache.get_model() + + embedding = model.embed_query(question) + + results = session.execute(text(embedding_sql), + {'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource}) + + for row in results: + _list.append(DataTraining(id=row.id, question=row.question)) + + except Exception: + traceback.print_exc() + + _map: dict = {} + _ids: list[int] = [] + for row in _list: + if row.id in _ids: + continue + else: + _ids.append(row.id) + + if len(_ids) == 0: + return [] + + t_list = session.query(DataTraining.id, DataTraining.datasource, DataTraining.question, + DataTraining.description).filter( + and_(DataTraining.id.in_(_ids))).all() + + for row in t_list: + _map[row.id] = {'question': row.question, 'suggestion-answer': row.description} + + _results: list[dict] = [] + for key in _map.keys(): + _results.append(_map.get(key)) + + return _results + + +def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str: + item_name_func = lambda x: 'sql-example' if x == 'sql-examples' else 'item' + dicttoxml.LOG.setLevel(logging.ERROR) + xml = dicttoxml.dicttoxml(_dict, + cdata=['question', 'suggestion-answer'], + custom_root=root, + item_func=item_name_func, + xml_declaration=False, + encoding='utf-8', + attr_type=False).decode('utf-8') + pretty_xml = parseString(xml).toprettyxml() + + if pretty_xml.startswith('') + 1 + pretty_xml = pretty_xml[end_index:].lstrip() + + # 替换所有 XML 转义字符 + escape_map = { + '<': '<', + '>': '>', + '&': '&', + '"': '"', + ''': "'" + } + for escaped, original in escape_map.items(): + pretty_xml = pretty_xml.replace(escaped, original) + + return pretty_xml + + +def get_training_template(session: SessionDep, question: str, datasource: int, oid: Optional[int] = 1) -> str: + if not oid: + oid = 1 + if not datasource: + return '' + _results = select_training_by_question(session, question, oid, datasource) + if _results and len(_results) > 0: + data_training = to_xml_string(_results) + template = get_base_data_training_template().format(data_training=data_training) + return template + else: + return '' diff --git a/backend/apps/data_training/models/__init__.py b/backend/apps/data_training/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/models/data_training_model.py b/backend/apps/data_training/models/data_training_model.py new file mode 100644 index 00000000..b2806400 --- /dev/null +++ b/backend/apps/data_training/models/data_training_model.py @@ -0,0 +1,28 @@ +from datetime import datetime +from typing import List, Optional + +from pgvector.sqlalchemy import VECTOR +from pydantic import BaseModel +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlmodel import SQLModel, Field + + +class DataTraining(SQLModel, table=True): + __tablename__ = "data_training" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1)) + datasource: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True)) + create_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + question: Optional[str] = Field(max_length=255) + description: Optional[str] = Field(sa_column=Column(Text, nullable=True)) + embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) + + +class DataTrainingInfo(BaseModel): + id: Optional[int] = None + oid: Optional[int] = None + datasource: Optional[int] = None + datasource_name: Optional[str] = None + create_time: Optional[datetime] = None + question: Optional[str] = None + description: Optional[str] = None diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py index 83d26ff0..9c182e34 100644 --- a/backend/apps/datasource/api/datasource.py +++ b/backend/apps/datasource/api/datasource.py @@ -10,6 +10,7 @@ import pandas as pd from fastapi import APIRouter, File, UploadFile, HTTPException +from apps.db.db import get_schema from apps.db.engine import get_engine_conn from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans @@ -20,7 +21,6 @@ from ..crud.field import get_fields_by_table_id from ..crud.table import get_tables_by_ds_id from ..models.datasource import CoreDatasource, CreateDatasource, TableObj, CoreTable, CoreField -from apps.db.db import get_schema router = APIRouter(tags=["datasource"], prefix="/datasource") path = settings.EXCEL_PATH @@ -135,6 +135,7 @@ class TestObj(BaseModel): sql: str = None +# not used, just do test @router.post("/execSql/{id}") async def exec_sql(session: SessionDep, id: int, obj: TestObj): def inner(): @@ -192,6 +193,7 @@ def inner(): return await asyncio.to_thread(inner) +# not used @router.post("/fieldEnum/{id}") async def field_enum(session: SessionDep, id: int): def inner(): @@ -301,6 +303,11 @@ def inner(): def insert_pg(df, tableName, engine): + # fix field type + for i in range(len(df.dtypes)): + if str(df.dtypes[i]) == 'uint64': + df[str(df.columns[i])] = df[str(df.columns[i])].astype('string') + conn = engine.raw_connection() cursor = conn.cursor() try: @@ -322,7 +329,8 @@ def insert_pg(df, tableName, engine): ) conn.commit() except Exception as e: - pass + traceback.print_exc() + raise HTTPException(400, str(e)) finally: cursor.close() conn.close() diff --git a/backend/apps/datasource/api/table_relation.py b/backend/apps/datasource/api/table_relation.py new file mode 100644 index 00000000..6a65060a --- /dev/null +++ b/backend/apps/datasource/api/table_relation.py @@ -0,0 +1,29 @@ +# Author: Junjun +# Date: 2025/9/24 +from typing import List + +from fastapi import APIRouter + +from apps.datasource.models.datasource import CoreDatasource +from common.core.deps import SessionDep + +router = APIRouter(tags=["table_relation"], prefix="/table_relation") + + +@router.post("/save/{ds_id}") +async def save_relation(session: SessionDep, ds_id: int, relation: List[dict]): + ds = session.get(CoreDatasource, ds_id) + if ds: + ds.table_relation = relation + session.commit() + else: + raise Exception("no datasource") + return True + + +@router.post("/get/{ds_id}") +async def save_relation(session: SessionDep, ds_id: int): + ds = session.get(CoreDatasource, ds_id) + if ds: + return ds.table_relation if ds.table_relation else [] + return [] diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 782a48ba..c541a152 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -1,23 +1,21 @@ import datetime import json -import platform from typing import List, Optional -if platform.system() != "Darwin": - import dmPython from fastapi import HTTPException from sqlalchemy import and_, text +from sqlbot_xpack.permissions.models.ds_rules import DsRules from sqlmodel import select from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user +from apps.datasource.embedding.table_embedding import get_table_embedding from apps.datasource.utils.utils import aes_decrypt -from apps.db.constant import ConnectType from apps.db.constant import DB -from apps.db.db import get_engine, get_tables, get_fields, exec_sql +from apps.db.db import get_tables, get_fields, exec_sql, check_connection from apps.db.engine import get_engine_config, get_engine_conn -from apps.db.type import db_type_relation +from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans -from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra +from common.utils.utils import deepcopy_ignore_extra from .table import get_tables_by_ds_id from ..crud.field import delete_field_by_ds_id, update_field from ..crud.table import delete_table_by_ds_id, update_table @@ -49,32 +47,7 @@ def check_status_by_id(session: SessionDep, trans: Trans, ds_id: int, is_raise: def check_status(session: SessionDep, trans: Trans, ds: CoreDatasource, is_raise: bool = False): - db = DB.get_db(ds.type) - if db.connect_type == ConnectType.sqlalchemy: - conn = get_engine(ds, 10) - try: - with conn.connect() as connection: - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False - else: - conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) - if ds.type == 'dm': - with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - try: - cursor.execute('select 1').fetchall() - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False + return check_connection(trans, ds, is_raise) def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource): @@ -99,7 +72,7 @@ def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: C ds.create_by = user.id ds.oid = user.oid if user.oid is not None else 1 ds.status = "Success" - ds.type_name = db_type_relation()[ds.type] + ds.type_name = DB.get_db(ds.type).db_name record = CoreDatasource(**ds.model_dump()) session.add(record) session.flush() @@ -123,8 +96,8 @@ def chooseTables(session: SessionDep, trans: Trans, id: int, tables: List[CoreTa def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource): ds.id = int(ds.id) check_name(session, trans, user, ds) - status = check_status(session, trans, ds) - ds.status = "Success" if status is True else "Fail" + # status = check_status(session, trans, ds) + ds.status = "Success" record = session.exec(select(CoreDatasource).where(CoreDatasource.id == ds.id)).first() update_data = ds.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -234,6 +207,7 @@ def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, field record.field_comment = item.fieldComment record.field_index = index + record.field_type = item.fieldType session.add(record) session.commit() else: @@ -278,8 +252,9 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table f_list = [f for f in data.fields if f.checked] if is_normal_user(current_user): # column is checked, and, column permission for data.fields + contain_rules = session.query(DsRules).all() f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table, - fields=f_list) or f_list + fields=f_list, contain_rules=contain_rules) # row permission tree where_str = '' @@ -296,7 +271,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() sql: str = "" - if ds.type == "mysql": + if ds.type == "mysql" or ds.type == "doris": sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}` {where} LIMIT 100""" @@ -304,7 +279,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table sql = f"""SELECT TOP 100 [{"], [".join(fields)}] FROM [{conf.dbSchema}].[{data.table.table_name}] {where} """ - elif ds.type == "pg" or ds.type == "excel": + elif ds.type == "pg" or ds.type == "excel" or ds.type == "redshift" or ds.type == "kingbase": sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}" {where} LIMIT 100""" @@ -321,6 +296,10 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}" {where} LIMIT 100""" + elif ds.type == "es": + sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}" + {where} + LIMIT 100""" return exec_sql(ds, sql, True) @@ -360,43 +339,122 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all() conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() schema = conf.dbSchema if conf.dbSchema is not None and conf.dbSchema != "" else conf.database + + # get all field + table_ids = [table.id for table in tables] + all_fields = session.query(CoreField).filter( + and_(CoreField.table_id.in_(table_ids), CoreField.checked == True)).all() + # build dict + fields_dict = {} + for field in all_fields: + if fields_dict.get(field.table_id): + fields_dict.get(field.table_id).append(field) + else: + fields_dict[field.table_id] = [field] + + contain_rules = session.query(DsRules).all() for table in tables: - fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all() + # fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all() + fields = fields_dict.get(table.id) # do column permissions, filter fields - fields = get_column_permission_fields(session=session, current_user=current_user, table=table, - fields=fields) or fields + fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields, + contain_rules=contain_rules) _list.append(TableAndFields(schema=schema, table=table, fields=fields)) return _list -def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str: +def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str, + embedding: bool = True) -> str: schema_str = "" table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) if len(table_objs) == 0: return schema_str db_name = table_objs[0].schema schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" + tables = [] + all_tables = [] # temp save all tables for obj in table_objs: - schema_str += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" else f"# Table: {obj.table.table_name}" + schema_table = '' + schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}" table_comment = '' if obj.table.custom_comment: table_comment = obj.table.custom_comment.strip() if table_comment == '': - schema_str += '\n[\n' + schema_table += '\n[\n' else: - schema_str += f", {table_comment}\n[\n" - - field_list = [] - for field in obj.fields: - field_comment = '' - if field.custom_comment: - field_comment = field.custom_comment.strip() - if field_comment == '': - field_list.append(f"({field.field_name}:{field.field_type})") - else: - field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})") - schema_str += ",\n".join(field_list) - schema_str += '\n]\n' - # todo 外键 + schema_table += f", {table_comment}\n[\n" + + if obj.fields: + field_list = [] + for field in obj.fields: + field_comment = '' + if field.custom_comment: + field_comment = field.custom_comment.strip() + if field_comment == '': + field_list.append(f"({field.field_name}:{field.field_type})") + else: + field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})") + schema_table += ",\n".join(field_list) + schema_table += '\n]\n' + + t_obj = {"id": obj.table.id, "schema_table": schema_table} + tables.append(t_obj) + all_tables.append(t_obj) + + # do table embedding + if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: + tables = get_table_embedding(session, current_user, tables, question) + # splice schema + if tables: + for s in tables: + schema_str += s.get('schema_table') + + # field relation + if tables and ds.table_relation: + relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation)) + if relations: + # Complete the missing table + # get tables in relation, remove irrelevant relation + embedding_table_ids = [s.get('id') for s in tables] + all_relations = list( + filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get( + 'cell') in embedding_table_ids, relations)) + + # get relation table ids, sub embedding table ids + relation_table_ids = [] + for r in all_relations: + relation_table_ids.append(r.get('source').get('cell')) + relation_table_ids.append(r.get('target').get('cell')) + relation_table_ids = list(set(relation_table_ids)) + # get table dict + table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all() + table_dict = {} + for ele in table_records: + table_dict[ele.id] = ele.table_name + + # get lost table ids + lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids)) + # get lost table schema and splice it + lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables)) + if lost_tables: + for s in lost_tables: + schema_str += s.get('schema_table') + + # get field dict + relation_field_ids = [] + for relation in all_relations: + relation_field_ids.append(relation.get('source').get('port')) + relation_field_ids.append(relation.get('target').get('port')) + relation_field_ids = list(set(relation_field_ids)) + field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all() + field_dict = {} + for ele in field_records: + field_dict[ele.id] = ele.field_name + + if all_relations: + schema_str += '【Foreign keys】\n' + for ele in all_relations: + schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" + return schema_str diff --git a/backend/apps/datasource/crud/permission.py b/backend/apps/datasource/crud/permission.py index 80c32efd..87071812 100644 --- a/backend/apps/datasource/crud/permission.py +++ b/backend/apps/datasource/crud/permission.py @@ -2,13 +2,15 @@ from typing import List, Optional from sqlalchemy import and_ -from apps.datasource.crud.row_permission import transFilterTree -from apps.datasource.models.datasource import CoreDatasource, CoreField, CoreTable -from common.core.deps import CurrentUser, SessionDep from sqlbot_xpack.permissions.api.permission import transRecord2DTO from sqlbot_xpack.permissions.models.ds_permission import DsPermission, PermissionDTO from sqlbot_xpack.permissions.models.ds_rules import DsRules +from apps.datasource.crud.row_permission import transFilterTree +from apps.datasource.models.datasource import CoreDatasource, CoreField, CoreTable +from common.core.deps import CurrentUser, SessionDep + + def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, tables: Optional[list] = None, single_table: Optional[CoreTable] = None): if single_table: @@ -20,10 +22,10 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d filters = [] if is_normal_user(current_user): + contain_rules = session.query(DsRules).all() for table in table_list: row_permissions = session.query(DsPermission).filter( and_(DsPermission.table_id == table.id, DsPermission.type == 'row')).all() - contain_rules = session.query(DsRules).all() res: List[PermissionDTO] = [] if row_permissions is not None: for permission in row_permissions: @@ -35,6 +37,7 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d if p_list is not None and u_list is not None and permission.id in p_list and ( current_user.id in u_list or f'{current_user.id}' in u_list): flag = True + break if flag: res.append(transRecord2DTO(session, permission)) where_str = transFilterTree(session, res, ds) @@ -43,11 +46,10 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d def get_column_permission_fields(session: SessionDep, current_user: CurrentUser, table: CoreTable, - fields: list[CoreField]): + fields: list[CoreField], contain_rules: list[DsRules]): if is_normal_user(current_user): column_permissions = session.query(DsPermission).filter( and_(DsPermission.table_id == table.id, DsPermission.type == 'column')).all() - contain_rules = session.query(DsRules).all() if column_permissions is not None: for permission in column_permissions: # check permission and user in same rules @@ -63,6 +65,7 @@ def get_column_permission_fields(session: SessionDep, current_user: CurrentUser, if p_list is not None and u_list is not None and permission.id in p_list and ( current_user.id in u_list or f'{current_user.id}' in u_list): flag = True + break if flag: permission_list = json.loads(permission.permissions) fields = filter_list(fields, permission_list) diff --git a/backend/apps/datasource/embedding/__init__.py b/backend/apps/datasource/embedding/__init__.py new file mode 100644 index 00000000..87bb6f5d --- /dev/null +++ b/backend/apps/datasource/embedding/__init__.py @@ -0,0 +1,2 @@ +# Author: Junjun +# Date: 2025/9/18 diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py new file mode 100644 index 00000000..7aad7d29 --- /dev/null +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -0,0 +1,59 @@ +# Author: Junjun +# Date: 2025/9/18 +import json +import traceback +from typing import Optional + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.crud.datasource import get_table_schema +from apps.datasource.embedding.utils import cosine_similarity +from apps.datasource.models.datasource import CoreDatasource +from apps.system.crud.assistant import AssistantOutDs +from common.core.deps import CurrentAssistant +from common.core.deps import SessionDep, CurrentUser +from common.utils.utils import SQLBotLogUtil + + +def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, out_ds: AssistantOutDs, + question: str, + current_assistant: Optional[CurrentAssistant] = None): + _list = [] + if current_assistant and current_assistant.type != 4: + if out_ds.ds_list: + for _ds in out_ds.ds_list: + ds = out_ds.get_ds(_ds.id) + table_schema = out_ds.get_db_schema(_ds.id) + ds_info = f"{ds.name}, {ds.description}\n" + ds_schema = ds_info + table_schema + _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + else: + for _ds in _ds_list: + if _ds.get('id'): + ds = session.get(CoreDatasource, _ds.get('id')) + table_schema = get_table_schema(session, current_user, ds, question, embedding=False) + ds_info = f"{ds.name}, {ds.description}\n" + ds_schema = ds_info + table_schema + _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + + if _list: + try: + text = [s.get('ds_schema') for s in _list] + + model = EmbeddingModelCache.get_model() + results = model.embed_documents(text) + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) + + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + # print(len(_list)) + SQLBotLogUtil.info(json.dumps( + [{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")} + for ele in _list])) + ds = _list[0].get('ds') + return {"id": ds.id, "name": ds.name, "description": ds.description} + except Exception: + traceback.print_exc() + return _list diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py new file mode 100644 index 00000000..1a3fe896 --- /dev/null +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -0,0 +1,41 @@ +# Author: Junjun +# Date: 2025/9/23 +import json +import time +import traceback + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.embedding.utils import cosine_similarity +from common.core.config import settings +from common.core.deps import SessionDep, CurrentUser +from common.utils.utils import SQLBotLogUtil + + +def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str): + _list = [] + for table in tables: + _list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0}) + + if _list: + try: + text = [s.get('schema_table') for s in _list] + + model = EmbeddingModelCache.get_model() + start_time = time.time() + results = model.embed_documents(text) + end_time = time.time() + SQLBotLogUtil.info(str(end_time - start_time)) + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) + + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + _list = _list[:settings.TABLE_EMBEDDING_COUNT] + # print(len(_list)) + SQLBotLogUtil.info(json.dumps(_list)) + return _list + except Exception: + traceback.print_exc() + return _list diff --git a/backend/apps/datasource/embedding/utils.py b/backend/apps/datasource/embedding/utils.py new file mode 100644 index 00000000..3f6ddced --- /dev/null +++ b/backend/apps/datasource/embedding/utils.py @@ -0,0 +1,18 @@ +# Author: Junjun +# Date: 2025/9/23 +import math + + +def cosine_similarity(vec_a, vec_b): + if len(vec_a) != len(vec_b): + raise ValueError("The vector dimension must be the same") + + dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) + + norm_a = math.sqrt(sum(a * a for a in vec_a)) + norm_b = math.sqrt(sum(b * b for b in vec_b)) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py index 6496a857..78f23916 100644 --- a/backend/apps/datasource/models/datasource.py +++ b/backend/apps/datasource/models/datasource.py @@ -2,13 +2,14 @@ from typing import List, Optional from pydantic import BaseModel -from sqlalchemy import Column, Text, BigInteger, DateTime, Integer, Identity +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import SQLModel, Field class CoreDatasource(SQLModel, table=True): __tablename__ = "core_datasource" - id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True)) + id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True)) name: str = Field(max_length=128, nullable=False) description: str = Field(max_length=512, nullable=True) type: str = Field(max_length=64) @@ -19,11 +20,12 @@ class CoreDatasource(SQLModel, table=True): status: str = Field(max_length=64, nullable=True) num: str = Field(max_length=256, nullable=True) oid: int = Field(sa_column=Column(BigInteger())) + table_relation: List = Field(sa_column=Column(JSONB, nullable=True)) class CoreTable(SQLModel, table=True): __tablename__ = "core_table" - id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True)) + id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True)) ds_id: int = Field(sa_column=Column(BigInteger())) checked: bool = Field(default=True) table_name: str = Field(sa_column=Column(Text)) @@ -33,7 +35,7 @@ class CoreTable(SQLModel, table=True): class CoreField(SQLModel, table=True): __tablename__ = "core_field" - id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True)) + id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True)) ds_id: int = Field(sa_column=Column(BigInteger())) table_id: int = Field(sa_column=Column(BigInteger())) checked: bool = Field(default=True) diff --git a/backend/apps/db/constant.py b/backend/apps/db/constant.py index dc4b5807..28d1ad0b 100644 --- a/backend/apps/db/constant.py +++ b/backend/apps/db/constant.py @@ -13,16 +13,21 @@ def __init__(self, type_name): class DB(Enum): - mysql = ('mysql', '`', '`', ConnectType.sqlalchemy) - sqlServer = ('sqlServer', '[', ']', ConnectType.sqlalchemy) - pg = ('pg', '"', '"', ConnectType.sqlalchemy) - excel = ('excel', '"', '"', ConnectType.sqlalchemy) - oracle = ('oracle', '"', '"', ConnectType.sqlalchemy) - ck = ('ck', '"', '"', ConnectType.sqlalchemy) - dm = ('dm', '"', '"', ConnectType.py_driver) - - def __init__(self, type, prefix, suffix, connect_type: ConnectType): + mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy) + sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy) + pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy) + excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy) + oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy) + ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy) + dm = ('dm', '达梦', '"', '"', ConnectType.py_driver) + doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver) + redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver) + es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver) + kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver) + + def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType): self.type = type + self.db_name = db_name self.prefix = prefix self.suffix = suffix self.connect_type = connect_type diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index 67d0c448..c428a576 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -3,9 +3,17 @@ import platform import urllib.parse from decimal import Decimal +from typing import Optional + +import psycopg2 +import pymssql +from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql +from common.error import ParseSQLResultError if platform.system() != "Darwin": import dmPython +import pymysql +import redshift_connector from sqlalchemy import create_engine, text, Engine from sqlalchemy.orm import sessionmaker @@ -15,6 +23,10 @@ from apps.db.engine import get_engine_config from apps.system.crud.assistant import get_ds_engine from apps.system.schemas.system_schema import AssistantOutDsSchema +from common.core.deps import Trans +from common.utils.utils import SQLBotLogUtil +from fastapi import HTTPException +from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http def get_uri(ds: CoreDatasource) -> str: @@ -60,6 +72,35 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str: return db_url +def get_extra_config(conf: DatasourceConf): + config_dict = {} + if conf.extraJdbc: + config_arr = conf.extraJdbc.split("&") + for config in config_arr: + kv = config.split("=") + if len(kv) == 2 and kv[0] and kv[1]: + config_dict[kv[0]] = kv[1] + else: + raise Exception(f'param: {config} is error') + return config_dict + + +def get_origin_connect(type: str, conf: DatasourceConf): + extra_config_dict = get_extra_config(conf) + if type == "sqlServer": + return pymssql.connect( + server=conf.host, + port=str(conf.port), + user=conf.username, + password=conf.password, + database=conf.database, + timeout=conf.timeout, + tds_version='7.0', # options: '4.2', '7.0', '8.0' ..., + **extra_config_dict + ) + + +# use sqlalchemy def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() if conf.timeout is None: @@ -77,7 +118,8 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout) elif ds.type == 'sqlServer': - engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) + engine = create_engine('mssql+pymssql://', creator=lambda: get_origin_connect(ds.type, conf), + pool_timeout=conf.timeout) elif ds.type == 'oracle': engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) @@ -93,6 +135,145 @@ def get_session(ds: CoreDatasource | AssistantOutDsSchema): return session +def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDsSchema, is_raise: bool = False): + if isinstance(ds, CoreDatasource): + db = DB.get_db(ds.type) + if db.connect_type == ConnectType.sqlalchemy: + conn = get_engine(ds, 10) + try: + with conn.connect() as connection: + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + else: + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + extra_config_dict = get_extra_config(conf) + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1', timeout=10).fetchall() + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=10, + read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, + user=conf.username, + password=conf.password, + timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif ds.type == 'kingbase': + with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, + user=conf.username, + password=conf.password, + connect_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif ds.type == 'es': + es_conn = get_es_connect(conf) + if es_conn.ping(): + SQLBotLogUtil.info("success") + return True + else: + SQLBotLogUtil.info("failed") + return False + else: + conn = get_ds_engine(ds) + try: + with conn.connect() as connection: + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + + return False + + +def get_version(ds: CoreDatasource | AssistantOutDsSchema): + version = '' + conf = None + if isinstance(ds, CoreDatasource): + conf = DatasourceConf( + **json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + if isinstance(ds, AssistantOutDsSchema): + conf = DatasourceConf() + conf.host = ds.host + conf.port = ds.port + conf.username = ds.user + conf.password = ds.password + conf.database = ds.dataBase + conf.dbSchema = ds.db_schema + conf.timeout = 10 + db = DB.get_db(ds.type) + sql = get_version_sql(ds, conf) + try: + if db.connect_type == ConnectType.sqlalchemy: + with get_session(ds) as session: + with session.execute(text(sql)) as result: + res = result.fetchall() + version = res[0][0] + else: + extra_config_dict = get_extra_config(conf) + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + cursor.execute(sql, timeout=10, **extra_config_dict) + res = cursor.fetchall() + version = res[0][0] + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=10, + read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + version = res[0][0] + elif ds.type == 'redshift' or ds.type == 'es': + version = '' + except Exception as e: + print(e) + version = '' + return version.decode() if isinstance(version, bytes) else version + + def get_schema(ds: CoreDatasource): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) @@ -100,20 +281,38 @@ def get_schema(ds: CoreDatasource): with get_session(ds) as session: sql: str = '' if ds.type == "sqlServer": - sql = f"""select name from sys.schemas""" + sql = """select name from sys.schemas""" elif ds.type == "pg" or ds.type == "excel": sql = """SELECT nspname FROM pg_namespace""" elif ds.type == "oracle": - sql = f"""select * from all_users""" + sql = """select * from all_users""" with session.execute(text(sql)) as result: res = result.fetchall() res_list = [item[0] for item in res] return res_list else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - cursor.execute(f"""select OBJECT_NAME from dba_objects where object_type='SCH'""") + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute("""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout) + res = cursor.fetchall() + res_list = [item[0] for item in res] + return res_list + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute("""SELECT nspname FROM pg_namespace""") + res = cursor.fetchall() + res_list = [item[0] for item in res] + return res_list + elif ds.type == 'kingbase': + with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + options=f"-c statement_timeout={conf.timeout * 1000}", + **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute("""SELECT nspname FROM pg_namespace""") res = cursor.fetchall() res_list = [item[0] for item in res] return res_list @@ -122,214 +321,101 @@ def get_schema(ds: CoreDatasource): def get_tables(ds: CoreDatasource): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) + sql, sql_param = get_table_sql(ds, conf, get_version(ds)) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: - sql: str = '' - if ds.type == "mysql": - sql = f""" - SELECT - TABLE_NAME, - TABLE_COMMENT - FROM - information_schema.TABLES - WHERE - TABLE_SCHEMA = '{conf.database}' - """ - elif ds.type == "sqlServer": - sql = f""" - SELECT - TABLE_NAME AS [TABLE_NAME], - ISNULL(ep.value, '') AS [TABLE_COMMENT] - FROM - INFORMATION_SCHEMA.TABLES t - LEFT JOIN - sys.extended_properties ep - ON ep.major_id = OBJECT_ID(t.TABLE_SCHEMA + '.' + t.TABLE_NAME) - AND ep.minor_id = 0 - AND ep.name = 'MS_Description' - WHERE - t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') - AND t.TABLE_SCHEMA = '{conf.dbSchema}' - """ - elif ds.type == "pg" or ds.type == "excel": - sql = """ - SELECT c.relname AS TABLE_NAME, - COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT - FROM pg_class c - LEFT JOIN - pg_namespace n ON n.oid = c.relnamespace - LEFT JOIN - pg_description d ON d.objoid = c.oid AND d.objsubid = 0 - WHERE n.nspname = current_schema() - AND c.relkind IN ('r', 'v', 'p', 'm') - AND c.relname NOT LIKE 'pg_%' - AND c.relname NOT LIKE 'sql_%' - ORDER BY c.relname \ - """ - elif ds.type == "oracle": - sql = f""" - SELECT - t.TABLE_NAME AS "TABLE_NAME", - NVL(c.COMMENTS, '') AS "TABLE_COMMENT" - FROM ( - SELECT TABLE_NAME, 'TABLE' AS OBJECT_TYPE - FROM DBA_TABLES - WHERE OWNER = '{conf.dbSchema}' - UNION ALL - SELECT VIEW_NAME AS TABLE_NAME, 'VIEW' AS OBJECT_TYPE - FROM DBA_VIEWS - WHERE OWNER = '{conf.dbSchema}' - ) t - LEFT JOIN DBA_TAB_COMMENTS c - ON t.TABLE_NAME = c.TABLE_NAME - AND c.TABLE_TYPE = t.OBJECT_TYPE - AND c.OWNER = '{conf.dbSchema}' - ORDER BY t.TABLE_NAME - """ - elif ds.type == "ck": - sql = f""" - SELECT name, comment - FROM system.tables - WHERE database = '{conf.database}' - AND engine NOT IN ('Dictionary') - ORDER BY name - """ - with session.execute(text(sql)) as result: + with session.execute(text(sql), {"param": sql_param}) as result: res = result.fetchall() res_list = [TableSchema(*item) for item in res] return res_list else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - cursor.execute(f"""select table_name, comments - from all_tab_comments - where owner='{conf.dbSchema}' - AND (table_type = 'TABLE' or table_type = 'VIEW') - """) + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, {"param": sql_param}, timeout=conf.timeout) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=conf.timeout, + read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (sql_param,)) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (sql_param,)) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list + elif ds.type == 'kingbase': + with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + options=f"-c statement_timeout={conf.timeout * 1000}", + **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql.format(sql_param)) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + elif ds.type == 'es': + res = get_es_index(conf) + res_list = [TableSchema(*item) for item in res] + return res_list def get_fields(ds: CoreDatasource, table_name: str = None): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) + sql, p1, p2 = get_field_sql(ds, conf, table_name) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: - sql: str = '' - if ds.type == "mysql": - sql1 = f""" - SELECT - COLUMN_NAME, - DATA_TYPE, - COLUMN_COMMENT - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA = '{conf.database}' - """ - sql2 = f" AND TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - sql = sql1 + sql2 - elif ds.type == "sqlServer": - sql1 = f""" - SELECT - COLUMN_NAME AS [COLUMN_NAME], - DATA_TYPE AS [DATA_TYPE], - ISNULL(EP.value, '') AS [COLUMN_COMMENT] - FROM - INFORMATION_SCHEMA.COLUMNS C - LEFT JOIN - sys.extended_properties EP - ON EP.major_id = OBJECT_ID(C.TABLE_SCHEMA + '.' + C.TABLE_NAME) - AND EP.minor_id = C.ORDINAL_POSITION - AND EP.name = 'MS_Description' - WHERE - C.TABLE_SCHEMA = '{conf.dbSchema}' - """ - sql2 = f" AND C.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - sql = sql1 + sql2 - elif ds.type == "pg" or ds.type == "excel": - sql1 = """ - SELECT a.attname AS COLUMN_NAME, - pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, - col_description(c.oid, a.attnum) AS COLUMN_COMMENT - FROM pg_catalog.pg_attribute a - JOIN - pg_catalog.pg_class c ON a.attrelid = c.oid - JOIN - pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = current_schema() - AND a.attnum > 0 - AND NOT a.attisdropped \ - """ - sql2 = f" AND c.relname = '{table_name}'" if table_name is not None and table_name != "" else "" - sql = sql1 + sql2 - elif ds.type == "oracle": - sql1 = f""" - SELECT - col.COLUMN_NAME AS "COLUMN_NAME", - (CASE - WHEN col.DATA_TYPE IN ('VARCHAR2', 'CHAR', 'NVARCHAR2', 'NCHAR') - THEN col.DATA_TYPE || '(' || col.DATA_LENGTH || ')' - WHEN col.DATA_TYPE = 'NUMBER' AND col.DATA_PRECISION IS NOT NULL - THEN col.DATA_TYPE || '(' || col.DATA_PRECISION || - CASE WHEN col.DATA_SCALE > 0 THEN ',' || col.DATA_SCALE END || ')' - ELSE col.DATA_TYPE - END) AS "DATA_TYPE", - NVL(com.COMMENTS, '') AS "COLUMN_COMMENT" - FROM - DBA_TAB_COLUMNS col - LEFT JOIN - DBA_COL_COMMENTS com - ON col.OWNER = com.OWNER - AND col.TABLE_NAME = com.TABLE_NAME - AND col.COLUMN_NAME = com.COLUMN_NAME - WHERE - col.OWNER = '{conf.dbSchema}' - """ - sql2 = f" AND col.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - sql = sql1 + sql2 - elif ds.type == "ck": - sql1 = f""" - SELECT - name AS COLUMN_NAME, - type AS DATA_TYPE, - comment AS COLUMN_COMMENT - FROM system.columns - WHERE database = '{conf.database}' - """ - sql2 = f" AND table = '{table_name}'" if table_name is not None and table_name != "" else "" - sql = sql1 + sql2 - - with session.execute(text(sql)) as result: + with session.execute(text(sql), {"param1": p1, "param2": p2}) as result: res = result.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - sql1 = f""" - SELECT - c.COLUMN_NAME AS "COLUMN_NAME", - c.DATA_TYPE AS "DATA_TYPE", - COALESCE(com.COMMENTS, '') AS "COMMENTS" - FROM - ALL_TAB_COLS c - LEFT JOIN - ALL_COL_COMMENTS com - ON c.OWNER = com.OWNER - AND c.TABLE_NAME = com.TABLE_NAME - AND c.COLUMN_NAME = com.COLUMN_NAME - WHERE - c.OWNER = '{conf.dbSchema}' - """ - sql2 = f" AND c.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - cursor.execute(sql1 + sql2) + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, {"param1": p1, "param2": p2}, timeout=conf.timeout) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=conf.timeout, + read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (p1, p2)) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (p1, p2)) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + elif ds.type == 'kingbase': + with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + options=f"-c statement_timeout={conf.timeout * 1000}", + **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql.format(p1, p2)) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + elif ds.type == 'es': + res = get_es_fields(conf, table_name) + res_list = [ColumnSchema(*item) for item in res] + return res_list def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False): @@ -351,12 +437,71 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= return {"fields": columns, "data": result_list, "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: - raise ex + raise ParseSQLResultError(str(ex)) else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute(sql, timeout=conf.timeout) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=conf.timeout, + read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute(sql) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute(sql) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + elif ds.type == 'kingbase': + with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + options=f"-c statement_timeout={conf.timeout * 1000}", + **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() @@ -371,4 +516,19 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= return {"fields": columns, "data": result_list, "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: - raise ex + raise ParseSQLResultError(str(ex)) + elif ds.type == 'es': + try: + res, columns = get_es_data_by_http(conf, sql) + columns = [field.get('name') for field in columns] if origin_column else [field.get('name').lower() for + field in + columns] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple(tuple_item))} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise Exception(str(ex)) diff --git a/backend/apps/db/db_sql.py b/backend/apps/db/db_sql.py new file mode 100644 index 00000000..bced9a22 --- /dev/null +++ b/backend/apps/db/db_sql.py @@ -0,0 +1,310 @@ +# Author: Junjun +# Date: 2025/8/20 +from apps.datasource.models.datasource import CoreDatasource, DatasourceConf + + +def get_version_sql(ds: CoreDatasource, conf: DatasourceConf): + if ds.type == "mysql" or ds.type == "doris": + return """ + SELECT VERSION() + """ + elif ds.type == "sqlServer": + return """ + select SERVERPROPERTY('ProductVersion') + """ + elif ds.type == "pg" or ds.type == "kingbase" or ds.type == "excel": + return """ + SELECT current_setting('server_version') + """ + elif ds.type == "oracle": + return """ + SELECT version FROM v$instance + """ + elif ds.type == "ck": + return """ + select version() + """ + elif ds.type == 'dm': + return """ + SELECT * FROM v$version + """ + elif ds.type == 'redshift': + return '' + + +def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''): + if ds.type == "mysql": + return """ + SELECT + TABLE_NAME, + TABLE_COMMENT + FROM + information_schema.TABLES + WHERE + TABLE_SCHEMA = :param + """, conf.database + elif ds.type == "sqlServer": + return """ + SELECT + TABLE_NAME AS [TABLE_NAME], + ISNULL(ep.value, '') AS [TABLE_COMMENT] + FROM + INFORMATION_SCHEMA.TABLES t + LEFT JOIN + sys.extended_properties ep + ON ep.major_id = OBJECT_ID(t.TABLE_SCHEMA + '.' + t.TABLE_NAME) + AND ep.minor_id = 0 + AND ep.name = 'MS_Description' + WHERE + t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') + AND t.TABLE_SCHEMA = :param + """, conf.dbSchema + elif ds.type == "pg" or ds.type == "excel": + return """ + SELECT c.relname AS TABLE_NAME, + COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT + FROM pg_class c + LEFT JOIN + pg_namespace n ON n.oid = c.relnamespace + LEFT JOIN + pg_description d ON d.objoid = c.oid AND d.objsubid = 0 + WHERE n.nspname = :param + AND c.relkind IN ('r', 'v', 'p', 'm') + AND c.relname NOT LIKE 'pg_%' + AND c.relname NOT LIKE 'sql_%' + ORDER BY c.relname \ + """, conf.dbSchema + elif ds.type == "oracle": + return """ + SELECT + t.TABLE_NAME AS "TABLE_NAME", + NVL(c.COMMENTS, '') AS "TABLE_COMMENT" + FROM ( + SELECT TABLE_NAME, 'TABLE' AS OBJECT_TYPE + FROM DBA_TABLES + WHERE OWNER = :param + UNION ALL + SELECT VIEW_NAME AS TABLE_NAME, 'VIEW' AS OBJECT_TYPE + FROM DBA_VIEWS + WHERE OWNER = :param + ) t + LEFT JOIN DBA_TAB_COMMENTS c + ON t.TABLE_NAME = c.TABLE_NAME + AND c.TABLE_TYPE = t.OBJECT_TYPE + AND c.OWNER = :param + ORDER BY t.TABLE_NAME + """, conf.dbSchema + elif ds.type == "ck": + version = int(db_version.split('.')[0]) + if version < 22: + return """ + SELECT name, null as comment + FROM system.tables + WHERE database = :param + AND engine NOT IN ('Dictionary') + ORDER BY name + """, conf.database + else: + return """ + SELECT name, comment + FROM system.tables + WHERE database = :param + AND engine NOT IN ('Dictionary') + ORDER BY name + """, conf.database + elif ds.type == 'dm': + return """ + select table_name, comments + from all_tab_comments + where owner=:param + AND (table_type = 'TABLE' or table_type = 'VIEW') + """, conf.dbSchema + elif ds.type == 'redshift': + return """ + SELECT + relname AS TableName, + obj_description(relfilenode::regclass, 'pg_class') AS TableDescription + FROM + pg_class + WHERE + relkind in ('r','p', 'f') + AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = %s) + """, conf.dbSchema + elif ds.type == "doris": + return """ + SELECT + TABLE_NAME, + TABLE_COMMENT + FROM + information_schema.TABLES + WHERE + TABLE_SCHEMA = %s + """, conf.database + elif ds.type == "kingbase": + return """ + SELECT c.relname AS TABLE_NAME, + COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT + FROM pg_class c + LEFT JOIN + pg_namespace n ON n.oid = c.relnamespace + LEFT JOIN + pg_description d ON d.objoid = c.oid AND d.objsubid = 0 + WHERE n.nspname = '{0}' + AND c.relkind IN ('r', 'v', 'p', 'm') + AND c.relname NOT LIKE 'pg_%' + AND c.relname NOT LIKE 'sql_%' + ORDER BY c.relname \ + """, conf.dbSchema + elif ds.type == "es": + return "", None + + +def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = None): + if ds.type == "mysql": + sql1 = """ + SELECT + COLUMN_NAME, + DATA_TYPE, + COLUMN_COMMENT + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = :param1 + """ + sql2 = " AND TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.database, table_name + elif ds.type == "sqlServer": + sql1 = """ + SELECT + COLUMN_NAME AS [COLUMN_NAME], + DATA_TYPE AS [DATA_TYPE], + ISNULL(EP.value, '') AS [COLUMN_COMMENT] + FROM + INFORMATION_SCHEMA.COLUMNS C + LEFT JOIN + sys.extended_properties EP + ON EP.major_id = OBJECT_ID(C.TABLE_SCHEMA + '.' + C.TABLE_NAME) + AND EP.minor_id = C.ORDINAL_POSITION + AND EP.name = 'MS_Description' + WHERE + C.TABLE_SCHEMA = :param1 + """ + sql2 = " AND C.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "pg" or ds.type == "excel": + sql1 = """ + SELECT a.attname AS COLUMN_NAME, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, + col_description(c.oid, a.attnum) AS COLUMN_COMMENT + FROM pg_catalog.pg_attribute a + JOIN + pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = :param1 + AND a.attnum > 0 + AND NOT a.attisdropped \ + """ + sql2 = " AND c.relname = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "redshift": + sql1 = """ + SELECT a.attname AS COLUMN_NAME, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, + col_description(c.oid, a.attnum) AS COLUMN_COMMENT + FROM pg_catalog.pg_attribute a + JOIN + pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = %s + AND a.attnum > 0 + AND NOT a.attisdropped \ + """ + sql2 = " AND c.relname = %s" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "oracle": + sql1 = """ + SELECT + col.COLUMN_NAME AS "COLUMN_NAME", + (CASE + WHEN col.DATA_TYPE IN ('VARCHAR2', 'CHAR', 'NVARCHAR2', 'NCHAR') + THEN col.DATA_TYPE || '(' || col.DATA_LENGTH || ')' + WHEN col.DATA_TYPE = 'NUMBER' AND col.DATA_PRECISION IS NOT NULL + THEN col.DATA_TYPE || '(' || col.DATA_PRECISION || + CASE WHEN col.DATA_SCALE > 0 THEN ',' || col.DATA_SCALE END || ')' + ELSE col.DATA_TYPE + END) AS "DATA_TYPE", + NVL(com.COMMENTS, '') AS "COLUMN_COMMENT" + FROM + DBA_TAB_COLUMNS col + LEFT JOIN + DBA_COL_COMMENTS com + ON col.OWNER = com.OWNER + AND col.TABLE_NAME = com.TABLE_NAME + AND col.COLUMN_NAME = com.COLUMN_NAME + WHERE + col.OWNER = :param1 + """ + sql2 = " AND col.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "ck": + sql1 = """ + SELECT + name AS COLUMN_NAME, + type AS DATA_TYPE, + comment AS COLUMN_COMMENT + FROM system.columns + WHERE database = :param1 + """ + sql2 = " AND table = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.database, table_name + elif ds.type == 'dm': + sql1 = """ + SELECT + c.COLUMN_NAME AS "COLUMN_NAME", + c.DATA_TYPE AS "DATA_TYPE", + COALESCE(com.COMMENTS, '') AS "COMMENTS" + FROM + ALL_TAB_COLS c + LEFT JOIN + ALL_COL_COMMENTS com + ON c.OWNER = com.OWNER + AND c.TABLE_NAME = com.TABLE_NAME + AND c.COLUMN_NAME = com.COLUMN_NAME + WHERE + c.OWNER = :param1 + """ + sql2 = " AND c.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "doris": + sql1 = """ + SELECT + COLUMN_NAME, + DATA_TYPE, + COLUMN_COMMENT + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = %s + """ + sql2 = " AND TABLE_NAME = %s" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.database, table_name + elif ds.type == "kingbase": + sql1 = """ + SELECT a.attname AS COLUMN_NAME, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, + col_description(c.oid, a.attnum) AS COLUMN_COMMENT + FROM pg_catalog.pg_attribute a + JOIN + pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = '{0}' + AND a.attnum > 0 + AND NOT a.attisdropped \ + """ + sql2 = " AND c.relname = '{1}'" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "es": + return "", None, None diff --git a/backend/apps/db/es_engine.py b/backend/apps/db/es_engine.py new file mode 100644 index 00000000..f8cf1f2f --- /dev/null +++ b/backend/apps/db/es_engine.py @@ -0,0 +1,117 @@ +# Author: Junjun +# Date: 2025/9/9 + +import json +from base64 import b64encode + +import requests +from elasticsearch import Elasticsearch + +from apps.datasource.models.datasource import DatasourceConf +from common.error import SingleMessageError + + +def get_es_connect(conf: DatasourceConf): + es_client = Elasticsearch( + [conf.host], # ES address + basic_auth=(conf.username, conf.password), + verify_certs=False, + compatibility_mode=True + ) + return es_client + + +# get tables +def get_es_index(conf: DatasourceConf): + es_client = get_es_connect(conf) + indices = es_client.cat.indices(format="json") + res = [] + if indices is not None: + for idx in indices: + index_name = idx.get('index') + desc = '' + # get mapping + mapping = es_client.indices.get_mapping(index=index_name) + mappings = mapping.get(index_name).get("mappings") + if mappings.get('_meta'): + desc = mappings.get('_meta').get('description') + res.append((index_name, desc)) + return res + + +# get fields +def get_es_fields(conf: DatasourceConf, table_name: str): + es_client = get_es_connect(conf) + index_name = table_name + mapping = es_client.indices.get_mapping(index=index_name) + properties = mapping.get(index_name).get("mappings").get("properties") + res = [] + if properties is not None: + for field, config in properties.items(): + field_type = config.get("type") + desc = '' + if config.get("_meta"): + desc = config.get("_meta").get('description') + + if field_type: + res.append((field, field_type, desc)) + else: + # object、nested... + res.append((field, ','.join(list(config.keys())), desc)) + return res + + +# def get_es_data(conf: DatasourceConf, sql: str, table_name: str): +# r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql}) +# if r.json().get('error'): +# print(json.dumps(r.json())) +# +# es_client = get_es_connect(conf) +# response = es_client.search( +# index=table_name, +# body=json.dumps(r.json()) +# ) +# +# # print(response) +# fields = get_es_fields(conf, table_name) +# res = [] +# for hit in response.get('hits').get('hits'): +# item = [] +# if 'fields' in hit: +# result = hit.get('fields') # {'title': ['Python'], 'age': [30]} +# for field in fields: +# v = result.get(field[0]) +# item.append(v[0]) if v else item.append(None) +# res.append(tuple(item)) +# # print(hit['fields']['title'][0]) +# # elif '_source' in hit: +# # print(hit.get('_source')) +# return res, fields + + +def get_es_data_by_http(conf: DatasourceConf, sql: str): + url = conf.host + while url.endswith('/'): + url = url[:-1] + + host = f'{url}/_sql?format=json' + username = f"{conf.username}" + password = f"{conf.password}" + + credentials = f"{username}:{password}" + encoded_credentials = b64encode(credentials.encode()).decode() + + headers = { + "Content-Type": "application/json", + "Authorization": f"Basic {encoded_credentials}" + } + + response = requests.post(host, data=json.dumps({"query": sql}), headers=headers) + + # print(response.json()) + res = response.json() + if res.get('error'): + raise SingleMessageError(json.dumps(res)) + fields = res.get('columns') + result = res.get('rows') + return result, fields diff --git a/backend/apps/db/type.py b/backend/apps/db/type.py deleted file mode 100644 index d106a437..00000000 --- a/backend/apps/db/type.py +++ /dev/null @@ -1,15 +0,0 @@ -# Author: Junjun -# Date: 2025/5/22 -from typing import Dict - - -def db_type_relation() -> Dict: - return { - "mysql": "MySQL", - "sqlServer": "Microsoft SQL Server", - "pg": "PostgreSQL", - "excel": "Excel/CSV", - "oracle": "Oracle", - "ck": "ClickHouse", - "dm": "达梦" - } diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index 19afffe7..433e4727 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -1,6 +1,7 @@ # Author: Junjun # Date: 2025/7/1 - +import json +import traceback from datetime import timedelta import jwt @@ -10,15 +11,17 @@ from jwt.exceptions import InvalidTokenError from pydantic import ValidationError from sqlmodel import select +from starlette.responses import JSONResponse from apps.chat.api.chat import create_chat -from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion +from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion, \ + ChatFinishStep from apps.chat.task.llm import LLMService from apps.system.crud.user import authenticate from apps.system.crud.user import get_db_user from apps.system.models.system_model import UserWsModel from apps.system.models.user import UserModel -from apps.system.schemas.system_schema import BaseUserDTO +from apps.system.schemas.system_schema import BaseUserDTO, AssistantHeader from apps.system.schemas.system_schema import UserInfoDTO from common.core import security from common.core.config import settings @@ -106,8 +109,93 @@ async def mcp_question(session: SessionDep, chat: McpQuestion): raise HTTPException(status_code=400, detail="Inactive user") mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question) - # ask - llm_service = LLMService(session_user, mcp_chat) - llm_service.init_record() - return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream") + try: + llm_service = await LLMService.create(session_user, mcp_chat) + llm_service.init_record() + llm_service.run_task_async(False, chat.stream) + except Exception as e: + traceback.print_exc() + + if chat.stream: + def _err(_e: Exception): + yield str(_e) + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) + if chat.stream: + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + else: + res = llm_service.await_result() + raw_data = {} + for chunk in res: + if chunk: + raw_data = chunk + status_code = 200 + if not raw_data.get('success'): + status_code = 500 + + return JSONResponse( + content=raw_data, + status_code=status_code, + ) + + +@router.post("/mcp_assistant", operation_id="mcp_assistant") +async def mcp_assistant(session: SessionDep, chat: McpAssistant): + session_user = BaseUserDTO(**{ + "id": -1, "account": 'sqlbot-mcp-assistant', "oid": 1, "assistant_id": -1, "password": '', "language": "zh-CN" + }) + # session_user: UserModel = get_db_user(session=session, user_id=1) + # session_user.oid = 1 + c = create_chat(session, session_user, CreateChat(origin=1), False) + + # build assistant param + configuration = {"endpoint": chat.url} + # authorization = [{"key": "x-de-token", + # "value": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1aWQiOjEsIm9pZCI6MSwiZXhwIjoxNzU4NTEyMDA2fQ.3NR-pgnADLdXZtI3dXX5-LuxfGYRvYD9kkr2de7KRP0", + # "target": "header"}] + mcp_assistant_header = AssistantHeader(id=1, name='mcp_assist', domain='', type=1, + configuration=json.dumps(configuration), + certificate=chat.authorization) + + # assistant question + mcp_chat = ChatQuestion(chat_id=c.id, question=chat.question) + # ask + try: + llm_service = await LLMService.create(session_user, mcp_chat, mcp_assistant_header) + llm_service.init_record() + llm_service.run_task_async(False, chat.stream, ChatFinishStep.QUERY_DATA) + except Exception as e: + traceback.print_exc() + + if chat.stream: + def _err(_e: Exception): + yield str(_e) + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) + if chat.stream: + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + else: + res = llm_service.await_result() + raw_data = {} + for chunk in res: + if chunk: + raw_data = chunk + status_code = 200 + if not raw_data.get('success'): + status_code = 500 + + return JSONResponse( + content=raw_data, + status_code=status_code, + ) diff --git a/backend/apps/system/api/aimodel.py b/backend/apps/system/api/aimodel.py index c33fd4ee..a96354da 100644 --- a/backend/apps/system/api/aimodel.py +++ b/backend/apps/system/api/aimodel.py @@ -9,6 +9,7 @@ from apps.system.models.system_model import AiModelDetail from common.core.deps import SessionDep, Trans +from common.utils.crypto import sqlbot_decrypt from common.utils.time import get_timestamp from common.utils.utils import SQLBotLogUtil, prepare_model_arg @@ -102,6 +103,10 @@ async def get_model_by_id( config_list = [AiModelConfigItem(**item) for item in raw] except Exception: pass + if db_model.api_key: + db_model.api_key = await sqlbot_decrypt(db_model.api_key) + if db_model.api_domain: + db_model.api_domain = await sqlbot_decrypt(db_model.api_domain) data = AiModelDetail.model_validate(db_model).model_dump(exclude_unset=True) data.pop("config", None) data["config_list"] = config_list diff --git a/backend/apps/system/api/assistant.py b/backend/apps/system/api/assistant.py index bd93e5d9..ea41e6d7 100644 --- a/backend/apps/system/api/assistant.py +++ b/backend/apps/system/api/assistant.py @@ -31,6 +31,25 @@ async def info(request: Request, response: Response, session: SessionDep, trans: db_model = AssistantModel.model_validate(db_model) response.headers["Access-Control-Allow-Origin"] = db_model.domain origin = request.headers.get("origin") or get_origin_from_referer(request) + if not origin: + raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or '')) + origin = origin.rstrip('/') + if origin != db_model.domain: + raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or '')) + return db_model + +@router.get("/app/{appId}") +async def getApp(request: Request, response: Response, session: SessionDep, trans: Trans, appId: str) -> AssistantModel: + if not appId: + raise Exception('miss assistant appId') + db_model = session.exec(select(AssistantModel).where(AssistantModel.app_id == appId)).first() + if not db_model: + raise RuntimeError(f"assistant application not exist") + db_model = AssistantModel.model_validate(db_model) + response.headers["Access-Control-Allow-Origin"] = db_model.domain + origin = request.headers.get("origin") or get_origin_from_referer(request) + if not origin: + raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or '')) origin = origin.rstrip('/') if origin != db_model.domain: raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or '')) @@ -45,9 +64,15 @@ async def validator(session: SessionDep, id: int, virtual: Optional[int] = Query if not db_model: return AssistantValidator() db_model = AssistantModel.model_validate(db_model) + assistant_oid = 1 + if(db_model.type == 0): + configuration = db_model.configuration + config_obj = json.loads(configuration) if configuration else {} + assistant_oid = config_obj.get('oid', 1) + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) assistantDict = { - "id": virtual, "account": 'sqlbot-inner-assistant', "oid": 1, "assistant_id": id + "id": virtual, "account": 'sqlbot-inner-assistant', "oid": assistant_oid, "assistant_id": id } access_token = create_access_token( assistantDict, expires_delta=access_token_expires @@ -90,12 +115,19 @@ async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] = file.filename = file_name if flag_name == 'logo' or flag_name == 'float_icon': SQLBotFileUtils.check_file(file=file, file_types=[".jpg", ".jpeg", ".png", ".svg"], limit_file_size=(10 * 1024 * 1024)) - SQLBotFileUtils.detete_file(config_obj.get(flag_name)) + if config_obj.get(flag_name): + SQLBotFileUtils.detete_file(config_obj.get(flag_name)) file_id = await SQLBotFileUtils.upload(file) ui_schema_dict[flag_name] = file_id else: raise ValueError(f"Unsupported file flag: {flag_name}") - + + for flag_name in ['logo', 'float_icon']: + file_val = config_obj.get(flag_name) + if file_val and not ui_schema_dict.get(flag_name): + config_obj[flag_name] = None + SQLBotFileUtils.detete_file(file_val) + for attr, value in ui_schema_dict.items(): if attr != 'id' and not attr.startswith("__"): config_obj[attr] = value @@ -103,15 +135,20 @@ async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] = db_model.configuration = json.dumps(config_obj, ensure_ascii=False) session.add(db_model) session.commit() + await clear_ui_cache(db_model.id) + +@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id") +async def clear_ui_cache(id: int): + pass @router.get("", response_model=list[AssistantModel]) async def query(session: SessionDep): - list_result = session.exec(select(AssistantModel).where(AssistantModel.type.in_([0, 1])).order_by(AssistantModel.name, AssistantModel.create_time)).all() + list_result = session.exec(select(AssistantModel).where(AssistantModel.type != 4).order_by(AssistantModel.name, AssistantModel.create_time)).all() return list_result @router.post("") async def add(request: Request, session: SessionDep, creator: AssistantBase): - save(request, session, creator) + await save(request, session, creator) @router.put("") diff --git a/backend/apps/system/api/user.py b/backend/apps/system/api/user.py index d6f33cda..7453e94e 100644 --- a/backend/apps/system/api/user.py +++ b/backend/apps/system/api/user.py @@ -193,7 +193,7 @@ async def batch_del(session: SessionDep, id_list: list[int]): @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id") async def langChange(session: SessionDep, current_user: CurrentUser, trans: Trans, language: UserLanguage): lang = language.language - if lang not in ["zh-CN", "en"]: + if lang not in ["zh-CN", "en", "ko-KR"]: raise Exception(trans('i18n_user.language_not_support', key = lang)) db_user: UserModel = get_db_user(session=session, user_id=current_user.id) db_user.language = lang diff --git a/backend/apps/system/api/workspace.py b/backend/apps/system/api/workspace.py index 0d795289..8cd05ac8 100644 --- a/backend/apps/system/api/workspace.py +++ b/backend/apps/system/api/workspace.py @@ -205,6 +205,8 @@ async def get_one(session: SessionDep, trans: Trans, id: int): @router.delete("/{id}") async def single_delete(session: SessionDep, id: int): + if id == 1: + raise HTTPException(f"Can not delete default workspace") db_model = session.get(WorkspaceModel, id) if not db_model: raise HTTPException(f"WorkspaceModel with id {id} not found") diff --git a/backend/apps/system/crud/aimodel_manage.py b/backend/apps/system/crud/aimodel_manage.py new file mode 100644 index 00000000..77d6cb5e --- /dev/null +++ b/backend/apps/system/crud/aimodel_manage.py @@ -0,0 +1,26 @@ + +from apps.system.models.system_model import AiModelDetail +from common.core.db import engine +from sqlmodel import Session, select +from common.utils.crypto import sqlbot_encrypt +from common.utils.utils import SQLBotLogUtil + +async def async_model_info(): + with Session(engine) as session: + model_list = session.exec(select(AiModelDetail)).all() + any_model_change = False + if model_list: + for model in model_list: + if model.api_domain.startswith("http"): + if model.api_key: + model.api_key = await sqlbot_encrypt(model.api_key) + if model.api_domain: + model.api_domain = await sqlbot_encrypt(model.api_domain) + session.add(model) + any_model_change = True + if any_model_change: + session.commit() + SQLBotLogUtil.info("✅ 异步加密已有模型的密钥和地址完成") + + + \ No newline at end of file diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 0699ec06..912218b9 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -15,9 +15,9 @@ from common.core.config import settings from common.core.db import engine from common.core.sqlbot_cache import cache +from common.utils.aes_crypto import simple_aes_decrypt from common.utils.utils import string_to_numeric_hash - @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id") async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None: db_model = session.get(AssistantModel, assistant_id) @@ -32,7 +32,7 @@ def get_assistant_user(*, id: int): def get_assistant_ds(session: Session, llm_service) -> list[dict]: assistant: AssistantHeader = llm_service.current_assistant type = assistant.type - if type == 0: + if type == 0 or type == 2: configuration = assistant.configuration if configuration: config: dict[any] = json.loads(configuration) @@ -40,9 +40,14 @@ def get_assistant_ds(session: Session, llm_service) -> list[dict]: stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where( CoreDatasource.oid == oid) if not assistant.online: - private_list: list[int] = config.get('private_list') or None + public_list: list[int] = config.get('public_list') or None + if public_list: + stmt = stmt.where(CoreDatasource.id.in_(public_list)) + else: + return [] + """ private_list: list[int] = config.get('private_list') or None if private_list: - stmt = stmt.where(~CoreDatasource.id.in_(private_list)) + stmt = stmt.where(~CoreDatasource.id.in_(private_list)) """ db_ds_list = session.exec(stmt) result_list = [ @@ -117,13 +122,13 @@ def get_ds_from_api(self): res = requests.get(url=endpoint, params=param, headers=header, cookies=cookies, timeout=10) if res.status_code == 200: result_json: dict[any] = json.loads(res.text) - if result_json.get('code') == 0: + if result_json.get('code') == 0 or result_json.get('code') == 200: temp_list = result_json.get('data', []) - self.ds_list = [ - self.convert2schema(item) + temp_ds_list = [ + self.convert2schema(item, config) for item in temp_list ] - + self.ds_list = temp_ds_list return self.ds_list else: raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}") @@ -169,9 +174,19 @@ def get_ds(self, ds_id: int): raise Exception("Datasource list is not found.") raise Exception(f"Datasource with id {ds_id} not found.") - def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema: + def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema: id_marker: str = '' attr_list = ['name', 'type', 'host', 'port', 'user', 'dataBase', 'schema'] + if config.get('encrypt', False): + key = config.get('aes_key', None) + iv = config.get('aes_iv', None) + aes_attrs = ['host', 'user', 'password', 'dataBase', 'db_schema', 'schema'] + for attr in aes_attrs: + if attr in ds_dict and ds_dict[attr]: + try: + ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv) + except Exception as e: + raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}") for attr in attr_list: if attr in ds_dict: id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--' @@ -180,7 +195,6 @@ def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema: ds_dict.pop("schema", None) return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema}) - class AssistantOutDsFactory: @staticmethod def get_instance(assistant: AssistantHeader) -> AssistantOutDs: @@ -197,7 +211,7 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine: password=ds.password, database=ds.dataBase, driver='', - extraJdbc=ds.extraParams, + extraJdbc=ds.extraParams or '', dbSchema=ds.db_schema or '' ) conf.extraJdbc = '' diff --git a/backend/apps/system/middleware/auth.py b/backend/apps/system/middleware/auth.py index 32fa4c71..3ea720c6 100644 --- a/backend/apps/system/middleware/auth.py +++ b/backend/apps/system/middleware/auth.py @@ -1,4 +1,6 @@ +import base64 +import json from typing import Optional from fastapi import Request from fastapi.responses import JSONResponse @@ -8,7 +10,7 @@ from apps.system.models.system_model import AssistantModel from common.core.db import engine from apps.system.crud.assistant import get_assistant_info, get_assistant_user -from apps.system.crud.user import get_user_info +from apps.system.crud.user import get_user_by_account, get_user_info from apps.system.schemas.system_schema import AssistantHeader, UserInfoDTO from common.core import security from common.core.config import settings @@ -34,7 +36,7 @@ async def dispatch(self, request, call_next): trans = await get_i18n(request) #if assistantToken and assistantToken.lower().startswith("assistant "): if assistantToken: - validator: tuple[any] = await self.validateAssistant(assistantToken) + validator: tuple[any] = await self.validateAssistant(assistantToken, trans) if validator[0]: request.state.current_user = validator[1] request.state.assistant = validator[2] @@ -87,14 +89,17 @@ async def validateToken(self, token: Optional[str], trans: I18n): return False, e - async def validateAssistant(self, assistantToken: Optional[str]) -> tuple[any]: + async def validateAssistant(self, assistantToken: Optional[str], trans: I18n) -> tuple[any]: if not assistantToken: return False, f"Miss Token[{settings.TOKEN_KEY}]!" schema, param = get_authorization_scheme_param(assistantToken) - if schema.lower() != "assistant": - return False, f"Token schema error!" - try: + + try: + if schema.lower() == 'embedded': + return await self.validateEmbedded(param, trans) + if schema.lower() != "assistant": + return False, f"Token schema error!" payload = jwt.decode( param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) @@ -108,8 +113,66 @@ async def validateAssistant(self, assistantToken: Optional[str]) -> tuple[any]: assistant_info = await get_assistant_info(session=session, assistant_id=payload['assistant_id']) assistant_info = AssistantModel.model_validate(assistant_info) assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True)) + if assistant_info and assistant_info.type == 0: + if payload['oid']: + session_user.oid = int(payload['oid']) + else: + assistant_oid = 1 + configuration = assistant_info.configuration + config_obj = json.loads(configuration) if configuration else {} + assistant_oid = config_obj.get('oid', 1) + session_user.oid = int(assistant_oid) + return True, session_user, assistant_info except Exception as e: SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}") # Return False and the exception message - return False, e \ No newline at end of file + return False, e + + async def validateEmbedded(self, param: str, trans: I18n) -> tuple[any]: + try: + """ payload = jwt.decode( + param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) """ + payload: dict = jwt.decode( + param, + options={"verify_signature": False, "verify_exp": False}, + algorithms=[security.ALGORITHM] + ) + app_key = payload.get('appId', '') + embeddedId = payload.get('embeddedId', None) + if not embeddedId: + embeddedId = xor_decrypt(app_key) + if not payload['account']: + return False, f"Miss account payload error!" + account = payload['account'] + with Session(engine) as session: + """ session_user = await get_user_info(session = session, user_id = token_data.id) + session_user = UserInfoDTO.model_validate(session_user) """ + session_user = get_user_by_account(session = session, account=account) + if not session_user: + message = trans('i18n_not_exist', msg = trans('i18n_user.account')) + raise Exception(message) + session_user = await get_user_info(session = session, user_id = session_user.id) + + session_user = UserInfoDTO.model_validate(session_user) + if session_user.status != 1: + message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin')) + raise Exception(message) + if not session_user.oid or session_user.oid == 0: + message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin')) + raise Exception(message) + assistant_info = await get_assistant_info(session=session, assistant_id=embeddedId) + assistant_info = AssistantModel.model_validate(assistant_info) + assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True)) + return True, session_user, assistant_info + except Exception as e: + SQLBotLogUtil.exception(f"Embedded validation error: {str(e)}") + # Return False and the exception message + return False, e + +def xor_decrypt(encrypted_str: str, key: int = 0xABCD1234) -> int: + encrypted_bytes = base64.urlsafe_b64decode(encrypted_str) + hex_str = encrypted_bytes.hex() + encrypted_num = int(hex_str, 16) + return encrypted_num ^ key \ No newline at end of file diff --git a/backend/apps/system/models/system_model.py b/backend/apps/system/models/system_model.py index 530a4853..8d782af0 100644 --- a/backend/apps/system/models/system_model.py +++ b/backend/apps/system/models/system_model.py @@ -14,8 +14,8 @@ class AiModelBase: class AiModelDetail(SnowflakeBase, AiModelBase, table=True): __tablename__ = "ai_model" - api_key: str | None = Field(max_length=255, nullable=True) - api_domain: str = Field(max_length=255, nullable=False) + api_key: str | None = Field(nullable=True) + api_domain: str = Field(nullable=False) protocol: int = Field(nullable=False, default = 1) config: str = Field(sa_type = Text()) status: int = Field(nullable=False, default = 1) diff --git a/backend/apps/system/schemas/system_schema.py b/backend/apps/system/schemas/system_schema.py index a9546d45..6505e8fd 100644 --- a/backend/apps/system/schemas/system_schema.py +++ b/backend/apps/system/schemas/system_schema.py @@ -1,8 +1,9 @@ +import re from typing import Optional + from pydantic import BaseModel, Field, field_validator from common.core.schemas import BaseCreatorDTO -import re EMAIL_REGEX = re.compile( r"^[a-zA-Z0-9]+([._-][a-zA-Z0-9]+)*@" @@ -14,106 +15,121 @@ r"(?=.*[~!@#$%^&*()_+\-={}|:\"<>?`\[\];',./])" r"[A-Za-z\d~!@#$%^&*()_+\-={}|:\"<>?`\[\];',./]{8,20}$" ) + + class UserStatus(BaseCreatorDTO): status: int = 1 - - + class UserLanguage(BaseModel): language: str - + class BaseUser(BaseModel): account: str = Field(min_length=1, max_length=100, description="用户账号") oid: int - + + class BaseUserDTO(BaseUser, BaseCreatorDTO): - language: str = Field(pattern=r"^(zh-CN|en)$", default="zh-CN", description="用户语言") + language: str = Field(pattern=r"^(zh-CN|en|ko-KR)$", default="zh-CN", description="用户语言") password: str status: int = 1 + def to_dict(self): return { "id": self.id, "account": self.account, "oid": self.oid } + @field_validator("language") def validate_language(cls, lang: str) -> str: - if not re.fullmatch(r"^(zh-CN|en)$", lang): - raise ValueError("Language must be 'zh-CN' or 'en'") + if not re.fullmatch(r"^(zh-CN|en|ko-KR)$", lang): + raise ValueError("Language must be 'zh-CN', 'en', or 'ko-KR'") return lang + class UserCreator(BaseUser): name: str = Field(min_length=1, max_length=100, description="用户名") email: str = Field(min_length=1, max_length=100, description="用户邮箱") status: int = 1 - oid_list: Optional[list[int]] = None - + oid_list: Optional[list[int]] = None + """ @field_validator("email") def validate_email(cls, lang: str) -> str: if not re.fullmatch(EMAIL_REGEX, lang): raise ValueError("Email format is invalid!") return lang """ + class UserEditor(UserCreator, BaseCreatorDTO): - pass + pass + class UserGrid(UserEditor): create_time: int language: str = "zh-CN" - #space_name: Optional[str] = None + # space_name: Optional[str] = None origin: str = '' - - + + class PwdEditor(BaseModel): pwd: str new_pwd: str - + + class UserWsBase(BaseModel): uid_list: list[int] oid: Optional[int] = None + + class UserWsDTO(UserWsBase): weight: Optional[int] = 0 + class UserWsEditor(BaseModel): uid: int - oid: int + oid: int weight: int = 0 + class UserInfoDTO(UserEditor): language: str = "zh-CN" weight: int = 0 isAdmin: bool = False - + class AssistantBase(BaseModel): name: str domain: str - type: int = 0 + type: int = 0 # 0普通小助手 1高级 4页面嵌入 configuration: Optional[str] = None description: Optional[str] = None + + class AssistantDTO(AssistantBase, BaseCreatorDTO): pass + class AssistantHeader(AssistantDTO): unique: Optional[str] = None certificate: Optional[str] = None online: bool = False - + class AssistantValidator(BaseModel): valid: bool = False id_match: bool = False domain_match: bool = False token: Optional[str] = None - + def __init__( - self, - valid: bool = False, - id_match: bool = False, - domain_match: bool = False, - token: Optional[str] = None, - **kwargs + self, + valid: bool = False, + id_match: bool = False, + domain_match: bool = False, + token: Optional[str] = None, + **kwargs ): super().__init__( valid=valid, @@ -122,23 +138,28 @@ def __init__( token=token, **kwargs ) - + + class WorkspaceUser(UserEditor): weight: int create_time: int - + + class UserWs(BaseCreatorDTO): name: str - + + class UserWsOption(UserWs): account: str - - + + class AssistantFieldSchema(BaseModel): id: Optional[int] = None name: Optional[str] = None type: Optional[str] = None comment: Optional[str] = None + + class AssistantTableSchema(BaseModel): id: Optional[int] = None name: Optional[str] = None @@ -147,6 +168,7 @@ class AssistantTableSchema(BaseModel): sql: Optional[str] = None fields: Optional[list[AssistantFieldSchema]] = None + class AssistantOutDsBase(BaseModel): id: Optional[int] = None name: str @@ -154,8 +176,8 @@ class AssistantOutDsBase(BaseModel): type_name: Optional[str] = None comment: Optional[str] = None description: Optional[str] = None - - + + class AssistantOutDsSchema(AssistantOutDsBase): host: Optional[str] = None port: Optional[int] = None @@ -165,7 +187,8 @@ class AssistantOutDsSchema(AssistantOutDsBase): db_schema: Optional[str] = None extraParams: Optional[str] = None tables: Optional[list[AssistantTableSchema]] = None - + + class AssistantUiSchema(BaseCreatorDTO): theme: Optional[str] = None header_font_color: Optional[str] = None @@ -175,11 +198,7 @@ class AssistantUiSchema(BaseCreatorDTO): x_type: Optional[str] = 'right' x_val: Optional[int] = 0 y_type: Optional[str] = 'bottom' - y_val: Optional[str] = 33 + y_val: Optional[int] = 33 name: Optional[str] = None welcome: Optional[str] = None welcome_desc: Optional[str] = None - - - - \ No newline at end of file diff --git a/backend/apps/template/generate_chart/generator.py b/backend/apps/template/generate_chart/generator.py index 0962cba7..a519b893 100644 --- a/backend/apps/template/generate_chart/generator.py +++ b/backend/apps/template/generate_chart/generator.py @@ -4,3 +4,11 @@ def get_chart_template(): template = get_base_template() return template['template']['chart'] + +def get_base_terminology_template(): + template = get_base_template() + return template['template']['terminology'] + +def get_base_data_training_template(): + template = get_base_template() + return template['template']['data_training'] diff --git a/backend/apps/terminology/__init__.py b/backend/apps/terminology/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/terminology/api/__init__.py b/backend/apps/terminology/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/terminology/api/terminology.py b/backend/apps/terminology/api/terminology.py new file mode 100644 index 00000000..a7eda083 --- /dev/null +++ b/backend/apps/terminology/api/terminology.py @@ -0,0 +1,39 @@ +from typing import Optional + +from fastapi import APIRouter, Query + +from apps.terminology.curd.terminology import page_terminology, create_terminology, update_terminology, \ + delete_terminology +from apps.terminology.models.terminology_model import TerminologyInfo +from common.core.deps import SessionDep, CurrentUser, Trans + +router = APIRouter(tags=["Terminology"], prefix="/system/terminology") + + +@router.get("/page/{current_page}/{page_size}") +async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int, + word: Optional[str] = Query(None, description="搜索术语(可选)")): + current_page, page_size, total_count, total_pages, _list = page_terminology(session, current_page, page_size, word, + current_user.oid) + + return { + "current_page": current_page, + "page_size": page_size, + "total_count": total_count, + "total_pages": total_pages, + "data": _list + } + + +@router.put("") +async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: TerminologyInfo): + oid = current_user.oid + if info.id: + return update_terminology(session, info, oid, trans) + else: + return create_terminology(session, info, oid, trans) + + +@router.delete("") +async def delete(session: SessionDep, id_list: list[int]): + delete_terminology(session, id_list) diff --git a/backend/apps/terminology/curd/__init__.py b/backend/apps/terminology/curd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py new file mode 100644 index 00000000..f5a382eb --- /dev/null +++ b/backend/apps/terminology/curd/terminology.py @@ -0,0 +1,620 @@ +import datetime +import logging +import traceback +from typing import List, Optional, Any +from xml.dom.minidom import parseString + +import dicttoxml +from sqlalchemy import and_, or_, select, func, delete, update, union, text, BigInteger +from sqlalchemy.orm import aliased +from sqlalchemy.orm.session import Session + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.models.datasource import CoreDatasource +from apps.template.generate_chart.generator import get_base_terminology_template +from apps.terminology.models.terminology_model import Terminology, TerminologyInfo +from common.core.config import settings +from common.core.deps import SessionDep, Trans +from common.utils.embedding_threads import run_save_terminology_embeddings + + +def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None, + oid: Optional[int] = 1): + _list: List[TerminologyInfo] = [] + + child = aliased(Terminology) + + current_page = max(1, current_page) + page_size = max(10, page_size) + + total_count = 0 + total_pages = 0 + + if name and name.strip() != "": + keyword_pattern = f"%{name.strip()}%" + # 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点) + matched_ids_subquery = ( + select(Terminology.id) + .where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件 + .subquery() + ) + + # 步骤2:找到这些匹配节点的所有父节点(包括自身如果是父节点) + parent_ids_subquery = ( + select(Terminology.id) + .where( + (Terminology.id.in_(matched_ids_subquery)) | + (Terminology.id.in_( + select(Terminology.pid) + .where(Terminology.id.in_(matched_ids_subquery)) + .where(Terminology.pid.isnot(None)) + )) + ) + .where(Terminology.pid.is_(None)) # 只取父节点 + ) + + count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) + total_count = session.execute(count_stmt).scalar() + total_pages = (total_count + page_size - 1) // page_size + + if current_page > total_pages: + current_page = 1 + + # 步骤3:获取分页后的父节点ID + paginated_parent_ids = ( + parent_ids_subquery + .order_by(Terminology.create_time.desc()) + .offset((current_page - 1) * page_size) + .limit(page_size) + .subquery() + ) + + # 步骤4:获取这些父节点的childrenNames + children_subquery = ( + select( + child.pid, + func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + ) + .where(child.pid.isnot(None)) + .group_by(child.pid) + .subquery() + ) + + # 创建子查询来获取数据源名称,添加类型转换 + datasource_names_subquery = ( + select( + func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'), + Terminology.id.label('term_id') + ) + .where(Terminology.id.in_(paginated_parent_ids)) + .subquery() + ) + + # 主查询 + stmt = ( + select( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words, + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') + ) + .outerjoin( + children_subquery, + Terminology.id == children_subquery.c.pid + ) + # 关联数据源名称子查询和 CoreDatasource 表 + .outerjoin( + datasource_names_subquery, + datasource_names_subquery.c.term_id == Terminology.id + ) + .outerjoin( + CoreDatasource, + CoreDatasource.id == datasource_names_subquery.c.ds_id + ) + .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) + .group_by( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words + ) + .order_by(Terminology.create_time.desc()) + ) + else: + parent_ids_subquery = ( + select(Terminology.id) + .where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点 + ) + count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) + total_count = session.execute(count_stmt).scalar() + total_pages = (total_count + page_size - 1) // page_size + + if current_page > total_pages: + current_page = 1 + + paginated_parent_ids = ( + parent_ids_subquery + .order_by(Terminology.create_time.desc()) + .offset((current_page - 1) * page_size) + .limit(page_size) + .subquery() + ) + + children_subquery = ( + select( + child.pid, + func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + ) + .where(child.pid.isnot(None)) + .group_by(child.pid) + .subquery() + ) + + # 创建子查询来获取数据源名称 + datasource_names_subquery = ( + select( + func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'), + Terminology.id.label('term_id') + ) + .where(Terminology.id.in_(paginated_parent_ids)) + .subquery() + ) + + stmt = ( + select( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words, + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') + ) + .outerjoin( + children_subquery, + Terminology.id == children_subquery.c.pid + ) + # 关联数据源名称子查询和 CoreDatasource 表 + .outerjoin( + datasource_names_subquery, + datasource_names_subquery.c.term_id == Terminology.id + ) + .outerjoin( + CoreDatasource, + CoreDatasource.id == datasource_names_subquery.c.ds_id + ) + .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) + .group_by(Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words + ) + .order_by(Terminology.create_time.desc()) + ) + + result = session.execute(stmt) + + for row in result: + _list.append(TerminologyInfo( + id=row.id, + word=row.word, + create_time=row.create_time, + description=row.description, + other_words=row.other_words if row.other_words else [], + specific_ds=row.specific_ds if row.specific_ds is not None else False, + datasource_ids=row.datasource_ids if row.datasource_ids is not None else [], + datasource_names=row.datasource_names if row.datasource_names is not None else [], + )) + + return current_page, page_size, total_count, total_pages, _list + + +def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): + create_time = datetime.datetime.now() + + specific_ds = info.specific_ds if info.specific_ds is not None else False + datasource_ids = info.datasource_ids if info.datasource_ids is not None else [] + + if specific_ds: + if not datasource_ids: + raise Exception(trans("i18n_terminology.datasource_cannot_be_none")) + + parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid, + specific_ds=specific_ds, + datasource_ids=datasource_ids) + + words = [info.word] + for child in info.other_words: + if child in words: + raise Exception(trans("i18n_terminology.cannot_be_repeated")) + else: + words.append(child) + + # 基础查询条件(word 和 oid 必须满足) + base_query = and_( + Terminology.word.in_(words), + Terminology.oid == oid + ) + + # 构建查询 + query = session.query(Terminology).filter(base_query) + + if specific_ds: + # 仅当 specific_ds=False 时,检查数据源条件 + query = query.where( + or_( + or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)), + and_( + Terminology.specific_ds == True, + Terminology.datasource_ids.isnot(None), + text(""" + EXISTS ( + SELECT 1 FROM jsonb_array_elements(datasource_ids) AS elem + WHERE elem::text::int = ANY(:datasource_ids) + ) + """) # 检查是否包含任意目标值 + ) + ) + ) + query = query.params(datasource_ids=datasource_ids) + + # 转换为 EXISTS 查询并获取结果 + exists = session.query(query.exists()).scalar() + + if exists: + raise Exception(trans("i18n_terminology.exists_in_db")) + + result = Terminology(**parent.model_dump()) + + session.add(parent) + session.flush() + session.refresh(parent) + + result.id = parent.id + session.commit() + + _list: List[Terminology] = [] + if info.other_words: + for other_word in info.other_words: + if other_word.strip() == "": + continue + _list.append( + Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid, + specific_ds=specific_ds, datasource_ids=datasource_ids)) + session.bulk_save_objects(_list) + session.flush() + session.commit() + + # embedding + run_save_terminology_embeddings([result.id]) + + return result.id + + +def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): + count = session.query(Terminology).filter( + Terminology.oid == oid, + Terminology.id == info.id + ).count() + if count == 0: + raise Exception(trans('i18n_terminology.terminology_not_exists')) + + specific_ds = info.specific_ds if info.specific_ds is not None else False + datasource_ids = info.datasource_ids if info.datasource_ids is not None else [] + + if specific_ds: + if not datasource_ids: + raise Exception(trans("i18n_terminology.datasource_cannot_be_none")) + + words = [info.word] + for child in info.other_words: + if child in words: + raise Exception(trans("i18n_terminology.cannot_be_repeated")) + else: + words.append(child) + + # 基础查询条件(word 和 oid 必须满足) + base_query = and_( + Terminology.word.in_(words), + Terminology.oid == oid, + or_( + Terminology.pid != info.id, + and_(Terminology.pid.is_(None), Terminology.id != info.id) + ), + Terminology.id != info.id + ) + + # 构建查询 + query = session.query(Terminology).filter(base_query) + + if specific_ds: + # 仅当 specific_ds=False 时,检查数据源条件 + query = query.where( + or_( + or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)), + and_( + Terminology.specific_ds == True, + Terminology.datasource_ids.isnot(None), + text(""" + EXISTS ( + SELECT 1 FROM jsonb_array_elements(datasource_ids) AS elem + WHERE elem::text::int = ANY(:datasource_ids) + ) + """) # 检查是否包含任意目标值 + ) + ) + ) + query = query.params(datasource_ids=datasource_ids) + + # 转换为 EXISTS 查询并获取结果 + exists = session.query(query.exists()).scalar() + + if exists: + raise Exception(trans("i18n_terminology.exists_in_db")) + + stmt = update(Terminology).where(and_(Terminology.id == info.id)).values( + word=info.word, + description=info.description, + specific_ds=specific_ds, + datasource_ids=datasource_ids + ) + session.execute(stmt) + session.commit() + + stmt = delete(Terminology).where(and_(Terminology.pid == info.id)) + session.execute(stmt) + session.commit() + + create_time = datetime.datetime.now() + _list: List[Terminology] = [] + if info.other_words: + for other_word in info.other_words: + if other_word.strip() == "": + continue + _list.append( + Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid, + specific_ds=specific_ds, datasource_ids=datasource_ids)) + session.bulk_save_objects(_list) + session.flush() + session.commit() + + # embedding + run_save_terminology_embeddings([info.id]) + + return info.id + + +def delete_terminology(session: SessionDep, ids: list[int]): + stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))) + session.execute(stmt) + session.commit() + + +# def run_save_embeddings(ids: List[int]): +# executor.submit(save_embeddings, ids) +# +# +# def fill_empty_embeddings(): +# executor.submit(run_fill_empty_embeddings) + + +def run_fill_empty_embeddings(session: Session): + if not settings.EMBEDDING_ENABLED: + return + stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None))) + stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct() + combined_stmt = union(stmt1, stmt2) + results = session.execute(combined_stmt).scalars().all() + save_embeddings(session, results) + + +def save_embeddings(session: Session, ids: List[int]): + if not settings.EMBEDDING_ENABLED: + return + + if not ids or len(ids) == 0: + return + try: + + _list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all() + + _words_list = [item.word for item in _list] + + model = EmbeddingModelCache.get_model() + + results = model.embed_documents(_words_list) + + for index in range(len(results)): + item = results[index] + stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item) + session.execute(stmt) + session.commit() + + except Exception: + traceback.print_exc() + + +embedding_sql = f""" +SELECT id, pid, word, similarity +FROM +(SELECT id, pid, word, oid, specific_ds, datasource_ids, +( 1 - (embedding <=> :embedding_array) ) AS similarity +FROM terminology AS child +) TEMP +WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid +AND (specific_ds = false OR specific_ds IS NULL) +ORDER BY similarity DESC +LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT} +""" + +embedding_sql_with_datasource = f""" +SELECT id, pid, word, similarity +FROM +(SELECT id, pid, word, oid, specific_ds, datasource_ids, +( 1 - (embedding <=> :embedding_array) ) AS similarity +FROM terminology AS child +) TEMP +WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid +AND ( + (specific_ds = false OR specific_ds IS NULL) + OR + (specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource)) +) +ORDER BY similarity DESC +LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT} +""" + + +def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasource: int = None): + if word.strip() == "": + return [] + + _list: List[Terminology] = [] + + stmt = ( + select( + Terminology.id, + Terminology.pid, + Terminology.word, + ) + .where( + and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid) + ) + ) + + if datasource is not None: + stmt = stmt.where( + or_( + or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)), + and_( + Terminology.specific_ds == True, + Terminology.datasource_ids.isnot(None), + text("datasource_ids @> jsonb_build_array(:datasource)") + ) + ) + ) + else: + stmt = stmt.where(or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None))) + + # 执行查询 + params: dict[str, Any] = {'sentence': word} + if datasource is not None: + params['datasource'] = datasource + + results = session.execute(stmt, params).fetchall() + + for row in results: + _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) + + if settings.EMBEDDING_ENABLED: + with session.begin_nested(): + try: + model = EmbeddingModelCache.get_model() + + embedding = model.embed_query(word) + + if datasource is not None: + results = session.execute(text(embedding_sql_with_datasource), + {'embedding_array': str(embedding), 'oid': oid, + 'datasource': datasource}).fetchall() + else: + results = session.execute(text(embedding_sql), + {'embedding_array': str(embedding), 'oid': oid}).fetchall() + + for row in results: + _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) + + except Exception: + traceback.print_exc() + session.rollback() + + _map: dict = {} + _ids: list[int] = [] + for row in _list: + if row.id in _ids or row.pid in _ids: + continue + if row.pid is not None: + _ids.append(row.pid) + else: + _ids.append(row.id) + + if len(_ids) == 0: + return [] + + t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter( + or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all() + for row in t_list: + pid = str(row.pid) if row.pid is not None else str(row.id) + if _map.get(pid) is None: + _map[pid] = {'words': [], 'description': row.description} + _map[pid]['words'].append(row.word) + + _results: list[dict] = [] + for key in _map.keys(): + _results.append(_map.get(key)) + + return _results + + +def get_example(): + _obj = { + 'terminologies': [ + {'words': ['GDP', '国内生产总值'], + 'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'}, + ] + } + return to_xml_string(_obj, 'example') + + +def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str: + item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item' + dicttoxml.LOG.setLevel(logging.ERROR) + xml = dicttoxml.dicttoxml(_dict, + cdata=['word', 'description'], + custom_root=root, + item_func=item_name_func, + xml_declaration=False, + encoding='utf-8', + attr_type=False).decode('utf-8') + pretty_xml = parseString(xml).toprettyxml() + + if pretty_xml.startswith('') + 1 + pretty_xml = pretty_xml[end_index:].lstrip() + + # 替换所有 XML 转义字符 + escape_map = { + '<': '<', + '>': '>', + '&': '&', + '"': '"', + ''': "'" + } + for escaped, original in escape_map.items(): + pretty_xml = pretty_xml.replace(escaped, original) + + return pretty_xml + + +def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1, + datasource: Optional[int] = None) -> str: + if not oid: + oid = 1 + _results = select_terminology_by_word(session, question, oid, datasource) + if _results and len(_results) > 0: + terminology = to_xml_string(_results) + template = get_base_terminology_template().format(terminologies=terminology) + return template + else: + return '' diff --git a/backend/apps/terminology/models/__init__.py b/backend/apps/terminology/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/terminology/models/terminology_model.py b/backend/apps/terminology/models/terminology_model.py new file mode 100644 index 00000000..b9048659 --- /dev/null +++ b/backend/apps/terminology/models/terminology_model.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import List, Optional + +from pgvector.sqlalchemy import VECTOR +from pydantic import BaseModel +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity, Boolean +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import SQLModel, Field + + +class Terminology(SQLModel, table=True): + __tablename__ = "terminology" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1)) + pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True)) + create_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + word: Optional[str] = Field(max_length=255) + description: Optional[str] = Field(sa_column=Column(Text, nullable=True)) + embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) + specific_ds: Optional[bool] = Field(sa_column=Column(Boolean, default=False)) + datasource_ids: Optional[list[int]] = Field(sa_column=Column(JSONB), default=[]) + + +class TerminologyInfo(BaseModel): + id: Optional[int] = None + create_time: Optional[datetime] = None + word: Optional[str] = None + description: Optional[str] = None + other_words: Optional[List[str]] = [] + specific_ds: Optional[bool] = False + datasource_ids: Optional[list[int]] = [] + datasource_names: Optional[list[str]] = [] diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 1d1a6bfe..ca0468b8 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -47,24 +47,25 @@ def all_cors_origins(self) -> list[str]: POSTGRES_SERVER: str = 'localhost' POSTGRES_PORT: int = 5432 POSTGRES_USER: str = 'root' - POSTGRES_PASSWORD: str = "123456" + POSTGRES_PASSWORD: str = "Password123@pg" POSTGRES_DB: str = "sqlbot" SQLBOT_DB_URL: str = '' - #SQLBOT_DB_URL: str = 'mysql+pymysql://root:Password123%40mysql@127.0.0.1:3306/sqlbot' - - TOKEN_KEY: str = "X-SQLBOT-TOKEN" + # SQLBOT_DB_URL: str = 'mysql+pymysql://root:Password123%40mysql@127.0.0.1:3306/sqlbot' + + TOKEN_KEY: str = "X-SQLBOT-TOKEN" DEFAULT_PWD: str = "SQLBot@123456" ASSISTANT_TOKEN_KEY: str = "X-SQLBOT-ASSISTANT-TOKEN" - + CACHE_TYPE: Literal["redis", "memory", "None"] = "memory" - CACHE_REDIS_URL: str | None = None # Redis URL, e.g., "redis://[[username]:[password]]@localhost:6379/0" - + CACHE_REDIS_URL: str | None = None # Redis URL, e.g., "redis://[[username]:[password]]@localhost:6379/0" + LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR LOG_DIR: str = "logs" LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s:%(lineno)d - %(message)s" SQL_DEBUG: bool = False - + UPLOAD_DIR: str = "/opt/sqlbot/data/file" + SQLBOT_KEY_EXPIRED: int = 100 # License key expiration timestamp, 0 means no expiration @computed_field # type: ignore[prop-decorator] @property @@ -83,6 +84,29 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: MCP_IMAGE_PATH: str = '/opt/sqlbot/images' EXCEL_PATH: str = '/opt/sqlbot/data/excel' MCP_IMAGE_HOST: str = 'http://localhost:3000' - SERVER_IMAGE_HOST: str = '' + SERVER_IMAGE_HOST: str = 'http://YOUR_SERVE_IP:MCP_PORT/images/' + + LOCAL_MODEL_PATH: str = '/opt/sqlbot/models' + DEFAULT_EMBEDDING_MODEL: str = 'shibing624/text2vec-base-chinese' + EMBEDDING_ENABLED: bool = True + EMBEDDING_DEFAULT_SIMILARITY: float = 0.4 + EMBEDDING_TERMINOLOGY_SIMILARITY: float = EMBEDDING_DEFAULT_SIMILARITY + EMBEDDING_DATA_TRAINING_SIMILARITY: float = EMBEDDING_DEFAULT_SIMILARITY + EMBEDDING_DEFAULT_TOP_COUNT: int = 5 + EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT + EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT + + PARSE_REASONING_BLOCK_ENABLED: bool = True + DEFAULT_REASONING_CONTENT_START: str = '' + DEFAULT_REASONING_CONTENT_END: str = '' + + PG_POOL_SIZE: int = 20 + PG_MAX_OVERFLOW: int = 30 + PG_POOL_RECYCLE: int = 3600 + PG_POOL_PRE_PING: bool = True + + TABLE_EMBEDDING_ENABLED: bool = False + TABLE_EMBEDDING_COUNT: int = 10 + settings = Settings() # type: ignore diff --git a/backend/common/core/db.py b/backend/common/core/db.py index 37eff29b..d4adf346 100644 --- a/backend/common/core/db.py +++ b/backend/common/core/db.py @@ -1,14 +1,18 @@ from sqlmodel import Session, create_engine, SQLModel - from common.core.config import settings +engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI), + pool_size=settings.PG_POOL_SIZE, + max_overflow=settings.PG_MAX_OVERFLOW, + pool_recycle=settings.PG_POOL_RECYCLE, + pool_pre_ping=settings.PG_POOL_PRE_PING) -engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) def get_session(): with Session(engine) as session: yield session + def init_db(): SQLModel.metadata.create_all(engine) diff --git a/backend/common/core/pagination.py b/backend/common/core/pagination.py index ab791147..f4a88445 100644 --- a/backend/common/core/pagination.py +++ b/backend/common/core/pagination.py @@ -14,6 +14,8 @@ def _process_result_row(self, row: Row) -> Dict[str, Any]: result_dict = {} if isinstance(row, int): return {'id': row} + if isinstance(row, SQLModel) and not hasattr(row, '_fields'): + return row.model_dump() for item, key in zip(row, row._fields): if isinstance(item, SQLModel): result_dict.update(item.model_dump()) diff --git a/backend/common/core/response_middleware.py b/backend/common/core/response_middleware.py index 6842893e..c60a959d 100644 --- a/backend/common/core/response_middleware.py +++ b/backend/common/core/response_middleware.py @@ -1,18 +1,33 @@ import json -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import JSONResponse + from starlette.exceptions import HTTPException +from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request +from starlette.responses import JSONResponse + from common.core.config import settings from common.utils.utils import SQLBotLogUtil + + class ResponseMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) async def dispatch(self, request, call_next): response = await call_next(request) - - if isinstance(response, JSONResponse) or request.url.path == f"{settings.API_V1_STR}/openapi.json": + + direct_paths = [ + f"{settings.API_V1_STR}/mcp/mcp_question", + f"{settings.API_V1_STR}/mcp/mcp_assistant" + ] + + route = request.scope.get("route") + # 获取定义的路径模式,例如 '/items/{item_id}' + path_pattern = '' if not route else route.path_format + + if (isinstance(response, JSONResponse) + or request.url.path == f"{settings.API_V1_STR}/openapi.json" + or path_pattern in direct_paths): return response if response.status_code != 200: return response @@ -21,9 +36,9 @@ async def dispatch(self, request, call_next): body = b"" async for chunk in response.body_iterator: body += chunk - + raw_data = json.loads(body.decode()) - + if isinstance(raw_data, dict) and all(k in raw_data for k in ["code", "data", "msg"]): return JSONResponse( content=raw_data, @@ -33,13 +48,13 @@ async def dispatch(self, request, call_next): if k.lower() not in ("content-length", "content-type") } ) - + wrapped_data = { "code": 0, "data": raw_data, "msg": None } - + return JSONResponse( content=wrapped_data, status_code=response.status_code, @@ -58,7 +73,7 @@ async def dispatch(self, request, call_next): if k.lower() not in ("content-length", "content-type") } ) - + return response @@ -72,7 +87,6 @@ async def http_exception_handler(request: Request, exc: HTTPException): headers={"Access-Control-Allow-Origin": "*"} ) - @staticmethod async def global_exception_handler(request: Request, exc: Exception): SQLBotLogUtil.error(f"Unhandled Exception: {str(exc)}", exc_info=True) @@ -81,4 +95,3 @@ async def global_exception_handler(request: Request, exc: Exception): content=str(exc), headers={"Access-Control-Allow-Origin": "*"} ) - diff --git a/backend/common/core/sqlbot_cache.py b/backend/common/core/sqlbot_cache.py index ec6d5187..15018019 100644 --- a/backend/common/core/sqlbot_cache.py +++ b/backend/common/core/sqlbot_cache.py @@ -68,7 +68,7 @@ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - if not settings.CACHE_TYPE or settings.CACHE_TYPE.lower() == "none": + if not settings.CACHE_TYPE or settings.CACHE_TYPE.lower() == "none" or not is_cache_initialized(): return await func(*args, **kwargs) # 生成缓存键 cache_key = used_key_builder( @@ -96,7 +96,7 @@ def clear_cache( def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - if not settings.CACHE_TYPE or settings.CACHE_TYPE.lower() == "none": + if not settings.CACHE_TYPE or settings.CACHE_TYPE.lower() == "none" or not is_cache_initialized(): return await func(*args, **kwargs) cache_key = custom_key_builder( func=func, @@ -106,9 +106,9 @@ async def wrapper(*args, **kwargs): cacheName=cacheName, keyExpression=keyExpression, ) - ket_list = cache_key if isinstance(cache_key, list) else [cache_key] + key_list = cache_key if isinstance(cache_key, list) else [cache_key] backend = FastAPICache.get_backend() - for temp_cache_key in ket_list: + for temp_cache_key in key_list: if await backend.get(temp_cache_key): if settings.CACHE_TYPE.lower() == "redis": redis = backend.redis @@ -138,4 +138,21 @@ def init_sqlbot_cache(): SQLBotLogUtil.info(f"SQLBot 使用Redis缓存, 可使用多进程模式") else: SQLBotLogUtil.warning("SQLBot 未启用缓存, 可使用多进程模式") - \ No newline at end of file + + +def is_cache_initialized() -> bool: + # 检查必要的属性是否存在 + if not hasattr(FastAPICache, "_backend") or not hasattr(FastAPICache, "_prefix"): + return False + + # 检查属性值是否为 None + if FastAPICache._backend is None or FastAPICache._prefix is None: + return False + + # 尝试获取后端确认 + try: + backend = FastAPICache.get_backend() + return backend is not None + except (AssertionError, AttributeError, Exception) as e: + SQLBotLogUtil.debug(f"缓存初始化检查失败: {str(e)}") + return False \ No newline at end of file diff --git a/backend/common/error.py b/backend/common/error.py index ec1d3420..09882e2e 100644 --- a/backend/common/error.py +++ b/backend/common/error.py @@ -4,4 +4,16 @@ def __init__(self, message): self.message = message def __str__(self): - return self.message \ No newline at end of file + return self.message + + +class SQLBotDBConnectionError(Exception): + pass + + +class SQLBotDBError(Exception): + pass + + +class ParseSQLResultError(Exception): + pass diff --git a/backend/common/utils/aes_crypto.py b/backend/common/utils/aes_crypto.py new file mode 100644 index 00000000..e75ab5a6 --- /dev/null +++ b/backend/common/utils/aes_crypto.py @@ -0,0 +1,16 @@ +from typing import Optional +from common.core.config import settings +from sqlbot_xpack.aes_utils import SecureEncryption + +simple_aes_iv_text = 'sqlbot_em_aes_iv' +def sqlbot_aes_encrypt(text: str, key: Optional[str] = None) -> str: + return SecureEncryption.encrypt_to_single_string(text, key or settings.SECRET_KEY) + +def sqlbot_aes_decrypt(text: str, key: Optional[str] = None) -> str: + return SecureEncryption.decrypt_from_single_string(text, key or settings.SECRET_KEY) + +def simple_aes_encrypt(text: str, key: Optional[str] = None, ivtext: Optional[str] = None) -> str: + return SecureEncryption.simple_aes_encrypt(text, key or settings.SECRET_KEY[:32], ivtext or simple_aes_iv_text) + +def simple_aes_decrypt(text: str, key: Optional[str] = None, ivtext: Optional[str] = None) -> str: + return SecureEncryption.simple_aes_decrypt(text, key or settings.SECRET_KEY[:32], ivtext or simple_aes_iv_text) \ No newline at end of file diff --git a/backend/common/utils/crypto.py b/backend/common/utils/crypto.py index ce5d7b95..dcd1a13c 100644 --- a/backend/common/utils/crypto.py +++ b/backend/common/utils/crypto.py @@ -1,4 +1,7 @@ -from sqlbot_xpack.core import sqlbot_decrypt as xpack_sqlbot_decrypt +from sqlbot_xpack.core import sqlbot_decrypt as xpack_sqlbot_decrypt, sqlbot_encrypt as xpack_sqlbot_encrypt async def sqlbot_decrypt(text: str) -> str: - return await xpack_sqlbot_decrypt(text) \ No newline at end of file + return await xpack_sqlbot_decrypt(text) + +async def sqlbot_encrypt(text: str) -> str: + return await xpack_sqlbot_encrypt(text) \ No newline at end of file diff --git a/backend/common/utils/embedding_threads.py b/backend/common/utils/embedding_threads.py new file mode 100644 index 00000000..a38b66f0 --- /dev/null +++ b/backend/common/utils/embedding_threads.py @@ -0,0 +1,33 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import List + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from common.core.config import settings + +executor = ThreadPoolExecutor(max_workers=200) + +engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) +session_maker = sessionmaker(bind=engine) +session = session_maker() + + +def run_save_terminology_embeddings(ids: List[int]): + from apps.terminology.curd.terminology import save_embeddings + executor.submit(save_embeddings, session, ids) + + +def fill_empty_terminology_embeddings(): + from apps.terminology.curd.terminology import run_fill_empty_embeddings + executor.submit(run_fill_empty_embeddings, session) + + +def run_save_data_training_embeddings(ids: List[int]): + from apps.data_training.curd.data_training import save_embeddings + executor.submit(save_embeddings, session, ids) + + +def fill_empty_data_training_embeddings(): + from apps.data_training.curd.data_training import run_fill_empty_embeddings + executor.submit(run_fill_empty_embeddings, session) diff --git a/backend/common/utils/random.py b/backend/common/utils/random.py new file mode 100644 index 00000000..879f4225 --- /dev/null +++ b/backend/common/utils/random.py @@ -0,0 +1,6 @@ +import secrets +import string + +def get_random_string(length=16): + alphabet = string.ascii_letters + string.digits + return ''.join(secrets.choice(alphabet) for _ in range(length)) diff --git a/backend/common/utils/whitelist.py b/backend/common/utils/whitelist.py index e8654738..beef0b5b 100644 --- a/backend/common/utils/whitelist.py +++ b/backend/common/utils/whitelist.py @@ -27,11 +27,12 @@ "/system/config/key", "/images/*", "/sse", - "/system/appearance*", + "/system/appearance/ui", + "/system/appearance/picture/*", "/system/assistant/validator*", "/system/assistant/info/*", + "/system/assistant/app/*", "/system/assistant/picture/*", - "/system/embedded*", "/datasource/uploadExcel" ] diff --git a/backend/locales/en.json b/backend/locales/en.json index acee9415..049c47aa 100644 --- a/backend/locales/en.json +++ b/backend/locales/en.json @@ -1,40 +1,59 @@ { - "i18n_default_workspace": "Default workspace", - "i18n_ds_name_exist": "Name already exists", - "i18n_concat_admin": "Please contact administrator!", - "i18n_exist": "{msg} already exists!", - "i18n_not_exist": "{msg} not exists", - "i18n_error": "{key} error!", - "i18n_miss_args": "Missing {key} parameter!", - "i18n_format_invalid": "{key} format is incorrect!", - "i18n_login": { - "account_pwd_error": "Account or password error!", - "no_associated_ws": "No associated workspace, {msg}", - "user_disable": "Account disabled, {msg}" - }, - "i18n_user": { - "account": "Account", - "email": "Email", - "password": "Password", - "language_not_support": "System does not support [{key}] language!", - "ws_miss": "The current user is not in the workspace [{ws}]!" - }, - "i18n_ws": { - "title": "Workspace" - }, - "i18n_permission": { - "only_admin": "Only administrators can call this!", - "no_permission": "No permission to access {url}{msg}", - "authenticate_invalid": "Authenticate invalid [{msg}]", - "token_expired": "Token has expired" - }, - "i18n_llm": { - "validate_error": "Validation failed [{msg}]", - "delete_default_error": "Cannot delete default model [{key}]!", - "miss_default": "The default large language model has not been configured" - }, - "i18n_ds_invalid": "Datasource Invalid", - "i18n_embedded": { - "invalid_origin": "Domain verification failed [{origin}]" - } + "i18n_default_workspace": "Default workspace", + "i18n_ds_name_exist": "Name already exists", + "i18n_concat_admin": "Please contact administrator!", + "i18n_exist": "{msg} already exists!", + "i18n_name": "Name", + "i18n_not_exist": "{msg} not exists", + "i18n_error": "{key} error!", + "i18n_miss_args": "Missing {key} parameter!", + "i18n_format_invalid": "{key} format is incorrect!", + "i18n_login": { + "account_pwd_error": "Account or password error!", + "no_associated_ws": "No associated workspace, {msg}", + "user_disable": "Account disabled, {msg}" + }, + "i18n_user": { + "account": "Account", + "email": "Email", + "password": "Password", + "language_not_support": "System does not support [{key}] language!", + "ws_miss": "The current user is not in the workspace [{ws}]!" + }, + "i18n_ws": { + "title": "Workspace" + }, + "i18n_permission": { + "only_admin": "Only administrators can call this!", + "no_permission": "No permission to access {url}{msg}", + "authenticate_invalid": "Authenticate invalid [{msg}]", + "token_expired": "Token has expired" + }, + "i18n_llm": { + "validate_error": "Validation failed [{msg}]", + "delete_default_error": "Cannot delete default model [{key}]!", + "miss_default": "The default large language model has not been configured" + }, + "i18n_ds_invalid": "Datasource Invalid", + "i18n_embedded": { + "invalid_origin": "Domain verification failed [{origin}]" + }, + "i18n_terminology": { + "terminology_not_exists": "Terminology does not exists", + "datasource_cannot_be_none": "Datasource cannot be none or empty", + "cannot_be_repeated": "Term name, synonyms cannot be repeated", + "exists_in_db": "Term name, synonyms exists" + }, + "i18n_data_training": { + "datasource_cannot_be_none": "Datasource cannot be none or empty", + "data_training_not_exists": "Example does not exists", + "exists_in_db": "Question exists" + }, + "i18n_custom_prompt": { + "exists_in_db": "Prompt name exists", + "not_exists": "Prompt does not exists" + }, + "i18n_excel_export": { + "data_is_empty": "The form data is empty, cannot export data" + } } \ No newline at end of file diff --git a/backend/locales/ko-KR.json b/backend/locales/ko-KR.json new file mode 100644 index 00000000..b426ecc3 --- /dev/null +++ b/backend/locales/ko-KR.json @@ -0,0 +1,55 @@ +{ + "i18n_default_workspace": "기본 워크스페이스", + "i18n_ds_name_exist": "이미 존재하는 이름입니다", + "i18n_concat_admin": "관리자에게 문의해 주세요!", + "i18n_exist": "{msg}이(가) 이미 존재합니다!", + "i18n_name": "이름", + "i18n_not_exist": "{msg}을(를) 찾을 수 없습니다", + "i18n_error": "{key} 오류!", + "i18n_miss_args": "{key} 매개변수가 없습니다!", + "i18n_format_invalid": "{key} 형식이 올바르지 않습니다!", + "i18n_login": { + "account_pwd_error": "계정 또는 비밀번호가 올바르지 않습니다!", + "no_associated_ws": "연결된 워크스페이스가 없습니다. {msg}", + "user_disable": "계정이 비활성화되었습니다. {msg}" + }, + "i18n_user": { + "account": "계정", + "email": "이메일", + "password": "비밀번호", + "language_not_support": "[{key}] 언어는 지원하지 않습니다!", + "ws_miss": "현재 사용자는 워크스페이스 [{ws}]에 속해 있지 않습니다!" + }, + "i18n_ws": { + "title": "워크스페이스" + }, + "i18n_permission": { + "only_admin": "관리자만 호출할 수 있습니다!", + "no_permission": "{url}{msg}에 접근 권한이 없습니다", + "authenticate_invalid": "인증이 실패했습니다 [{msg}]", + "token_expired": "토큰이 만료되었습니다" + }, + "i18n_llm": { + "validate_error": "검증에 실패했습니다 [{msg}]", + "delete_default_error": "기본 모델 [{key}]은 삭제할 수 없습니다!", + "miss_default": "기본 LLM이 아직 설정되지 않았습니다" + }, + "i18n_ds_invalid": "데이터 소스가 유효하지 않습니다", + "i18n_embedded": { + "invalid_origin": "도메인 검증에 실패했습니다 [{origin}]" + }, + "i18n_terminology": { + "terminology_not_exists": "용어를 찾을 수 없습니다", + "datasource_cannot_be_none": "데이터 소스를 선택해 주세요", + "cannot_be_repeated": "용어 이름과 동의어는 중복될 수 없습니다", + "exists_in_db": "용어 이름 또는 동의어가 이미 존재합니다" + }, + "i18n_data_training": { + "datasource_cannot_be_none": "데이터 소스를 선택해 주세요", + "data_training_not_exists": "예제가 존재하지 않습니다", + "exists_in_db": "질문이 이미 존재합니다" + }, + "i18n_excel_export": { + "data_is_empty": "양식 데이터가 없어 내보낼 수 없습니다" + } +} \ No newline at end of file diff --git a/backend/locales/zh-CN.json b/backend/locales/zh-CN.json index 77ea01aa..010943d0 100644 --- a/backend/locales/zh-CN.json +++ b/backend/locales/zh-CN.json @@ -1,41 +1,59 @@ { - "i18n_default_workspace": "默认工作空间", - "i18n_ds_name_exist": "名称已存在", - "i18n_concat_admin": "请联系管理员!", - "i18n_exist": "{msg}已存在!", - "i18n_not_exist": "{msg}不存在!", - "i18n_error": "{key}错误!", - "i18n_miss_args": "缺失{key}参数!", - "i18n_format_invalid": "{key}格式不正确!", - "i18n_login": { - "account_pwd_error": "账号或密码错误!", - "no_associated_ws": "没有关联的工作空间,{msg}", - "user_disable": "账号已禁用,{msg}" - }, - "i18n_user": { - "account": "账号", - "email": "邮箱", - "password": "密码", - "language_not_support": "系统不支持[{key}]语言!", - "ws_miss": "当前用户不在工作空间[{ws}]中!" - }, - "i18n_ws": { - "title": "工作空间" - }, - "i18n_permission": { - "only_admin": "仅支持管理员调用!", - "no_permission": "无权调用{url}{msg}", - "authenticate_invalid": "认证无效【{msg}】", - "token_expired": "Token 已过期" - - }, - "i18n_llm": { - "validate_error": "校验失败[{msg}]", - "delete_default_error": "无法删除默认模型[{key}]!", - "miss_default": "尚未配置默认大语言模型" - }, - "i18n_ds_invalid": "数据源连接无效", - "i18n_embedded": { - "invalid_origin": "域名校验失败【{origin}】" - } + "i18n_default_workspace": "默认工作空间", + "i18n_ds_name_exist": "名称已存在", + "i18n_concat_admin": "请联系管理员!", + "i18n_exist": "{msg}已存在!", + "i18n_name": "名称", + "i18n_not_exist": "{msg}不存在!", + "i18n_error": "{key}错误!", + "i18n_miss_args": "缺失{key}参数!", + "i18n_format_invalid": "{key}格式不正确!", + "i18n_login": { + "account_pwd_error": "账号或密码错误!", + "no_associated_ws": "没有关联的工作空间,{msg}", + "user_disable": "账号已禁用,{msg}" + }, + "i18n_user": { + "account": "账号", + "email": "邮箱", + "password": "密码", + "language_not_support": "系统不支持[{key}]语言!", + "ws_miss": "当前用户不在工作空间[{ws}]中!" + }, + "i18n_ws": { + "title": "工作空间" + }, + "i18n_permission": { + "only_admin": "仅支持管理员调用!", + "no_permission": "无权调用{url}{msg}", + "authenticate_invalid": "认证无效【{msg}】", + "token_expired": "Token 已过期" + }, + "i18n_llm": { + "validate_error": "校验失败[{msg}]", + "delete_default_error": "无法删除默认模型[{key}]!", + "miss_default": "尚未配置默认大语言模型" + }, + "i18n_ds_invalid": "数据源连接无效", + "i18n_embedded": { + "invalid_origin": "域名校验失败【{origin}】" + }, + "i18n_terminology": { + "terminology_not_exists": "该术语不存在", + "datasource_cannot_be_none": "数据源不能为空", + "cannot_be_repeated": "术语名称,同义词不能重复", + "exists_in_db": "术语名称,同义词已存在" + }, + "i18n_data_training": { + "datasource_cannot_be_none": "数据源不能为空", + "data_training_not_exists": "该示例不存在", + "exists_in_db": "该问题已存在" + }, + "i18n_custom_prompt": { + "exists_in_db": "模版名称已存在", + "not_exists": "该模版不存在" + }, + "i18n_excel_export": { + "data_is_empty": "表单数据为空,无法导出数据" + } } \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index 7a639e5a..34e1ad67 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,5 @@ +import os + import sqlbot_xpack from alembic.config import Config from fastapi import FastAPI @@ -10,11 +12,13 @@ from alembic import command from apps.api import api_router +from apps.system.crud.aimodel_manage import async_model_info from apps.system.crud.assistant import init_dynamic_cors from apps.system.middleware.auth import TokenMiddleware from common.core.config import settings from common.core.response_middleware import ResponseMiddleware, exception_handler from common.core.sqlbot_cache import init_sqlbot_cache +from common.utils.embedding_threads import fill_empty_terminology_embeddings, fill_empty_data_training_embeddings from common.utils.utils import SQLBotLogUtil @@ -23,13 +27,24 @@ def run_migrations(): command.upgrade(alembic_cfg, "head") +def init_terminology_embedding_data(): + fill_empty_terminology_embeddings() + + +def init_data_training_embedding_data(): + fill_empty_data_training_embeddings() + + @asynccontextmanager async def lifespan(app: FastAPI): run_migrations() init_sqlbot_cache() init_dynamic_cors(app) + init_terminology_embedding_data() + init_data_training_embedding_data() SQLBotLogUtil.info("✅ SQLBot 初始化完成") await sqlbot_xpack.core.clean_xpack_cache() + await async_model_info() # 异步加密已有模型的密钥和地址 yield SQLBotLogUtil.info("SQLBot 应用关闭") @@ -48,7 +63,9 @@ def custom_generate_unique_id(route: APIRoute) -> str: mcp_app = FastAPI() # mcp server, images path -mcp_app.mount("/images", StaticFiles(directory=settings.MCP_IMAGE_PATH), name="images") +images_path = settings.MCP_IMAGE_PATH +os.makedirs(images_path, exist_ok=True) +mcp_app.mount("/images", StaticFiles(directory=images_path), name="images") mcp = FastApiMCP( app, @@ -56,7 +73,7 @@ def custom_generate_unique_id(route: APIRoute) -> str: description="SQLBot MCP Server", describe_all_responses=True, describe_full_response_schema=True, - include_operations=["get_datasource_list", "get_model_list", "mcp_question", "mcp_start"] + include_operations=["get_datasource_list", "get_model_list", "mcp_question", "mcp_start", "mcp_assistant"] ) mcp.mount(mcp_app) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index bbdf7b9d..d07fbb00 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlbot" -version = "1.0.1" +version = "1.2.0" description = "" requires-python = "==3.11.*" dependencies = [ @@ -23,7 +23,10 @@ dependencies = [ "langchain-core>=0.3,<0.4", "langchain-openai>=0.3,<0.4", "langchain-community>=0.3,<0.4", + "langchain-huggingface>=0.2.0", + "sentence-transformers>=4.0.2", "langgraph>=0.3,<0.4", + "pgvector>=0.4.1", "dashscope>=1.14.0,<2.0.0", "pymysql (>=1.1.1,<2.0.0)", "cryptography (>=44.0.3,<45.0.0)", @@ -36,7 +39,7 @@ dependencies = [ "pyyaml (>=6.0.2,<7.0.0)", "fastapi-mcp (>=0.3.4,<0.4.0)", "tabulate>=0.9.0", - "sqlbot-xpack==0.0.3.16", + "sqlbot-xpack>=0.0.3.40,<1.0.0", "fastapi-cache2>=0.2.2", "sqlparse>=0.5.3", "redis>=6.2.0", @@ -44,8 +47,30 @@ dependencies = [ "python-calamine>=0.4.0", "xlrd>=2.0.2", "clickhouse-sqlalchemy>=0.3.2", - "dmpython>=2.5.22", + "dicttoxml>=1.7.16", + "dmpython>=2.5.22; platform_system != 'Darwin'", + "redshift-connector>=2.1.8", + "elasticsearch[requests] (>=7.10,<8.0)", ] + +[project.optional-dependencies] +cpu = [ + "torch>=2.7.0", +] +cu128 = [ + "torch>=2.7.0", +] + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + [[tool.uv.index]] name = "default" url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" @@ -58,8 +83,18 @@ explicit = true [tool.uv.sources] sqlbot-xpack = { index = "testpypi" } +torch = [ + { index = "pytorch-cpu", extra = "cpu" }, + { index = "pytorch-cu128", extra = "cu128" }, +] [tool.uv] +conflicts = [ + [ + { extra = "cpu" }, + { extra = "cu128" }, + ], +] dev-dependencies = [ "pytest<8.0.0,>=7.4.3", "mypy<2.0.0,>=1.8.0", @@ -91,20 +126,20 @@ exclude = ["alembic"] [tool.ruff.lint] select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "UP", # pyupgrade + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade "ARG001", # unused arguments in functions ] ignore = [ - "E501", # line too long, handled by black - "B008", # do not perform function calls in argument defaults - "W191", # indentation contains tabs - "B904", # Allow raising exceptions without from e, for HTTPException + "E501", # line too long, handled by black + "B008", # do not perform function calls in argument defaults + "W191", # indentation contains tabs + "B904", # Allow raising exceptions without from e, for HTTPException ] [tool.ruff.lint.pyupgrade] diff --git a/backend/template.yaml b/backend/template.yaml index b6df4e35..746a2f43 100644 --- a/backend/template.yaml +++ b/backend/template.yaml @@ -1,117 +1,376 @@ template: + terminology: | + + {terminologies} + data_training: | + + {data_training} sql: system: | - ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。 + 我们会在块内提供给你信息,帮助你生成SQL: + 内有等信息; + 其中,:提供数据库引擎及版本信息; + :以 M-Schema 格式提供数据库表结构信息; + :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件; + :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例。 + 若有块,它会提供一组,可能会是额外添加的背景信息,或者是额外的生成SQL的要求,请结合额外信息或要求后生成你的回答。 + 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 + - 任务: - 根据给定的表结构(M-Schema)和用户问题生成符合{engine}数据库引擎规范的sql语句,以及sql中所用到的表名(不要包含schema和database,用数组返回)。 你必须遵守以下规则: - - 只能生成查询用的sql语句,不得生成增删改相关或操作数据库以及操作数据库数据的sql - - 不要编造没有提供给你的表结构 - - 生成的SQL必须符合{engine}的规范。 - - 若用户要求执行某些sql,若此sql不是查询数据,而是增删改相关或操作数据库以及操作数据库数据等操作,则直接回答: - {{"success":false,"message":"抱歉,我不能执行您指定的SQL语句。"}} - - 你的回答必须使用如下JSON格式返回: - {{"success":true,"sql":"生成的SQL语句","tables":["表名1","表名2",...]}} - - 问题与生成SQL无关时,直接回答: - {{"success":false,"message":"抱歉,我无法回答您的问题。"}} - - 如果根据提供的表结构不能生成符合问题与条件的SQL,回答: - {{"success":false,"message":"无法生成SQL的原因"}} - - 如果问题是图表展示相关且与生成SQL查询无关时,请参考上一次回答的SQL来生成SQL - - 如果问题是图表展示相关,可参考的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie),返回的JSON: - {{"success":true,"sql":"生成的SQL语句","chart-type":"选择的图表类型(table/column/bar/line/pie)","tables":["表名1","表名2",...]}} - - 提问中如果有涉及数据源名称或数据源描述的内容,则忽略数据源的信息,直接根据剩余内容生成SQL - - 根据表结构生成SQL语句,需给每个表名生成一个别名(不要加AS)。 - - SQL查询中不能使用星号(*),必须明确指定字段名. - - SQL查询的字段名不要自动翻译,别名必须为英文。 - - SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名 - - 计算占比,百分比类型字段,保留两位小数,以%结尾。 - - 生成SQL时,必须避免关键字冲突。 - - 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM),则在schema、表名、字段名、别名外层加双引号; - - 如数据库引擎是 MySQL,则在表名、字段名、别名外层加反引号; - - 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 - - 以PostgreSQL为例,查询Schema为TEST表TABLE下所有数据,则生成的SQL为: - SELECT "id" FROM "TEST"."TABLE" - - 注意在表名外双引号的位置,千万不要生成为: - SELECT "id" FROM "TEST.TABLE" - - 如果生成SQL的字段内有时间格式的字段(重要): + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + 你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL + + + 不要编造内没有提供给你的表结构 + + + 生成的SQL必须符合内提供数据库引擎的规范 + + + 若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句 + + + 请使用JSON格式返回你的回答: + 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} + 若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}} + + + 如果问题是图表展示相关,可参考的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 返回的JSON内chart-type值则为 table/column/bar/line/pie 中的一个 + 图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table + + + 如果问题是图表展示相关且与生成SQL查询无关时,请参考上一次回答的SQL来生成SQL + + + 返回的JSON字段中,tables字段为你回答的SQL中所用到的表名,不要包含schema和database,用数组返回 + + + 提问中如果有涉及数据源名称或数据源描述的内容,则忽略数据源的信息,直接根据剩余内容生成SQL + + + 根据表结构生成SQL语句,需给每个表名生成一个别名(不要加AS) + + + SQL查询中不能使用星号(*),必须明确指定字段名 + + + SQL查询的字段名不要自动翻译,别名必须为英文 + + + SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名 + + + 计算占比,百分比类型字段,保留两位小数,以%结尾 + + + 生成SQL时,必须避免与数据库关键字冲突 + + + 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号; + 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号; + 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 + + 以 PostgreSQL 为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: + SELECT "id" FROM "TEST"."TABLE" LIMIT 1000 + - 注意在表名外双引号的位置,千万不要生成为: + SELECT "id" FROM "TEST.TABLE" LIMIT 1000 + 以 Microsoft SQL Server 为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: + SELECT TOP 1000 [id] FROM [TEST].[TABLE] + - 注意在表名外方括号的位置,千万不要生成为: + SELECT TOP 1000 [id] FROM [TEST.TABLE] + 以 MySQL 为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: + SELECT `id` FROM `TEST`.`TABLE` LIMIT 1000 + - 注意在表名外反引号的位置,千万不要生成为: + SELECT `id` FROM `TEST.TABLE` LIMIT 1000 + + + + 如果生成SQL的字段内有时间格式的字段: - 若提问中没有指定查询顺序,则默认按时间升序排序 - 若提问是时间,且没有指定具体格式,则格式化为yyyy-MM-dd HH:mm:ss的格式 - 若提问是日期,且没有指定具体格式,则格式化为yyyy-MM-dd的格式 - 若提问是年月,且没有指定具体格式,则格式化为yyyy-MM的格式 - 若提问是年,且没有指定具体格式,则格式化为yyyy的格式 - 生成的格式化语法需要适配对应的数据库引擎。 - - 生成的SQL查询结果可以用来进行图表展示,需要注意排序字段的排序优先级,例如: - - 柱状图或折线图:适合展示在横轴的字段优先排序,若SQL包含分类字段,则分类字段次一级排序 - - ### M-Schema格式简单的解释如下: - ``` - 【DB_ID】 [Database名] - 【Schema】 - # Table: [Database名].[Table名], [表描述(若没有则为空)] - [ - ([字段名1]:[字段1的类型], [字段1的描述(这一行的逗号后都是描述,若没有则为空)]), - ([字段名2]:[字段2的类型], [字段2的描述(这一行的逗号后都是描述,若没有则为空)]), - ([字段名3]:[字段3的类型], [字段3的描述(这一行的逗号后都是描述,若没有则为空)]), - ... - ] - ``` - - ### 提供表结构如下: + + + 生成的SQL查询结果可以用来进行图表展示,需要注意排序字段的排序优先级,例如: + - 柱状图或折线图:适合展示在横轴的字段优先排序,若SQL包含分类字段,则分类字段次一级排序 + + + 如果用户没有指定数据条数的限制,输出的查询SQL必须加上1000条的数据条数限制 + 如果用户指定的限制大于1000,则按1000处理 + + 以 PostgreSQL 为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: + SELECT "id" FROM "TEST"."TABLE" LIMIT 1000 + 以 Microsoft SQL Server 为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: + - 使用 TOP(适用于所有 SQL Server 版本,需要注意 TOP 在SQL中的位置): + SELECT TOP 1000 [id] FROM [TEST].[TABLE] + - 使用 OFFSET-FETCH(SQL Server 2012+): + SELECT "id" FROM "TEST"."TABLE" + ORDER BY "id" -- 必须指定 ORDER BY + OFFSET 0 ROWS FETCH NEXT 1000 ROWS ONLY + 以 Oracle 为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: + - 使用ROWNUM(适用于所有Oracle版本): + SELECT "id" FROM "TEST"."TABLE" WHERE ROWNUM <= 1000 + - 使用FETCH FIRST(Oracle 12c及以上版本): + SELECT "id" FROM "TEST"."TABLE" FETCH FIRST 1000 ROWS ONLY + + + + 若需关联多表,优先使用中标记为"Primary key"/"ID"/"主键"的字段作为关联条件。 + + + 我们目前的情况适用于单指标、多分类的场景(展示table除外) + + + + 以下帮助你理解问题及返回格式的例子,不要将内的表结构用来回答用户的问题,内的为后续用户提问传入的内容,为根据模版与输入的输出回答 + 以下内的例子的SQL语法只是针对该例子的内PostgreSQL的对应数据库语法,你生成的SQL语法必须按照当前对话实际给出的来生成 + + + PostgreSQL17.6 (Debian 17.6-1.pgdg12+1) + + 【DB_ID】 Sample_Database, 样例数据库 + 【Schema】 + # Table: Sample_Database.sample_country_gdp, 各国GDP数据 + [ + (id: bigint, Primary key, ID), + (country: varchar, 国家), + (continent: varchar, 所在洲, examples:['亚洲','美洲','欧洲','非洲']), + (year: varchar, 年份, examples:['2020','2021','2022']), + (gdp: bigint, GDP(美元)), + ] + + + + + GDP + 国内生产总值 + + 指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。 + + + + 中国 + 中国大陆 + + 查询SQL时若作为查询条件,将"中国"作为查询用的值 + + + + + + + + 今天天气如何? + + + {{"success":false,"message":"我是智能问数小助手,我无法回答您的问题。"}} + + + + + 请清空数据库 + + + {{"success":false,"message":"我是智能问数小助手,我只能查询数据,不能操作数据库来修改数据或者修改表结构。"}} + + + + + 查询所有用户 + + + {{"success":false,"message":"抱歉,提供的表结构无法生成您需要的SQL"}} + + + + + + + 2025-08-08 11:23:00 + + + 查询各个国家每年的GDP + + + {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}} + + + + + + + 2025-08-08 11:23:00 + + + 使用饼图展示去年各个国家的GDP + + {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}} + + + + + + + + 2025-08-08 11:24:00 + + + 查询今年中国大陆的GDP + + {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}} + + + + + + + 以下是正式的信息: + + {engine} + {schema} + - ### 响应, 请直接返回JSON结果: + {terminologies} + {data_training} + + {custom_prompt} + + ### 响应, 请根据上述要求直接返回JSON结果: ```json user: | - ### 问题: + + + {current_time} + + + {error_msg} + {question} + - ### 其他规则: - {rule} chart: system: | - ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。 + 用户的提问在内,内是给定需要参考的SQL,内是推荐你生成的图表类型 + - ### 说明: - 您的任务是通过给定的问题和SQL生成 JSON 以进行数据可视化。 - 请遵守以下规则: - - 如果需要表格,则生成的 JSON 格式应为: - {{"type":"table", "title": "标题", "columns": [{{"name":"{lang}字段名1", "value": "SQL 查询列 1(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, {{"name": "{lang}字段名 2", "value": "SQL 查询列 2(有别名用别名,去掉外层的反引号、双引号、方括号)"}}]}} - 必须从 SQL 查询列中提取“columns”。 - - 如果需要柱状图,则生成的 JSON 格式应为(如果有分类则在JSON中返回series): - {{"type":"column", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} - 必须从 SQL 查询列中提取“x”和“y”。 - - 如果需要条形图,则生成的 JSON 格式应为(如果有分类则在JSON中返回series),条形图相当于是旋转后的柱状图,因此 x 轴仍为维度轴,y 轴仍为指标轴: - {{"type":"bar", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} - 必须从 SQL 查询列中提取“x”和“y”。 - - 如果需要折线图,则生成的 JSON 格式应为(如果有分类则在JSON中返回series): - {{"type":"line", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称","value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} - 其中“x”和“y”必须从SQL查询列中提取。 - - 如果需要饼图,则生成的 JSON 格式应为: - {{"type":"pie", "title": "标题", "axis": {{"y": {{"name":"值轴的{lang}名称","value":"SQL 查询数值的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} - 其中“y”和“series”必须从SQL查询列中提取。 - - 如果答案未知或者与生成JSON无关,则生成的 JSON 格式应为: - {{"type":"error", "reason": "抱歉,我无法回答您的问题。"}} - - JSON中生成的标题需要尽量精简 - - ### 示例: - 如果 SQL 为: SELECT products_sales_data.category, AVG(products_sales_data.price) AS average_price FROM products_sales_data GROUP BY products_sales_data.category; - 问题是:每个商品分类的平均价格 - 则生成的 JSON 可以是: {{"type":"table", "title": "每个商品分类的平均价格", "columns": [{{"name":"商品分类","value":"category"}}, {{"name":"平均价格","value":"average_price"}}]}} + 你必须遵守以下规则: + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + 支持的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 提供给你的值则为 table/column/bar/line/pie 中的一个,若没有推荐类型,则由你自己选择一个合适的类型。 + 图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table + + + 不需要你提供创建图表的代码,你只需要负责根据要求生成JSON配置项 + + + 用户提问的内容只是参考,主要以内的SQL为准 + + + 若用户提问内就是参考SQL,则以内的SQL为准进行推测,选择合适的图表类型展示 + + + 你需要在JSON内生成一个图表的标题,放在"title"字段内,这个标题需要尽量精简 + + + 如果需要表格,JSON格式应为: + {{"type":"table", "title": "标题", "columns": [{{"name":"{lang}字段名1", "value": "SQL 查询列 1(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, {{"name": "{lang}字段名 2", "value": "SQL 查询列 2(有别名用别名,去掉外层的反引号、双引号、方括号)"}}]}} + 必须从 SQL 查询列中提取“columns” + + + 如果需要柱状图,JSON格式应为(如果有分类则在JSON中返回series): + {{"type":"column", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 柱状图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。 + + + 如果需要条形图,JSON格式应为(如果有分类则在JSON中返回series),条形图相当于是旋转后的柱状图,因此 x 轴仍为维度轴,y 轴仍为指标轴: + {{"type":"bar", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 条形图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"和"y"与"series"。 + + + 如果需要折线图,JSON格式应为(如果有分类则在JSON中返回series): + {{"type":"line", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称","value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 折线图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。 + + + 如果需要饼图,JSON格式应为: + {{"type":"pie", "title": "标题", "axis": {{"y": {{"name":"值轴的{lang}名称","value":"SQL 查询数值的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 饼图使用一个分类字段(series)和一个数值字段(y),其中必须从SQL查询列中提取"y"与"series"。 + + + 如果SQL中没有分类列,那么JSON内的series字段不需要出现 + + + 如果SQL查询结果中存在可用于数据分类的字段(如国家、产品类型等),则必须提供series配置。如果不存在,则无需在JSON中包含series字段。 + + + 我们目前的情况适用于单指标、多分类的场景(展示table除外),若SQL中包含多指标列,请选择一个最符合提问情况的指标作为值轴 + + + 如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}} + 可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。" + + + - ### 响应, 请直接返回JSON结果: + ### 以下帮助你理解问题及返回格式的例子,不要将内的表结构用来回答用户的问题 + + + + + SELECT `u`.`email` AS `email`, `u`.`id` AS `id`, `u`.`account` AS `account`, `u`.`enable` AS `enable`, `u`.`create_time` AS `create_time`, `u`.`language` AS `language`, `u`.`default_oid` AS `default_oid`, `u`.`name` AS `name`, `u`.`phone` AS `phone`, FROM `per_user` `u` LIMIT 1000 + 查询所有用户信息 + + + + {{"type":"table","title":"所有用户信息","columns":[{{"name":"邮箱","value":"email"}},{{"name":"ID","value":"id"}},{{"name":"账号","value":"account"}},{{"name":"启用状态","value":"enable"}},{{"name":"创建时间","value":"create_time"}},{{"name":"语言","value":"language"}},{{"name":"所属组织ID","value":"default_oid"}},{{"name":"姓名","value":"name"}},{{"name":"Phone","value":"phone"}}]}} + + + + + SELECT `o`.`name` AS `org_name`, COUNT(`u`.`id`) AS `user_count` FROM `per_user` `u` JOIN `per_org` `o` ON `u`.`default_oid` = `o`.`id` GROUP BY `o`.`name` ORDER BY `user_count` DESC LIMIT 1000 + 饼图展示各个组织的人员数量 + pie + + + {{"type":"pie","title":"组织人数统计","axis":{{"y":{{"name":"人数","value":"user_count"}},"series":{{"name":"组织名称","value":"org_name"}}}}}} + + + + + + ### 响应, 请根据上述要求直接返回JSON结果: ```json user: | - ### SQL: - {sql} - - ### 问题: + {question} - - ### 其他规则: - {rule} + + + {sql} + + + {chart_type} + + guess: system: | ### 请使用语言:{lang} 回答,不需要输出深度思考过程 @@ -147,32 +406,79 @@ template: {old_questions} analysis: system: | - ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定的数据分析数据,并给出你的分析结果。 + 我们会在块内提供给你信息,帮助你进行分析: + 内有等信息; + :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件。 + 若有块,它会提供一组,可能会是额外添加的背景信息,或者是额外的分析要求,请结合额外信息或要求后生成你的回答。 + 用户会在提问中提供给你信息: + 块内是提供给你的数,以JSON格式给出; + 块内提供给你对应的字段或字段别名。 + - ### 说明: - 你是一个数据分析师,你的任务是根据给定的数据分析数据,并给出你的分析结果。 + 你必须遵守以下规则: + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + + + {terminologies} + + {custom_prompt} user: | - ### 字段(字段别名): + {fields} + - ### 数据: + {data} + predict: system: | - ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定的数据进行数据预测,并给出你的预测结果。 + 若有块,它会提供一组,可能会是额外添加的背景信息,或者是额外的分析要求,请结合额外信息或要求后生成你的回答。 + 用户会在提问中提供给你信息: + 块内是提供给你的数据,以JSON格式给出; + 块内提供给你对应的字段或字段别名。 + - ### 说明: - 你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以JSON格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json数组的格式返回,返回的格式需要与传入的数据格式保持一致。 + 你必须遵守以下规则: + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + 预测的数据是一段可以展示趋势的数据,至少2个周期 + + + 返回的预测数据必须与用户提供的数据同样的格式,使用JSON数组的形式返回 + + + 无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式):"抱歉,该数据无法进行预测。"(若有原因,则额外返回无法预测的原因) + + + 预测的数据不需要返回用户提供的原有数据,请直接返回你预测的部份 + + + {custom_prompt} + + ### 响应, 请根据上述要求直接返回JSON结果: ```json - 无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式,需要翻译为 {lang} 输出):"抱歉,该数据无法进行预测。(有原因则返回无法预测的原因)" - 如果可以预测,则不需要返回原有数据,直接返回预测的部份 user: | - ### 字段(字段别名): + {fields} + - ### 数据: + {data} + datasource: system: | ### 请使用语言:{lang} 回答 @@ -208,8 +514,8 @@ template: - 如果存在冗余的过滤条件则进行去重后再生成新SQL。 - 给过滤条件中的字段前加上表别名(如果没有表别名则加表名),如:table.field。 - 生成SQL时,必须避免关键字冲突: - - 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM),则在schema、表名、字段名、别名外层加双引号; - - 如数据库引擎是 MySQL,则在表名、字段名、别名外层加反引号; + - 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号; + - 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号; - 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 - 生成的SQL使用JSON格式返回: {{"success":true,"sql":"生成的SQL语句"}} diff --git a/docker-compose.yaml b/docker-compose.yaml index 71b177df..5cd98599 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,8 +1,9 @@ services: sqlbot: - image: dataease/sqlbot:v1.0.0 + image: dataease/sqlbot container_name: sqlbot restart: always + privileged: true networks: - sqlbot-network ports: @@ -10,16 +11,16 @@ services: - 8001:8001 environment: # Database configuration - POSTGRES_SERVER: sqlbot-db + POSTGRES_SERVER: localhost POSTGRES_PORT: 5432 POSTGRES_DB: sqlbot - POSTGRES_USER: sqlbot - POSTGRES_PASSWORD: sqlbot + POSTGRES_USER: root + POSTGRES_PASSWORD: Password123@pg # Project basic settings PROJECT_NAME: "SQLBot" DEFAULT_PWD: "SQLBot@123456" # MCP settings - SERVER_IMAGE_HOST: https://YOUR_SERVE_IP:MCP_PORT/images/ + SERVER_IMAGE_HOST: http://YOUR_SERVE_IP:MCP_PORT/images/ # Auth & Security SECRET_KEY: y5txe1mRmS_JpOrUzFzHEu-kIQn3lf7ll0AOv9DQh0s # CORS settings @@ -29,28 +30,10 @@ services: SQL_DEBUG: False volumes: - ./data/sqlbot/excel:/opt/sqlbot/data/excel + - ./data/sqlbot/file:/opt/sqlbot/data/file - ./data/sqlbot/images:/opt/sqlbot/images - - ./data/sqlbot/logs:/opt/sqlbot/logs - depends_on: - sqlbot-db: - condition: service_healthy - - sqlbot-db: - image: postgres:17.5 - container_name: sqlbot-db - restart: always - networks: - - sqlbot-network - volumes: + - ./data/sqlbot/logs:/opt/sqlbot/app/logs - ./data/postgresql:/var/lib/postgresql/data - environment: - POSTGRES_DB: sqlbot - POSTGRES_USER: sqlbot - POSTGRES_PASSWORD: sqlbot - healthcheck: - test: ["CMD-SHELL", "pg_isready"] - interval: 3s - timeout: 5s - retries: 5 + networks: sqlbot-network: diff --git a/frontend/embedded.html b/frontend/embedded.html index 12c49178..6938017f 100644 --- a/frontend/embedded.html +++ b/frontend/embedded.html @@ -2,7 +2,7 @@ - + SQLBot diff --git a/frontend/index.html b/frontend/index.html index 12c49178..97c846b8 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -2,7 +2,7 @@ - + SQLBot diff --git a/frontend/package.json b/frontend/package.json index c06d2008..50e4dbce 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -18,6 +18,7 @@ "dependencies": { "@antv/g2": "^5.3.3", "@antv/s2": "^2.4.3", + "@antv/x6": "^2.18.1", "@eslint/js": "^9.28.0", "@highlightjs/vue-plugin": "^2.1.0", "@npkg/tinymce-plugins": "^0.0.7", diff --git a/frontend/public/assistant.js b/frontend/public/assistant.js index 4b986199..38a963b2 100644 --- a/frontend/public/assistant.js +++ b/frontend/public/assistant.js @@ -8,8 +8,9 @@ header_font_color: 'rgb(100, 106, 115)', x_type: 'right', y_type: 'bottom', - x_value: '-5', - y_value: '33', + x_val: '30', + y_val: '30', + float_icon_drag: false, } const script_id_prefix = 'sqlbot-assistant-float-script-' const guideHtml = ` @@ -35,8 +36,8 @@ const chatButtonHtml = (data) => `
- - + + @@ -68,7 +69,7 @@ const getChatContainerHtml = (data) => { return `
- +
@@ -156,18 +157,67 @@ closeviewport.classList.remove('sqlbot-assistant-viewportnone') } } - const drag = (e) => { + if (data.float_icon_drag) { + chat_button.setAttribute('draggable', 'true') + + let startX = 0 + let startY = 0 + const img = new Image() + img.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=' + chat_button.addEventListener('dragstart', (e) => { + startX = e.clientX - chat_button.offsetLeft + startY = e.clientY - chat_button.offsetTop + e.dataTransfer.setDragImage(img, 0, 0) + }) + + chat_button.addEventListener('drag', (e) => { + if (e.clientX && e.clientY) { + const left = e.clientX - startX + const top = e.clientY - startY + + const maxX = window.innerWidth - chat_button.offsetWidth + const maxY = window.innerHeight - chat_button.offsetHeight + + chat_button.style.left = Math.min(Math.max(0, left), maxX) + 'px' + chat_button.style.top = Math.min(Math.max(0, top), maxY) + 'px' + } + }) + + let touchStartX = 0 + let touchStartY = 0 + + chat_button.addEventListener('touchstart', (e) => { + touchStartX = e.touches[0].clientX - chat_button.offsetLeft + touchStartY = e.touches[0].clientY - chat_button.offsetTop + e.preventDefault() + }) + + chat_button.addEventListener('touchmove', (e) => { + const left = e.touches[0].clientX - touchStartX + const top = e.touches[0].clientY - touchStartY + + const maxX = window.innerWidth - chat_button.offsetWidth + const maxY = window.innerHeight - chat_button.offsetHeight + + chat_button.style.left = Math.min(Math.max(0, left), maxX) + 'px' + chat_button.style.top = Math.min(Math.max(0, top), maxY) + 'px' + + e.preventDefault() + }) + } + /* const drag = (e) => { if (['touchmove', 'touchstart'].includes(e.type)) { - chat_button.style.top = e.touches[0].clientY - chat_button_img.naturalHeight / 2 + 'px' - chat_button.style.left = e.touches[0].clientX - chat_button_img.naturalWidth / 2 + 'px' + chat_button.style.top = e.touches[0].clientY - chat_button_img.clientHeight / 2 + 'px' + chat_button.style.left = e.touches[0].clientX - chat_button_img.clientHeight / 2 + 'px' } else { - chat_button.style.top = e.y - chat_button_img.naturalHeight / 2 + 'px' - chat_button.style.left = e.x - chat_button_img.naturalWidth / 2 + 'px' + chat_button.style.top = e.y - chat_button_img.clientHeight / 2 + 'px' + chat_button.style.left = e.x - chat_button_img.clientHeight / 2 + 'px' } - chat_button.style.width = chat_button_img.naturalWidth + 'px' - chat_button.style.height = chat_button_img.naturalHeight + 'px' + chat_button.style.width = chat_button_img.clientHeight + 'px' + chat_button.style.height = chat_button_img.clientHeight + 'px' } - if (data.is_draggable) { + if (data.float_icon_drag) { + chat_button.setAttribute('draggable', 'true') chat_button.addEventListener('drag', drag) chat_button.addEventListener('dragover', (e) => { e.preventDefault() @@ -175,7 +225,7 @@ chat_button.addEventListener('dragend', drag) chat_button.addEventListener('touchstart', drag) chat_button.addEventListener('touchmove', drag) - } + } */ viewport.onclick = viewport_func closeviewport.onclick = viewport_func } @@ -234,14 +284,14 @@ height: 64px; box-shadow: 1px 1px 1px 9999px rgba(0,0,0,.6); position: absolute; - ${data.x_type}: ${data.x_value}px; - ${data.y_type}: ${data.y_value}px; + ${data.x_type}: ${data.x_val}px; + ${data.y_type}: ${data.y_val}px; z-index: 10001; } #sqlbot-assistant .sqlbot-assistant-tips { position: fixed; - ${data.x_type}:calc(${data.x_value}px + 75px); - ${data.y_type}: calc(${data.y_value}px + 0px); + ${data.x_type}:calc(${data.x_val}px + 75px); + ${data.y_type}: calc(${data.y_val}px + 0px); padding: 22px 24px 24px; border-radius: 6px; color: #ffffff; @@ -297,76 +347,76 @@ display:none; } @media only screen and (max-width: 768px) { - #sqlbot-assistant-chat-container { - width: 100%; - height: 70%; - right: 0 !important; - } - } - - #sqlbot-assistant .sqlbot-assistant-chat-button{ - position: fixed; - ${data.x_type}: ${data.x_value}px; - ${data.y_type}: ${data.y_value}px; - cursor: pointer; - z-index:10000; - } - #sqlbot-assistant #sqlbot-assistant-chat-container{ - z-index:10000;position: relative; - border-radius: 8px; - //border: 1px solid #ffffff; - background: linear-gradient(188deg, rgba(235, 241, 255, 0.20) 39.6%, rgba(231, 249, 255, 0.20) 94.3%), #EFF0F1; - box-shadow: 0px 4px 8px 0px rgba(31, 35, 41, 0.10); - position: fixed;bottom: 16px;right: 16px;overflow: hidden; + #sqlbot-assistant-chat-container { + width: 100%; + height: 70%; + right: 0 !important; } + } - .ed-overlay-dialog { - margin-top: 50px; - } - .ed-drawer { - margin-top: 50px; - } + #sqlbot-assistant .sqlbot-assistant-chat-button{ + position: fixed; + ${data.x_type}: ${data.x_val}px; + ${data.y_type}: ${data.y_val}px; + cursor: pointer; + z-index:10000; + } + #sqlbot-assistant #sqlbot-assistant-chat-container{ + z-index:10000;position: relative; + border-radius: 8px; + //border: 1px solid #ffffff; + background: linear-gradient(188deg, rgba(235, 241, 255, 0.20) 39.6%, rgba(231, 249, 255, 0.20) 94.3%), #EFF0F1; + box-shadow: 0px 4px 8px 0px rgba(31, 35, 41, 0.10); + position: fixed;bottom: 16px;right: 16px;overflow: hidden; + } - #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate{ - top: 18px; - right: 15px; - position: absolute; - display: flex; - align-items: center; - line-height: 18px; - } - #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate .sqlbot-assistant-chat-close{ - margin-left:15px; - cursor: pointer; - } - #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate .sqlbot-assistant-openviewport{ + .ed-overlay-dialog { + margin-top: 50px; + } + .ed-drawer { + margin-top: 50px; + } - cursor: pointer; - } - #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate .sqlbot-assistant-closeviewport{ + #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate{ + top: 18px; + right: 15px; + position: absolute; + display: flex; + align-items: center; + line-height: 18px; + } + #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate .sqlbot-assistant-chat-close{ + margin-left:15px; + cursor: pointer; + } + #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate .sqlbot-assistant-openviewport{ - cursor: pointer; - } - #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-viewportnone{ - display:none; + cursor: pointer; + } + #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-operate .sqlbot-assistant-closeviewport{ + + cursor: pointer; + } + #sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-viewportnone{ + display:none; + } + #sqlbot-assistant #sqlbot-assistant-chat-container #sqlbot-assistant-chat-iframe-${data.id} { + height:100%; + width:100%; + border: none; + } + #sqlbot-assistant #sqlbot-assistant-chat-container { + animation: appear .4s ease-in-out; + } + @keyframes appear { + from { + height: 0;; } - #sqlbot-assistant #sqlbot-assistant-chat-container #sqlbot-assistant-chat-iframe-${data.id}{ - height:100%; - width:100%; - border: none; -} - #sqlbot-assistant #sqlbot-assistant-chat-container { - animation: appear .4s ease-in-out; - } - @keyframes appear { - from { - height: 0;; - } - to { - height: 600px; - } - }`.replaceAll('#sqlbot-assistant ', `#${sqlbot_assistantId} `) + to { + height: 600px; + } + }`.replaceAll('#sqlbot-assistant ', `#${sqlbot_assistantId} `) root.appendChild(style) } function getParam(src, key) { @@ -448,6 +498,7 @@ function loadScript(src, id) { const domain_url = getDomain(src) const online = getParam(src, 'online') + const userFlag = getParam(src, 'userFlag') let url = `${domain_url}/api/v1/system/assistant/info/${id}` if (domain_url.includes('5173')) { url = url.replace('5173', '8000') @@ -460,7 +511,10 @@ } const data = res.data const config_json = data.configuration - let tempData = Object.assign(defaultData, { id, domain_url, name: data.name }) + let tempData = Object.assign(defaultData, data) + if (tempData.configuration) { + delete tempData.configuration + } if (config_json) { const config = JSON.parse(config_json) if (config) { @@ -468,7 +522,20 @@ tempData = Object.assign(tempData, config) } } + tempData['id'] = id + tempData['domain_url'] = domain_url + + if (tempData['float_icon'] && !tempData['float_icon'].startsWith('http://')) { + tempData['float_icon'] = + `${domain_url}/api/v1/system/assistant/picture/${tempData['float_icon']}` + + if (domain_url.includes('5173')) { + tempData['float_icon'] = tempData['float_icon'].replace('5173', '8000') + } + } + tempData['online'] = online && online.toString().toLowerCase() == 'true' + tempData['userFlag'] = userFlag initsqlbot_assistant(tempData) if (data.type == 1) { registerMessageEvent(id, tempData) @@ -643,7 +710,7 @@ contentWindow.postMessage(params, url) } } - window.sqlbot_assistant_handler[id]['refresh'] = (online) => { + window.sqlbot_assistant_handler[id]['refresh'] = (online, userFlag) => { if (online != null && typeof online != 'boolean') { throw new Error('The parameter can only be of type boolean') } @@ -654,12 +721,35 @@ if (online != null) { new_url = updateParam(new_url, 'online', online) } + if (userFlag != null) { + new_url = updateParam(new_url, 'userFlag', userFlag) + } iframe.src = 'about:blank' setTimeout(() => { iframe.src = new_url }, 500) } } + window.sqlbot_assistant_handler[id]['destroy'] = () => { + const sqlbot_root_id = 'sqlbot-assistant-root-' + id + const container_div = document.getElementById(sqlbot_root_id) + if (container_div) { + const root_div = container_div.parentNode + if (root_div?.parentNode) { + root_div.parentNode.removeChild(root_div) + } + } + + const scriptDom = document.getElementById(`sqlbot-assistant-float-script-${id}`) + if (scriptDom) { + scriptDom.parentNode.removeChild(scriptDom) + } + const propName = script_id_prefix + id + '-state' + if (window[propName]) { + delete window[propName] + } + delete window.sqlbot_assistant_handler[id] + } } // window.addEventListener('load', init) const executeWhenReady = (fn) => { diff --git a/frontend/src/api/chat.ts b/frontend/src/api/chat.ts index c1b526c7..76c436dd 100644 --- a/frontend/src/api/chat.ts +++ b/frontend/src/api/chat.ts @@ -332,5 +332,9 @@ export const chatApi = { return request.fetchStream(`/chat/recommend_questions/${record_id}`, {}, controller) }, checkLLMModel: () => request.get('/system/aimodel/default', { requestOptions: { silent: true } }), - export2Excel: (data: any) => request.post('/chat/excel/export', data, { responseType: 'blob' }), + export2Excel: (data: any) => + request.post('/chat/excel/export', data, { + responseType: 'blob', + requestOptions: { customError: true }, + }), } diff --git a/frontend/src/api/datasource.ts b/frontend/src/api/datasource.ts index 3274b3f4..4a733185 100644 --- a/frontend/src/api/datasource.ts +++ b/frontend/src/api/datasource.ts @@ -3,6 +3,8 @@ import { request } from '@/utils/request' export const datasourceApi = { check: (data: any) => request.post('/datasource/check', data), check_by_id: (id: any) => request.get(`/datasource/check/${id}`), + relationGet: (id: any) => request.post(`/table_relation/get/${id}`), + relationSave: (dsId: any, data: any) => request.post(`/table_relation/save/${dsId}`, data), add: (data: any) => request.post('/datasource/add', data), list: () => request.get('/datasource/list'), update: (data: any) => request.post('/datasource/update', data), diff --git a/frontend/src/api/embedded.ts b/frontend/src/api/embedded.ts index 6002ca90..5c70c721 100644 --- a/frontend/src/api/embedded.ts +++ b/frontend/src/api/embedded.ts @@ -6,3 +6,15 @@ export const saveAssistant = (data: any) => request.post('/system/assistant', da export const getOne = (id: any) => request.get(`/system/assistant/${id}`) export const delOne = (id: any) => request.delete(`/system/assistant/${id}`) export const dsApi = (id: any) => request.get(`/datasource/ws/${id}`) + +export const embeddedApi = { + getList: (pageNum: any, pageSize: any, params: any) => + request.get(`/system/embedded/${pageNum}/${pageSize}`, { + params, + }), + secret: (id: any) => request.patch(`/system/embedded/secret/${id}`), + updateEmbedded: (data: any) => request.put('/system/embedded', data), + addEmbedded: (data: any) => request.post('/system/embedded', data), + deleteEmbedded: (params: any) => request.delete('/system/embedded', { data: params }), + getOne: (id: any) => request.get(`/system/embedded/${id}`), +} diff --git a/frontend/src/api/professional.ts b/frontend/src/api/professional.ts new file mode 100644 index 00000000..5b934e7e --- /dev/null +++ b/frontend/src/api/professional.ts @@ -0,0 +1,11 @@ +import { request } from '@/utils/request' + +export const professionalApi = { + getList: (pageNum: any, pageSize: any, params: any) => + request.get(`/system/terminology/page/${pageNum}/${pageSize}`, { + params, + }), + updateEmbedded: (data: any) => request.put('/system/terminology', data), + deleteEmbedded: (params: any) => request.delete('/system/terminology', { data: params }), + getOne: (id: any) => request.get(`/system/terminology/${id}`), +} diff --git a/frontend/src/api/prompt.ts b/frontend/src/api/prompt.ts new file mode 100644 index 00000000..01d73488 --- /dev/null +++ b/frontend/src/api/prompt.ts @@ -0,0 +1,11 @@ +import { request } from '@/utils/request' + +export const promptApi = { + getList: (pageNum: any, pageSize: any, type: any, params: any) => + request.get(`/system/custom_prompt/${type}/page/${pageNum}/${pageSize}`, { + params, + }), + updateEmbedded: (data: any) => request.put(`/system/custom_prompt`, data), + deleteEmbedded: (params: any) => request.delete('/system/custom_prompt', { data: params }), + getOne: (id: any) => request.get(`/system/custom_prompt/${id}`), +} diff --git a/frontend/src/api/system.ts b/frontend/src/api/system.ts index 9713ec9e..aa7db56d 100644 --- a/frontend/src/api/system.ts +++ b/frontend/src/api/system.ts @@ -3,8 +3,26 @@ import { request } from '@/utils/request' export const modelApi = { queryAll: (keyword?: string) => request.get('/system/aimodel', { params: keyword ? { keyword } : {} }), - add: (data: any) => request.post('/system/aimodel', data), - edit: (data: any) => request.put('/system/aimodel', data), + add: (data: any) => { + const param = data + if (param.api_key) { + param.api_key = LicenseGenerator.sqlbotEncrypt(data.api_key) + } + if (param.api_domain) { + param.api_domain = LicenseGenerator.sqlbotEncrypt(data.api_domain) + } + return request.post('/system/aimodel', param) + }, + edit: (data: any) => { + const param = data + if (param.api_key) { + param.api_key = LicenseGenerator.sqlbotEncrypt(data.api_key) + } + if (param.api_domain) { + param.api_domain = LicenseGenerator.sqlbotEncrypt(data.api_domain) + } + return request.put('/system/aimodel', param) + }, delete: (id: number) => request.delete(`/system/aimodel/${id}`), query: (id: number) => request.get(`/system/aimodel/${id}`), setDefault: (id: number) => request.put(`/system/aimodel/default/${id}`), diff --git a/frontend/src/api/training.ts b/frontend/src/api/training.ts new file mode 100644 index 00000000..c88ad13a --- /dev/null +++ b/frontend/src/api/training.ts @@ -0,0 +1,11 @@ +import { request } from '@/utils/request' + +export const trainingApi = { + getList: (pageNum: any, pageSize: any, params: any) => + request.get(`/system/data-training/page/${pageNum}/${pageSize}`, { + params, + }), + updateEmbedded: (data: any) => request.put('/system/data-training', data), + deleteEmbedded: (params: any) => request.delete('/system/data-training', { data: params }), + getOne: (id: any) => request.get(`/system/data-training/${id}`), +} diff --git a/frontend/src/assets/datasource/icon_doris.png b/frontend/src/assets/datasource/icon_doris.png new file mode 100644 index 00000000..8a80e96f Binary files /dev/null and b/frontend/src/assets/datasource/icon_doris.png differ diff --git a/frontend/src/assets/datasource/icon_es.png b/frontend/src/assets/datasource/icon_es.png new file mode 100644 index 00000000..5d8f6e6e Binary files /dev/null and b/frontend/src/assets/datasource/icon_es.png differ diff --git a/frontend/src/assets/datasource/icon_kingbase.png b/frontend/src/assets/datasource/icon_kingbase.png new file mode 100644 index 00000000..79b171f6 Binary files /dev/null and b/frontend/src/assets/datasource/icon_kingbase.png differ diff --git a/frontend/src/assets/datasource/icon_redshift.png b/frontend/src/assets/datasource/icon_redshift.png new file mode 100644 index 00000000..a6c6bdcf Binary files /dev/null and b/frontend/src/assets/datasource/icon_redshift.png differ diff --git a/frontend/src/assets/embedded/Card.png b/frontend/src/assets/embedded/Card.png new file mode 100644 index 00000000..04eb68ff Binary files /dev/null and b/frontend/src/assets/embedded/Card.png differ diff --git a/frontend/src/assets/embedded/icon_invisible_outlined.svg b/frontend/src/assets/embedded/icon_invisible_outlined.svg new file mode 100644 index 00000000..2a95dc15 --- /dev/null +++ b/frontend/src/assets/embedded/icon_invisible_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/assets/embedded/icon_refresh_outlined.svg b/frontend/src/assets/embedded/icon_refresh_outlined.svg new file mode 100644 index 00000000..ca174ee9 --- /dev/null +++ b/frontend/src/assets/embedded/icon_refresh_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/assets/embedded/icon_sidebar_outlined_nofill.svg b/frontend/src/assets/embedded/icon_sidebar_outlined_nofill.svg new file mode 100644 index 00000000..ad723ea3 --- /dev/null +++ b/frontend/src/assets/embedded/icon_sidebar_outlined_nofill.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/assets/embedded/icon_visible_outlined.svg b/frontend/src/assets/embedded/icon_visible_outlined.svg new file mode 100644 index 00000000..f4dff258 --- /dev/null +++ b/frontend/src/assets/embedded/icon_visible_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/assets/embedded/info-yellow.svg b/frontend/src/assets/embedded/info-yellow.svg new file mode 100644 index 00000000..4b0f9fd1 --- /dev/null +++ b/frontend/src/assets/embedded/info-yellow.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/frontend/src/assets/img/Default-avatar.svg b/frontend/src/assets/img/Default-avatar.svg new file mode 100644 index 00000000..50a7c89f --- /dev/null +++ b/frontend/src/assets/img/Default-avatar.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/frontend/src/assets/img/none-dashboard.svg b/frontend/src/assets/img/none-dashboard.svg new file mode 100644 index 00000000..01b130d5 --- /dev/null +++ b/frontend/src/assets/img/none-dashboard.svg @@ -0,0 +1,67 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/assets/model/icon_common_openai.png b/frontend/src/assets/model/icon_common_openai.png new file mode 100644 index 00000000..809fa7cc Binary files /dev/null and b/frontend/src/assets/model/icon_common_openai.png differ diff --git a/frontend/src/assets/permission/icon_custom-tools_colorful.png b/frontend/src/assets/permission/icon_custom-tools_colorful.png deleted file mode 100644 index 41b09345..00000000 Binary files a/frontend/src/assets/permission/icon_custom-tools_colorful.png and /dev/null differ diff --git a/frontend/src/assets/permission/icon_custom-tools_colorful.svg b/frontend/src/assets/permission/icon_custom-tools_colorful.svg new file mode 100644 index 00000000..affb781a --- /dev/null +++ b/frontend/src/assets/permission/icon_custom-tools_colorful.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/frontend/src/assets/permission/icon_dashboard.svg b/frontend/src/assets/permission/icon_dashboard.svg new file mode 100644 index 00000000..11f5dbb5 --- /dev/null +++ b/frontend/src/assets/permission/icon_dashboard.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/assets/svg/401.svg b/frontend/src/assets/svg/401.svg new file mode 100644 index 00000000..d02779fa --- /dev/null +++ b/frontend/src/assets/svg/401.svg @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/assets/svg/LOGO-custom.svg b/frontend/src/assets/svg/LOGO-custom.svg new file mode 100644 index 00000000..24a7390f --- /dev/null +++ b/frontend/src/assets/svg/LOGO-custom.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/assets/svg/avatar_personal.svg b/frontend/src/assets/svg/avatar_personal.svg index 34de8b7f..c25f512a 100644 --- a/frontend/src/assets/svg/avatar_personal.svg +++ b/frontend/src/assets/svg/avatar_personal.svg @@ -1,5 +1,5 @@ - - + + diff --git a/frontend/src/assets/svg/chart/icon_bar_outlined.svg b/frontend/src/assets/svg/chart/icon_bar_outlined.svg index 2fecb922..5b34a326 100644 --- a/frontend/src/assets/svg/chart/icon_bar_outlined.svg +++ b/frontend/src/assets/svg/chart/icon_bar_outlined.svg @@ -1,3 +1,3 @@ - + diff --git a/frontend/src/assets/svg/chart/icon_chart-line.svg b/frontend/src/assets/svg/chart/icon_chart-line.svg index 4899f0b1..0f3f6780 100644 --- a/frontend/src/assets/svg/chart/icon_chart-line.svg +++ b/frontend/src/assets/svg/chart/icon_chart-line.svg @@ -1,4 +1,4 @@ - + diff --git a/frontend/src/assets/svg/chart/icon_pie_outlined.svg b/frontend/src/assets/svg/chart/icon_pie_outlined.svg index 3efd2a8c..1b5fb283 100644 --- a/frontend/src/assets/svg/chart/icon_pie_outlined.svg +++ b/frontend/src/assets/svg/chart/icon_pie_outlined.svg @@ -1,4 +1,4 @@ - + diff --git a/frontend/src/assets/svg/icon_magnify_outlined.svg b/frontend/src/assets/svg/icon_magnify_outlined.svg index 15d021ce..05ef4413 100644 --- a/frontend/src/assets/svg/icon_magnify_outlined.svg +++ b/frontend/src/assets/svg/icon_magnify_outlined.svg @@ -1,3 +1,3 @@ - - + + diff --git a/frontend/src/assets/svg/icon_mindnote_outlined.svg b/frontend/src/assets/svg/icon_mindnote_outlined.svg new file mode 100644 index 00000000..d3456f2b --- /dev/null +++ b/frontend/src/assets/svg/icon_mindnote_outlined.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/assets/svg/icon_sidebar_outlined_nofill.svg b/frontend/src/assets/svg/icon_sidebar_outlined_nofill.svg new file mode 100644 index 00000000..ad723ea3 --- /dev/null +++ b/frontend/src/assets/svg/icon_sidebar_outlined_nofill.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/assets/svg/logo-custom_small.svg b/frontend/src/assets/svg/logo-custom_small.svg new file mode 100644 index 00000000..e73df44e --- /dev/null +++ b/frontend/src/assets/svg/logo-custom_small.svg @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/assets/svg/tool-bar.svg b/frontend/src/assets/svg/tool-bar.svg new file mode 100644 index 00000000..e0462b3d --- /dev/null +++ b/frontend/src/assets/svg/tool-bar.svg @@ -0,0 +1,67 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/components/Language-selector/index.vue b/frontend/src/components/Language-selector/index.vue index 5ab19f0a..7d9ee8d1 100644 --- a/frontend/src/components/Language-selector/index.vue +++ b/frontend/src/components/Language-selector/index.vue @@ -1,21 +1,20 @@
- {{ selectedLanguage === 'zh-CN' ? '中文' : 'English' }} + {{ displayLanguageName }}
@@ -29,13 +28,24 @@ import { useUserStore } from '@/stores/user' import { ArrowDown } from '@element-plus/icons-vue' import { userApi } from '@/api/auth' -const { locale } = useI18n() +const { t, locale } = useI18n() const userStore = useUserStore() +const languageOptions = computed(() => [ + { value: 'en', label: t('common.english') }, + { value: 'zh-CN', label: t('common.simplified_chinese') }, + { value: 'ko-KR', label: t('common.korean') }, +]) + const selectedLanguage = computed(() => { return userStore.language }) +const displayLanguageName = computed(() => { + const current = languageOptions.value.find((item) => item.value === selectedLanguage.value) + return current?.label ?? t('common.language') +}) + const changeLanguage = (lang: string) => { locale.value = lang userStore.setLanguage(lang) @@ -63,4 +73,4 @@ const changeLanguage = (lang: string) => { .selected-lang { color: var(--el-color-primary); } - + \ No newline at end of file diff --git a/frontend/src/components/layout/LayoutDsl.vue b/frontend/src/components/layout/LayoutDsl.vue index 6309e43a..1c17fcd6 100644 --- a/frontend/src/components/layout/LayoutDsl.vue +++ b/frontend/src/components/layout/LayoutDsl.vue @@ -1,6 +1,8 @@ diff --git a/frontend/src/views/chat/ChatList.vue b/frontend/src/views/chat/ChatList.vue index ae853c48..04d4b682 100644 --- a/frontend/src/views/chat/ChatList.vue +++ b/frontend/src/views/chat/ChatList.vue @@ -221,7 +221,12 @@ const handleConfirmPassword = () => { {{ chat.brief ?? 'Untitled' }} diff --git a/frontend/src/views/chat/ChatListContainer.vue b/frontend/src/views/chat/ChatListContainer.vue index 7aa0a8d4..760d902f 100644 --- a/frontend/src/views/chat/ChatListContainer.vue +++ b/frontend/src/views/chat/ChatListContainer.vue @@ -19,6 +19,7 @@ const props = withDefaults( currentChatId?: number currentChat?: ChatInfo loading?: boolean + appName?: string }>(), { chatList: () => [], @@ -26,6 +27,7 @@ const props = withDefaults( currentChat: () => new ChatInfo(), loading: false, inPopover: false, + appName: '', } ) @@ -43,7 +45,7 @@ const emits = defineEmits([ ]) const assistantStore = useAssistantStore() -const isAssistant = computed(() => assistantStore.getAssistant) +const isCompletePage = computed(() => !assistantStore.getAssistant || assistantStore.getEmbedded) const search = ref() @@ -145,7 +147,7 @@ const createNewChat = async () => { } async function doCreateNewChat() { - if (isAssistant.value) { + if (!isCompletePage.value) { return } chatCreatorRef.value?.showDs() @@ -214,7 +216,7 @@ function onChatRenamed(chat: Chat) {
-
{{ t('qa.title') }}
+
{{ appName || t('qa.title') }}
@@ -252,7 +254,7 @@ function onChatRenamed(chat: Chat) { /> - + @@ -317,6 +319,9 @@ function onChatRenamed(chat: Chat) { .search { height: 32px; width: 100%; + :deep(.ed-input__wrapper) { + background-color: #f5f6f7; + } } } diff --git a/frontend/src/views/chat/ChatRow.vue b/frontend/src/views/chat/ChatRow.vue index f6afe2d0..ffa4176f 100644 --- a/frontend/src/views/chat/ChatRow.vue +++ b/frontend/src/views/chat/ChatRow.vue @@ -1,6 +1,8 @@