diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2c5178241c820f..5a6903d3d576c2 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,7 +3,7 @@ from typing import Optional from flask import Flask, current_app -from sqlalchemy.orm import load_only +from sqlalchemy.orm import Session, load_only from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -144,7 +144,8 @@ def external_retrieve( @classmethod def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: - return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + return session.query(Dataset).filter(Dataset.id == dataset_id).first() @classmethod def keyword_search( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3fca48be22e002..5c0360b0647027 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,7 @@ from flask import Flask, current_app from sqlalchemy import Float, and_, or_, text from sqlalchemy import cast as sqlalchemy_cast +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -598,7 +599,8 @@ def _retriever( metadata_condition: Optional[MetadataCondition] = None, ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b34d62d669d234..f05d93d83ea4a9 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -144,6 +144,8 @@ def _run(self) -> NodeRunResult: # type: ignore error=str(e), error_type=type(e).__name__, ) + finally: + db.session.close() def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] @@ -171,6 +173,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: .all() ) + # avoid blocking at retrieval + db.session.close() + for dataset in results: # pass if dataset is not available if not dataset: