diff --git a/cognee/api/v1/add/add_v2.py b/cognee/api/v1/add/add_v2.py index 631d963e56..637c4a1878 100644 --- a/cognee/api/v1/add/add_v2.py +++ b/cognee/api/v1/add/add_v2.py @@ -2,7 +2,7 @@ from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines import run_tasks, Task -from cognee.tasks.ingestion import ingest_data_with_metadata +from cognee.tasks.ingestion import ingest_data_with_metadata, resolve_data_directories from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables @@ -14,6 +14,7 @@ async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_nam user = await get_default_user() tasks = [ + Task(resolve_data_directories), Task(ingest_data_with_metadata, dataset_name, user) ] diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index 31e3fa67d8..1ba96a2323 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -76,7 +76,7 @@ async def delete_data(dataset_id: str, data_id: str, user: User = Depends(get_au message=f"Dataset ({dataset_id}) not found." ) - data = await get_data(data_id) + data = await get_data(user.id, data_id) if data is None: raise EntityNotFoundError( @@ -141,6 +141,7 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset @router.get("/{dataset_id}/data/{data_id}/raw", response_class=FileResponse) async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)): + from cognee.modules.data.methods import get_data from cognee.modules.data.methods import get_dataset, get_dataset_data dataset = await get_dataset(user.id, dataset_id) @@ -164,7 +165,10 @@ async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_a if len(matching_data) == 0: raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).") - data = matching_data[0] + data = await get_data(user.id, data_id) + + if data is None: + raise EntityNotFoundError(message=f"Data ({data_id}) not found in dataset ({dataset_id}).") return data.raw_data_location diff --git a/cognee/api/v1/permissions/routers/get_permissions_router.py b/cognee/api/v1/permissions/routers/get_permissions_router.py index 8d012d6002..2b30f62fd2 100644 --- a/cognee/api/v1/permissions/routers/get_permissions_router.py +++ b/cognee/api/v1/permissions/routers/get_permissions_router.py @@ -1,46 +1,63 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse from sqlalchemy.orm import Session +from sqlalchemy.future import select +from sqlalchemy import insert +from sqlalchemy.exc import IntegrityError +from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundError from cognee.modules.users import get_user_db -from cognee.modules.users.models import User, Group, Permission +from cognee.modules.users.models import User, Group, Permission, UserGroup, GroupPermission def get_permissions_router() -> APIRouter: permissions_router = APIRouter() @permissions_router.post("/groups/{group_id}/permissions") - async def give_permission_to_group(group_id: int, permission: str, db: Session = Depends(get_user_db)): - group = db.query(Group).filter(Group.id == group_id).first() + async def give_permission_to_group(group_id: str, permission: str, db: Session = Depends(get_user_db)): + group = (await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first() if not group: raise GroupNotFoundError - permission = db.query(Permission).filter(Permission.name == permission).first() + permission_entity = ( + await db.session.execute(select(Permission).where(Permission.name == permission))).scalars().first() - if not permission: - permission = Permission(name = permission) - db.add(permission) + if not permission_entity: + stmt = insert(Permission).values(name=permission) + await db.session.execute(stmt) + permission_entity = ( + await db.session.execute(select(Permission).where(Permission.name == permission))).scalars().first() - group.permissions.append(permission) + try: + # add permission to group + await db.session.execute( + insert(GroupPermission).values(group_id=group.id, permission_id=permission_entity.id)) + except IntegrityError as e: + raise EntityAlreadyExistsError(message="Group permission already exists.") - db.commit() + await db.session.commit() return JSONResponse(status_code = 200, content = {"message": "Permission assigned to group"}) @permissions_router.post("/users/{user_id}/groups") - async def add_user_to_group(user_id: int, group_id: int, db: Session = Depends(get_user_db)): - user = db.query(User).filter(User.id == user_id).first() - group = db.query(Group).filter(Group.id == group_id).first() + async def add_user_to_group(user_id: str, group_id: str, db: Session = Depends(get_user_db)): + user = (await db.session.execute(select(User).where(User.id == user_id))).scalars().first() + group = (await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first() if not user: raise UserNotFoundError elif not group: raise GroupNotFoundError - user.groups.append(group) + try: + # Add association directly to the association table + stmt = insert(UserGroup).values(user_id=user_id, group_id=group_id) + await db.session.execute(stmt) + except IntegrityError as e: + raise EntityAlreadyExistsError(message="User is already part of group.") - db.commit() + await db.session.commit() return JSONResponse(status_code = 200, content = {"message": "User added to group"}) diff --git a/cognee/api/v1/search/search_v2.py b/cognee/api/v1/search/search_v2.py index d77aa5fa81..6a5da4648b 100644 --- a/cognee/api/v1/search/search_v2.py +++ b/cognee/api/v1/search/search_v2.py @@ -14,11 +14,13 @@ from cognee.tasks.chunks import query_chunks from cognee.tasks.graph import query_graph_connections from cognee.tasks.summarization import query_summaries +from cognee.tasks.completion import query_completion class SearchType(Enum): SUMMARIES = "SUMMARIES" INSIGHTS = "INSIGHTS" CHUNKS = "CHUNKS" + COMPLETION = "COMPLETION" async def search(query_type: SearchType, query_text: str, user: User = None) -> list: if user is None: @@ -50,6 +52,7 @@ async def specific_search(query_type: SearchType, query: str, user) -> list: SearchType.SUMMARIES: query_summaries, SearchType.INSIGHTS: query_graph_connections, SearchType.CHUNKS: query_chunks, + SearchType.COMPLETION: query_completion, } search_task = search_tasks.get(query_type) diff --git a/cognee/infrastructure/llm/prompts/answer_simple_question.txt b/cognee/infrastructure/llm/prompts/answer_simple_question.txt new file mode 100644 index 0000000000..351e1e5e99 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/answer_simple_question.txt @@ -0,0 +1 @@ +Answer the question using the provided context. Be as brief as possible. \ No newline at end of file diff --git a/cognee/modules/data/exceptions/__init__.py b/cognee/modules/data/exceptions/__init__.py index fa8468c880..6f74c627e5 100644 --- a/cognee/modules/data/exceptions/__init__.py +++ b/cognee/modules/data/exceptions/__init__.py @@ -6,4 +6,5 @@ from .exceptions import ( UnstructuredLibraryImportError, + UnauthorizedDataAccessError, ) \ No newline at end of file diff --git a/cognee/modules/data/exceptions/exceptions.py b/cognee/modules/data/exceptions/exceptions.py index 3b1aac52c8..5117f3caca 100644 --- a/cognee/modules/data/exceptions/exceptions.py +++ b/cognee/modules/data/exceptions/exceptions.py @@ -7,5 +7,14 @@ def __init__( message: str = "Import error. Unstructured library is not installed.", name: str = "UnstructuredModuleImportError", status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) + +class UnauthorizedDataAccessError(CogneeApiError): + def __init__( + self, + message: str = "User does not have permission to access this data.", + name: str = "UnauthorizedDataAccessError", + status_code=status.HTTP_401_UNAUTHORIZED, ): super().__init__(message, name, status_code) \ No newline at end of file diff --git a/cognee/modules/data/methods/get_data.py b/cognee/modules/data/methods/get_data.py index b07401463c..d7daff29bc 100644 --- a/cognee/modules/data/methods/get_data.py +++ b/cognee/modules/data/methods/get_data.py @@ -1,12 +1,14 @@ from uuid import UUID from typing import Optional from cognee.infrastructure.databases.relational import get_relational_engine +from ..exceptions import UnauthorizedDataAccessError from ..models import Data -async def get_data(data_id: UUID) -> Optional[Data]: +async def get_data(user_id: UUID, data_id: UUID) -> Optional[Data]: """Retrieve data by ID. Args: + user_id (UUID): user ID data_id (UUID): ID of the data to retrieve Returns: @@ -17,4 +19,7 @@ async def get_data(data_id: UUID) -> Optional[Data]: async with db_engine.get_async_session() as session: data = await session.get(Data, data_id) + if data and data.owner_id != user_id: + raise UnauthorizedDataAccessError(message=f"User {user_id} is not authorized to access data {data_id}") + return data \ No newline at end of file diff --git a/cognee/modules/users/models/GroupPermission.py b/cognee/modules/users/models/GroupPermission.py new file mode 100644 index 0000000000..eaf3630b48 --- /dev/null +++ b/cognee/modules/users/models/GroupPermission.py @@ -0,0 +1,11 @@ +from datetime import datetime, timezone +from sqlalchemy import Column, ForeignKey, DateTime, UUID +from cognee.infrastructure.databases.relational import Base + +class GroupPermission(Base): + __tablename__ = "group_permissions" + + created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc)) + + group_id = Column(UUID, ForeignKey("groups.id"), primary_key = True) + permission_id = Column(UUID, ForeignKey("permissions.id"), primary_key = True) diff --git a/cognee/modules/users/models/__init__.py b/cognee/modules/users/models/__init__.py index 7dc1bf8ca4..a713798d57 100644 --- a/cognee/modules/users/models/__init__.py +++ b/cognee/modules/users/models/__init__.py @@ -1,5 +1,7 @@ from .User import User from .Group import Group +from .UserGroup import UserGroup +from .GroupPermission import GroupPermission from .Resource import Resource from .Permission import Permission from .ACL import ACL diff --git a/cognee/tasks/completion/__init__.py b/cognee/tasks/completion/__init__.py new file mode 100644 index 0000000000..1bf0fa6bbb --- /dev/null +++ b/cognee/tasks/completion/__init__.py @@ -0,0 +1 @@ +from .query_completion import query_completion \ No newline at end of file diff --git a/cognee/tasks/completion/exceptions/__init__.py b/cognee/tasks/completion/exceptions/__init__.py new file mode 100644 index 0000000000..5f80e6eccd --- /dev/null +++ b/cognee/tasks/completion/exceptions/__init__.py @@ -0,0 +1,9 @@ +""" +Custom exceptions for the Cognee API. + +This module defines a set of exceptions for handling various compute errors +""" + +from .exceptions import ( + NoRelevantDataFound, +) \ No newline at end of file diff --git a/cognee/tasks/completion/exceptions/exceptions.py b/cognee/tasks/completion/exceptions/exceptions.py new file mode 100644 index 0000000000..9b64c01a6d --- /dev/null +++ b/cognee/tasks/completion/exceptions/exceptions.py @@ -0,0 +1,11 @@ +from cognee.exceptions import CogneeApiError +from fastapi import status + +class NoRelevantDataFound(CogneeApiError): + def __init__( + self, + message: str = "Search did not find any data.", + name: str = "NoRelevantDataFound", + status_code=status.HTTP_404_NOT_FOUND, + ): + super().__init__(message, name, status_code) \ No newline at end of file diff --git a/cognee/tasks/completion/query_completion.py b/cognee/tasks/completion/query_completion.py new file mode 100644 index 0000000000..5324676f86 --- /dev/null +++ b/cognee/tasks/completion/query_completion.py @@ -0,0 +1,36 @@ +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.tasks.completion.exceptions import NoRelevantDataFound +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt + + +async def query_completion(query: str) -> list: + """ + Parameters: + - query (str): The query string to compute. + + Returns: + - list: Answer to the query. + """ + vector_engine = get_vector_engine() + + found_chunks = await vector_engine.search("document_chunk_text", query, limit = 1) + + if len(found_chunks) == 0: + raise NoRelevantDataFound + + args = { + "question": query, + "context": found_chunks[0].payload["text"], + } + user_prompt = render_prompt("context_for_question.txt", args) + system_prompt = read_query_prompt("answer_simple_question.txt") + + llm_client = get_llm_client() + computed_answer = await llm_client.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) + + return [computed_answer] diff --git a/cognee/tasks/ingestion/__init__.py b/cognee/tasks/ingestion/__init__.py index f569267a17..8b873b2736 100644 --- a/cognee/tasks/ingestion/__init__.py +++ b/cognee/tasks/ingestion/__init__.py @@ -3,3 +3,4 @@ from .save_data_item_to_storage import save_data_item_to_storage from .save_data_item_with_metadata_to_storage import save_data_item_with_metadata_to_storage from .ingest_data_with_metadata import ingest_data_with_metadata +from .resolve_data_directories import resolve_data_directories diff --git a/cognee/tasks/ingestion/resolve_data_directories.py b/cognee/tasks/ingestion/resolve_data_directories.py new file mode 100644 index 0000000000..9807568056 --- /dev/null +++ b/cognee/tasks/ingestion/resolve_data_directories.py @@ -0,0 +1,37 @@ +import os +from typing import List, Union, BinaryIO + +async def resolve_data_directories(data: Union[BinaryIO, List[BinaryIO], str, List[str]], include_subdirectories: bool = True): + """ + Resolves directories by replacing them with their contained files. + + Args: + data: A single file, directory, or binary stream, or a list of such items. + include_subdirectories: Whether to include files in subdirectories recursively. + + Returns: + A list of resolved files and binary streams. + """ + # Ensure `data` is a list + if not isinstance(data, list): + data = [data] + + resolved_data = [] + + for item in data: + if isinstance(item, str): # Check if the item is a path + if os.path.isdir(item): # If it's a directory + if include_subdirectories: + # Recursively add all files in the directory and subdirectories + for root, _, files in os.walk(item): + resolved_data.extend([os.path.join(root, f) for f in files]) + else: + # Add all files (not subdirectories) in the directory + resolved_data.extend( + [os.path.join(item, f) for f in os.listdir(item) if os.path.isfile(os.path.join(item, f))] + ) + else: # If it's a file or text add it directly + resolved_data.append(item) + else: # If it's not a string add it directly + resolved_data.append(item) + return resolved_data