Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
optimize: batch embedding and qdrant write_consistency_factor parameter
  • Loading branch information
hobo-l-20230331 committed Jul 1, 2025
commit 44aedbb38935ac0a7ed4c719faf1c2e8f7b80c44
2 changes: 2 additions & 0 deletions api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
write_consistency_factor: int = 1

def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
Expand Down Expand Up @@ -127,6 +128,7 @@ def create_collection(self, collection_name: str, vector_size: int):
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
write_consistency_factor=self._client_config.write_consistency_factor,
)

# create group_id payload index
Expand Down
20 changes: 18 additions & 2 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Optional

Expand All @@ -13,6 +15,8 @@
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Whitelist

logger = logging.getLogger(__name__)


class AbstractVectorFactory(ABC):
@abstractmethod
Expand Down Expand Up @@ -173,8 +177,20 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:

def create(self, texts: Optional[list] = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
start = time.time()
logger.info(f"start embedding {len(texts)} texts {start}")
batch_size = 1000
total_batches = len(texts) + batch_size - 1
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
logger.info(
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")

def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
Expand Down