Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
write_consistency_factor: int = 1

def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
Expand Down Expand Up @@ -127,6 +128,7 @@ def create_collection(self, collection_name: str, vector_size: int):
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
write_consistency_factor=self._client_config.write_consistency_factor,
)

# create group_id payload index
Expand Down
20 changes: 18 additions & 2 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Optional

Expand All @@ -13,6 +15,8 @@
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Whitelist

logger = logging.getLogger(__name__)


class AbstractVectorFactory(ABC):
@abstractmethod
Expand Down Expand Up @@ -173,8 +177,20 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:

def create(self, texts: Optional[list] = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
start = time.time()
logger.info(f"start embedding {len(texts)} texts {start}")
batch_size = 1000
total_batches = len(texts) + batch_size - 1
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
logger.info(
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")

def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
Expand Down