diff --git a/cognee/api/client.py b/cognee/api/client.py index 7c9a5c6442..53c9f97620 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -1,24 +1,11 @@ """ FastAPI server for the Cognee API. """ -from datetime import datetime import os -from uuid import UUID -import aiohttp import uvicorn import logging import sentry_sdk -from typing import List, Union, Optional, Literal -from typing_extensions import Annotated -from fastapi import FastAPI, HTTPException, Form, UploadFile, Query, Depends -from fastapi.responses import JSONResponse, FileResponse, Response +from fastapi import FastAPI +from fastapi.responses import JSONResponse, Response from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel - -from cognee.api.DTO import InDTO, OutDTO -from cognee.api.v1.search import SearchType -from cognee.modules.users.models import User -from cognee.modules.users.methods import get_authenticated_user -from cognee.modules.pipelines.models import PipelineRunStatus - # Set up logging logging.basicConfig( @@ -65,9 +52,12 @@ async def lifespan(app: FastAPI): from cognee.api.v1.users.routers import get_auth_router, get_register_router,\ get_reset_password_router, get_verify_router, get_users_router - -from cognee.api.v1.permissions.get_permissions_router import get_permissions_router - +from cognee.api.v1.permissions.routers import get_permissions_router +from cognee.api.v1.settings.routers import get_settings_router +from cognee.api.v1.datasets.routers import get_datasets_router +from cognee.api.v1.cognify.routers import get_cognify_router +from cognee.api.v1.search.routers import get_search_router +from cognee.api.v1.add.routers import get_add_router from fastapi import Request from fastapi.encoders import jsonable_encoder @@ -137,261 +127,35 @@ def health_check(): """ return Response(status_code = 200) +app.include_router( + get_datasets_router(), + prefix="/api/v1/datasets", + tags=["datasets"] +) -class ErrorResponseDTO(BaseModel): - message: str - - -class DatasetDTO(OutDTO): - id: UUID - name: str - created_at: datetime - updated_at: Optional[datetime] = None - owner_id: UUID - -@app.get("/api/v1/datasets", response_model = list[DatasetDTO]) -async def get_datasets(user: User = Depends(get_authenticated_user)): - try: - from cognee.modules.data.methods import get_datasets - datasets = await get_datasets(user.id) - - return datasets - except Exception as error: - logger.error(f"Error retrieving datasets: {str(error)}") - raise HTTPException(status_code = 500, detail = f"Error retrieving datasets: {str(error)}") from error - - -@app.delete("/api/v1/datasets/{dataset_id}", response_model = None, responses = { 404: { "model": ErrorResponseDTO }}) -async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)): - from cognee.modules.data.methods import get_dataset, delete_dataset - - dataset = await get_dataset(user.id, dataset_id) - - if dataset is None: - raise HTTPException( - status_code = 404, - detail = f"Dataset ({dataset_id}) not found." - ) - - await delete_dataset(dataset) - - -@app.get("/api/v1/datasets/{dataset_id}/graph", response_model = str) -async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authenticated_user)): - from cognee.shared.utils import render_graph - from cognee.infrastructure.databases.graph import get_graph_engine - - try: - graph_client = await get_graph_engine() - graph_url = await render_graph(graph_client.graph) - - return JSONResponse( - status_code = 200, - content = str(graph_url), - ) - except: - return JSONResponse( - status_code = 409, - content = "Graphistry credentials are not set. Please set them in your .env file.", - ) - - -class DataDTO(OutDTO): - id: UUID - name: str - created_at: datetime - updated_at: Optional[datetime] = None - extension: str - mime_type: str - raw_data_location: str - -@app.get("/api/v1/datasets/{dataset_id}/data", response_model = list[DataDTO], responses = { 404: { "model": ErrorResponseDTO }}) -async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)): - from cognee.modules.data.methods import get_dataset_data, get_dataset - - dataset = await get_dataset(user.id, dataset_id) - - if dataset is None: - return JSONResponse( - status_code = 404, - content = ErrorResponseDTO(f"Dataset ({dataset_id}) not found."), - ) - - dataset_data = await get_dataset_data(dataset_id = dataset.id) - - if dataset_data is None: - return [] - - return dataset_data - - -@app.get("/api/v1/datasets/status", response_model = dict[str, PipelineRunStatus]) -async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None, user: User = Depends(get_authenticated_user)): - from cognee.api.v1.datasets.datasets import datasets as cognee_datasets - - try: - datasets_statuses = await cognee_datasets.get_status(datasets) - - return datasets_statuses - except Exception as error: - return JSONResponse( - status_code = 409, - content = {"error": str(error)} - ) - - -@app.get("/api/v1/datasets/{dataset_id}/data/{data_id}/raw", response_class = FileResponse) -async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)): - from cognee.modules.data.methods import get_dataset, get_dataset_data - - dataset = await get_dataset(user.id, dataset_id) - - if dataset is None: - return JSONResponse( - status_code = 404, - content = { - "detail": f"Dataset ({dataset_id}) not found." - } - ) - - dataset_data = await get_dataset_data(dataset.id) - - if dataset_data is None: - raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.") - - data = [data for data in dataset_data if str(data.id) == data_id][0] - - if data is None: - return JSONResponse( - status_code = 404, - content = { - "detail": f"Data ({data_id}) not found in dataset ({dataset_id})." - } - ) - - return data.raw_data_location - - -@app.post("/api/v1/add", response_model = None) -async def add( - data: List[UploadFile], - datasetId: str = Form(...), - user: User = Depends(get_authenticated_user), -): - """ This endpoint is responsible for adding data to the graph.""" - from cognee.api.v1.add import add as cognee_add - try: - if isinstance(data, str) and data.startswith("http"): - if "github" in data: - # Perform git clone if the URL is from GitHub - repo_name = data.split("/")[-1].replace(".git", "") - os.system(f"git clone {data} .data/{repo_name}") - await cognee_add( - "data://.data/", - f"{repo_name}", - ) - else: - # Fetch and store the data from other types of URL using curl - async with aiohttp.ClientSession() as session: - async with session.get(data) as resp: - if resp.status == 200: - file_data = await resp.read() - with open(f".data/{data.split('/')[-1]}", "wb") as f: - f.write(file_data) - await cognee_add( - "data://.data/", - f"{data.split('/')[-1]}", - ) - else: - await cognee_add( - data, - datasetId, - user = user, - ) - except Exception as error: - return JSONResponse( - status_code = 409, - content = {"error": str(error)} - ) - - -class CognifyPayloadDTO(BaseModel): - datasets: List[str] - -@app.post("/api/v1/cognify", response_model = None) -async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)): - """ This endpoint is responsible for the cognitive processing of the content.""" - from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify - try: - await cognee_cognify(payload.datasets, user) - except Exception as error: - return JSONResponse( - status_code = 409, - content = {"error": str(error)} - ) - - -class SearchPayloadDTO(InDTO): - search_type: SearchType - query: str - -@app.post("/api/v1/search", response_model = list) -async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): - """ This endpoint is responsible for searching for nodes in the graph.""" - from cognee.api.v1.search import search as cognee_search - - try: - results = await cognee_search(payload.search_type, payload.query, user) - - return results - except Exception as error: - return JSONResponse( - status_code = 409, - content = {"error": str(error)} - ) - -from cognee.modules.settings.get_settings import LLMConfig, VectorDBConfig - -class LLMConfigDTO(OutDTO, LLMConfig): - pass - -class VectorDBConfigDTO(OutDTO, VectorDBConfig): - pass - -class SettingsDTO(OutDTO): - llm: LLMConfigDTO - vector_db: VectorDBConfigDTO - -@app.get("/api/v1/settings", response_model = SettingsDTO) -async def get_settings(user: User = Depends(get_authenticated_user)): - from cognee.modules.settings import get_settings as get_cognee_settings - return get_cognee_settings() - - -class LLMConfigDTO(InDTO): - provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]] - model: str - api_key: str - -class VectorDBConfigDTO(InDTO): - provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]] - url: str - api_key: str - -class SettingsPayloadDTO(InDTO): - llm: Optional[LLMConfigDTO] = None - vector_db: Optional[VectorDBConfigDTO] = None - -@app.post("/api/v1/settings", response_model = None) -async def save_settings(new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)): - from cognee.modules.settings import save_llm_config, save_vector_db_config +app.include_router( + get_add_router(), + prefix="/api/v1/add", + tags=["add"] +) - if new_settings.llm is not None: - await save_llm_config(new_settings.llm) +app.include_router( + get_cognify_router(), + prefix="/api/v1/cognify", + tags=["cognify"] +) - if new_settings.vector_db is not None: - await save_vector_db_config(new_settings.vector_db) +app.include_router( + get_search_router(), + prefix="/api/v1/search", + tags=["search"] +) +app.include_router( + get_settings_router(), + prefix="/api/v1/settings", + tags=["settings"] +) def start_api_server(host: str = "0.0.0.0", port: int = 8000): """ diff --git a/cognee/api/v1/add/routers/__init__.py b/cognee/api/v1/add/routers/__init__.py new file mode 100644 index 0000000000..eebb250ab4 --- /dev/null +++ b/cognee/api/v1/add/routers/__init__.py @@ -0,0 +1 @@ +from .get_add_router import get_add_router \ No newline at end of file diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py new file mode 100644 index 0000000000..1f45d0c956 --- /dev/null +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -0,0 +1,60 @@ +from fastapi import Form, UploadFile, Depends +from fastapi.responses import JSONResponse +from fastapi import APIRouter +from typing import List +import aiohttp +import subprocess +import logging +import os +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user + +logger = logging.getLogger(__name__) + +def get_add_router() -> APIRouter: + router = APIRouter() + + @router.post("/", response_model=None) + async def add( + data: List[UploadFile], + datasetId: str = Form(...), + user: User = Depends(get_authenticated_user), + ): + """ This endpoint is responsible for adding data to the graph.""" + from cognee.api.v1.add import add as cognee_add + try: + if isinstance(data, str) and data.startswith("http"): + if "github" in data: + # Perform git clone if the URL is from GitHub + repo_name = data.split("/")[-1].replace(".git", "") + subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True) + await cognee_add( + "data://.data/", + f"{repo_name}", + ) + else: + # Fetch and store the data from other types of URL using curl + async with aiohttp.ClientSession() as session: + async with session.get(data) as resp: + if resp.status == 200: + file_data = await resp.read() + filename = os.path.basename(data) + with open(f".data/{filename}", "wb") as f: + f.write(file_data) + await cognee_add( + "data://.data/", + f"{data.split('/')[-1]}", + ) + else: + await cognee_add( + data, + datasetId, + user=user, + ) + except Exception as error: + return JSONResponse( + status_code=409, + content={"error": str(error)} + ) + + return router \ No newline at end of file diff --git a/cognee/api/v1/cognify/routers/__init__.py b/cognee/api/v1/cognify/routers/__init__.py new file mode 100644 index 0000000000..c6d52bfa24 --- /dev/null +++ b/cognee/api/v1/cognify/routers/__init__.py @@ -0,0 +1 @@ +from .get_cognify_router import get_cognify_router \ No newline at end of file diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py new file mode 100644 index 0000000000..9616fa71cf --- /dev/null +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -0,0 +1,27 @@ +from fastapi import APIRouter +from typing import List +from pydantic import BaseModel +from cognee.modules.users.models import User +from fastapi.responses import JSONResponse +from cognee.modules.users.methods import get_authenticated_user +from fastapi import Depends + +class CognifyPayloadDTO(BaseModel): + datasets: List[str] + +def get_cognify_router() -> APIRouter: + router = APIRouter() + + @router.post("/", response_model=None) + async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)): + """ This endpoint is responsible for the cognitive processing of the content.""" + from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify + try: + await cognee_cognify(payload.datasets, user) + except Exception as error: + return JSONResponse( + status_code=409, + content={"error": str(error)} + ) + + return router \ No newline at end of file diff --git a/cognee/api/v1/datasets/routers/__init__.py b/cognee/api/v1/datasets/routers/__init__.py new file mode 100644 index 0000000000..f03428fd6d --- /dev/null +++ b/cognee/api/v1/datasets/routers/__init__.py @@ -0,0 +1 @@ +from .get_datasets_router import get_datasets_router \ No newline at end of file diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py new file mode 100644 index 0000000000..f27c6c2ad4 --- /dev/null +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -0,0 +1,178 @@ +import logging +from fastapi import APIRouter +from datetime import datetime +from uuid import UUID +from typing import List, Optional +from typing_extensions import Annotated +from fastapi import HTTPException, Query, Depends +from fastapi.responses import JSONResponse, FileResponse +from pydantic import BaseModel + +from cognee.api.DTO import OutDTO +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user +from cognee.modules.pipelines.models import PipelineRunStatus + +logger = logging.getLogger(__name__) + +class ErrorResponseDTO(BaseModel): + message: str + +class DatasetDTO(OutDTO): + id: UUID + name: str + created_at: datetime + updated_at: Optional[datetime] = None + owner_id: UUID + +class DataDTO(OutDTO): + id: UUID + name: str + created_at: datetime + updated_at: Optional[datetime] = None + extension: str + mime_type: str + raw_data_location: str + +def get_datasets_router() -> APIRouter: + router = APIRouter() + + @router.get("/", response_model=list[DatasetDTO]) + async def get_datasets(user: User = Depends(get_authenticated_user)): + try: + from cognee.modules.data.methods import get_datasets + datasets = await get_datasets(user.id) + + return datasets + except Exception as error: + logger.error(f"Error retrieving datasets: {str(error)}") + raise HTTPException(status_code=500, detail=f"Error retrieving datasets: {str(error)}") from error + + @router.delete("/{dataset_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}}) + async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)): + from cognee.modules.data.methods import get_dataset, delete_dataset + + dataset = await get_dataset(user.id, dataset_id) + + if dataset is None: + raise HTTPException( + status_code=404, + detail=f"Dataset ({dataset_id}) not found." + ) + + await delete_dataset(dataset) + + @router.delete("/{dataset_id}/data/{data_id}", response_model=None, responses={404: {"model": ErrorResponseDTO}}) + async def delete_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)): + from cognee.modules.data.methods import get_data, delete_data + from cognee.modules.data.methods import get_dataset + + # Check if user has permission to access dataset and data by trying to get the dataset + dataset = await get_dataset(user.id, dataset_id) + + #TODO: Handle situation differently if user doesn't have permission to access data? + if dataset is None: + raise HTTPException( + status_code=404, + detail=f"Dataset ({dataset_id}) not found." + ) + + data = await get_data(data_id) + + if data is None: + raise HTTPException( + status_code=404, + detail=f"Dataset ({data_id}) not found." + ) + + await delete_data(data) + + @router.get("/{dataset_id}/graph", response_model=str) + async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authenticated_user)): + from cognee.shared.utils import render_graph + from cognee.infrastructure.databases.graph import get_graph_engine + + try: + graph_client = await get_graph_engine() + graph_url = await render_graph(graph_client.graph) + + return JSONResponse( + status_code=200, + content=str(graph_url), + ) + except: + return JSONResponse( + status_code=409, + content="Graphistry credentials are not set. Please set them in your .env file.", + ) + + @router.get("/{dataset_id}/data", response_model=list[DataDTO], + responses={404: {"model": ErrorResponseDTO}}) + async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)): + from cognee.modules.data.methods import get_dataset_data, get_dataset + + dataset = await get_dataset(user.id, dataset_id) + + if dataset is None: + return JSONResponse( + status_code=404, + content=ErrorResponseDTO(f"Dataset ({dataset_id}) not found."), + ) + + dataset_data = await get_dataset_data(dataset_id=dataset.id) + + if dataset_data is None: + return [] + + return dataset_data + + @router.get("/status", response_model=dict[str, PipelineRunStatus]) + async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None, + user: User = Depends(get_authenticated_user)): + from cognee.api.v1.datasets.datasets import datasets as cognee_datasets + + try: + datasets_statuses = await cognee_datasets.get_status(datasets) + + return datasets_statuses + except Exception as error: + return JSONResponse( + status_code=409, + content={"error": str(error)} + ) + + @router.get("/{dataset_id}/data/{data_id}/raw", response_class=FileResponse) + async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)): + from cognee.modules.data.methods import get_dataset, get_dataset_data + + dataset = await get_dataset(user.id, dataset_id) + + if dataset is None: + return JSONResponse( + status_code=404, + content={ + "detail": f"Dataset ({dataset_id}) not found." + } + ) + + dataset_data = await get_dataset_data(dataset.id) + + if dataset_data is None: + raise HTTPException(status_code=404, detail=f"No data found in dataset ({dataset_id}).") + + matching_data = [data for data in dataset_data if str(data.id) == data_id] + + # Check if matching_data contains an element + if len(matching_data) == 0: + return JSONResponse( + status_code=404, + content={ + "detail": f"Data ({data_id}) not found in dataset ({dataset_id})." + } + ) + + data = matching_data[0] + + return data.raw_data_location + + return router \ No newline at end of file diff --git a/cognee/api/v1/permissions/routers/__init__.py b/cognee/api/v1/permissions/routers/__init__.py new file mode 100644 index 0000000000..986b52c3e7 --- /dev/null +++ b/cognee/api/v1/permissions/routers/__init__.py @@ -0,0 +1 @@ +from .get_permissions_router import get_permissions_router \ No newline at end of file diff --git a/cognee/api/v1/permissions/get_permissions_router.py b/cognee/api/v1/permissions/routers/get_permissions_router.py similarity index 100% rename from cognee/api/v1/permissions/get_permissions_router.py rename to cognee/api/v1/permissions/routers/get_permissions_router.py diff --git a/cognee/api/v1/search/routers/__init__.py b/cognee/api/v1/search/routers/__init__.py new file mode 100644 index 0000000000..c3b199f5f0 --- /dev/null +++ b/cognee/api/v1/search/routers/__init__.py @@ -0,0 +1 @@ +from .get_search_router import get_search_router \ No newline at end of file diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py new file mode 100644 index 0000000000..5df49635ff --- /dev/null +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -0,0 +1,31 @@ +from cognee.api.v1.search import SearchType +from fastapi.responses import JSONResponse +from cognee.modules.users.models import User +from fastapi import Depends, APIRouter +from cognee.api.DTO import InDTO +from cognee.modules.users.methods import get_authenticated_user + + +class SearchPayloadDTO(InDTO): + search_type: SearchType + query: str + +def get_search_router() -> APIRouter: + router = APIRouter() + + @router.post("/", response_model = list) + async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): + """ This endpoint is responsible for searching for nodes in the graph.""" + from cognee.api.v1.search import search as cognee_search + + try: + results = await cognee_search(payload.search_type, payload.query, user) + + return results + except Exception as error: + return JSONResponse( + status_code = 409, + content = {"error": str(error)} + ) + + return router \ No newline at end of file diff --git a/cognee/api/v1/settings/routers/__init__.py b/cognee/api/v1/settings/routers/__init__.py new file mode 100644 index 0000000000..363d26610f --- /dev/null +++ b/cognee/api/v1/settings/routers/__init__.py @@ -0,0 +1 @@ +from .get_settings_router import get_settings_router \ No newline at end of file diff --git a/cognee/api/v1/settings/routers/get_settings_router.py b/cognee/api/v1/settings/routers/get_settings_router.py new file mode 100644 index 0000000000..31692382be --- /dev/null +++ b/cognee/api/v1/settings/routers/get_settings_router.py @@ -0,0 +1,51 @@ +from fastapi import APIRouter +from cognee.api.DTO import InDTO, OutDTO +from typing import Union, Optional, Literal +from cognee.modules.users.methods import get_authenticated_user +from fastapi import Depends +from cognee.modules.users.models import User +from cognee.modules.settings.get_settings import LLMConfig, VectorDBConfig + +class LLMConfigOutputDTO(OutDTO, LLMConfig): + pass + +class VectorDBConfigOutputDTO(OutDTO, VectorDBConfig): + pass + +class SettingsDTO(OutDTO): + llm: LLMConfigOutputDTO + vector_db: VectorDBConfigOutputDTO + +class LLMConfigInputDTO(InDTO): + provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]] + model: str + api_key: str + +class VectorDBConfigInputDTO(InDTO): + provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]] + url: str + api_key: str + +class SettingsPayloadDTO(InDTO): + llm: Optional[LLMConfigInputDTO] = None + vector_db: Optional[VectorDBConfigInputDTO] = None + +def get_settings_router() -> APIRouter: + router = APIRouter() + + @router.get("/", response_model=SettingsDTO) + async def get_settings(user: User = Depends(get_authenticated_user)): + from cognee.modules.settings import get_settings as get_cognee_settings + return get_cognee_settings() + + @router.post("/", response_model=None) + async def save_settings(new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)): + from cognee.modules.settings import save_llm_config, save_vector_db_config + + if new_settings.llm is not None: + await save_llm_config(new_settings.llm) + + if new_settings.vector_db is not None: + await save_vector_db_config(new_settings.vector_db) + + return router \ No newline at end of file diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index edde075658..febfe19312 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -89,10 +89,22 @@ async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: O """ Delete data in given table based on id. Table must have an id Column. """ - async with self.get_async_session() as session: - TableModel = await self.get_table(table_name, schema_name) - await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) - await session.commit() + if self.engine.dialect.name == "sqlite": + async with self.get_async_session() as session: + TableModel = await self.get_table(table_name, schema_name) + + # Foreign key constraints are disabled by default in SQLite (for backwards compatibility), + # so must be enabled for each database connection/session separately. + await session.execute(text("PRAGMA foreign_keys = ON;")) + + await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) + await session.commit() + else: + async with self.get_async_session() as session: + TableModel = await self.get_table(table_name, schema_name) + await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) + await session.commit() + async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table: """ diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index b13346cfb6..d6c0d48457 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -24,7 +24,6 @@ def __init__( self.api_key = api_key self.embedding_engine = embedding_engine self.db_uri: str = connection_string - self.engine = create_async_engine(self.db_uri) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) diff --git a/cognee/modules/data/methods/__init__.py b/cognee/modules/data/methods/__init__.py index a904060f57..34f9433595 100644 --- a/cognee/modules/data/methods/__init__.py +++ b/cognee/modules/data/methods/__init__.py @@ -6,6 +6,8 @@ from .get_datasets import get_datasets from .get_datasets_by_name import get_datasets_by_name from .get_dataset_data import get_dataset_data +from .get_data import get_data # Delete from .delete_dataset import delete_dataset +from .delete_data import delete_data \ No newline at end of file diff --git a/cognee/modules/data/methods/delete_data.py b/cognee/modules/data/methods/delete_data.py new file mode 100644 index 0000000000..7560762e19 --- /dev/null +++ b/cognee/modules/data/methods/delete_data.py @@ -0,0 +1,19 @@ +from cognee.modules.data.models import Data +from cognee.infrastructure.databases.relational import get_relational_engine + + +async def delete_data(data: Data): + """Delete a data record from the database. + + Args: + data (Data): The data object to be deleted. + + Raises: + ValueError: If the data object is invalid. + """ + if not hasattr(data, '__tablename__'): + raise ValueError("The provided data object is missing the required '__tablename__' attribute.") + + db_engine = get_relational_engine() + + return await db_engine.delete_data_by_id(data.__tablename__, data.id) diff --git a/cognee/modules/data/methods/delete_dataset.py b/cognee/modules/data/methods/delete_dataset.py index b0fe96c427..c2205144d0 100644 --- a/cognee/modules/data/methods/delete_dataset.py +++ b/cognee/modules/data/methods/delete_dataset.py @@ -4,4 +4,4 @@ async def delete_dataset(dataset: Dataset): db_engine = get_relational_engine() - return await db_engine.delete_table(dataset.id) + return await db_engine.delete_data_by_id(dataset.__tablename__, dataset.id) diff --git a/cognee/modules/data/methods/get_data.py b/cognee/modules/data/methods/get_data.py new file mode 100644 index 0000000000..b07401463c --- /dev/null +++ b/cognee/modules/data/methods/get_data.py @@ -0,0 +1,20 @@ +from uuid import UUID +from typing import Optional +from cognee.infrastructure.databases.relational import get_relational_engine +from ..models import Data + +async def get_data(data_id: UUID) -> Optional[Data]: + """Retrieve data by ID. + + Args: + data_id (UUID): ID of the data to retrieve + + Returns: + Optional[Data]: The requested data object if found, None otherwise + """ + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + data = await session.get(Data, data_id) + + return data \ No newline at end of file diff --git a/cognee/modules/data/methods/get_dataset.py b/cognee/modules/data/methods/get_dataset.py index f66b707a10..9f46fa223d 100644 --- a/cognee/modules/data/methods/get_dataset.py +++ b/cognee/modules/data/methods/get_dataset.py @@ -1,8 +1,9 @@ +from typing import Optional from uuid import UUID from cognee.infrastructure.databases.relational import get_relational_engine from ..models import Dataset -async def get_dataset(user_id: UUID, dataset_id: UUID) -> Dataset: +async def get_dataset(user_id: UUID, dataset_id: UUID) -> Optional[Dataset]: db_engine = get_relational_engine() async with db_engine.get_async_session() as session: diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index feb9e3bffe..0645215395 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -20,9 +20,11 @@ class Data(Base): updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc)) datasets: Mapped[List["Dataset"]] = relationship( + "Dataset", secondary = DatasetData.__tablename__, back_populates = "data", lazy = "noload", + cascade="all, delete" ) def to_json(self) -> dict: diff --git a/cognee/modules/data/models/Dataset.py b/cognee/modules/data/models/Dataset.py index 7e35ce982f..5cf5d2351b 100644 --- a/cognee/modules/data/models/Dataset.py +++ b/cognee/modules/data/models/Dataset.py @@ -19,9 +19,11 @@ class Dataset(Base): owner_id = Column(UUID, index = True) data: Mapped[List["Data"]] = relationship( + "Data", secondary = DatasetData.__tablename__, back_populates = "datasets", lazy = "noload", + cascade="all, delete" ) def to_json(self) -> dict: diff --git a/cognee/modules/data/models/DatasetData.py b/cognee/modules/data/models/DatasetData.py index b156d8d374..ed9d3c64c8 100644 --- a/cognee/modules/data/models/DatasetData.py +++ b/cognee/modules/data/models/DatasetData.py @@ -7,5 +7,5 @@ class DatasetData(Base): created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc)) - dataset_id = Column(UUID, ForeignKey("datasets.id"), primary_key = True) - data_id = Column(UUID, ForeignKey("data.id"), primary_key = True) + dataset_id = Column(UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key = True) + data_id = Column(UUID, ForeignKey("data.id", ondelete="CASCADE"), primary_key = True)