Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ async def batch_search(
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
if len(data_point_ids) == 1:
results = await collection.delete(f"id = '{data_point_ids[0]}'")
else:
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
Comment on lines +167 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential SQL injection vulnerability and add input validation.

The current implementation has several issues that need to be addressed:

  1. Using string interpolation for SQL queries is unsafe and could lead to SQL injection. While LanceDB might have internal protections, it's better to use parameterized queries when available.
  2. There's no validation for empty data_point_ids list.
  3. The tuple conversion for multiple IDs might fail if data_point_ids contains only one item.

Consider applying these improvements:

 async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
     connection = await self.get_connection()
     collection = await connection.open_table(collection_name)
+    if not data_point_ids:
+        return None
     if len(data_point_ids) == 1:
-        results = await collection.delete(f"id = '{data_point_ids[0]}'")
+        results = await collection.delete("id = ?", [data_point_ids[0]])
     else:
-        results = await collection.delete(f"id IN {tuple(data_point_ids)}")
+        placeholders = ','.join(['?' for _ in data_point_ids])
+        results = await collection.delete(f"id IN ({placeholders})", data_point_ids)
     return results

Note: If LanceDB doesn't support parameterized queries, please verify their documentation for the recommended way to safely handle user input in queries.

Committable suggestion was skipped due to low confidence.

return results

async def prune(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
for chunk_index, chunk in enumerate(data_chunks):
chunk_classification = chunk_classifications[chunk_index]
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))

for classification_subclass in chunk_classification.label.subclass:
classification_data_points.append(uuid5(NAMESPACE_OID, classification_subclass.value))
Expand All @@ -39,7 +38,7 @@ class Keyword(BaseModel):
if await vector_engine.has_collection(collection_name):
existing_data_points = await vector_engine.retrieve(
collection_name,
list(set(classification_data_points)),
[str(classification_data) for classification_data in list(set(classification_data_points))],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider centralizing UUID-to-string conversion logic.

To improve maintainability and ensure consistent UUID handling, consider extracting the conversion logic into a utility function:

def ensure_string_id(id_value: str | uuid.UUID) -> str:
    """Ensure ID is in string format for database operations."""
    return str(id_value) if isinstance(id_value, uuid.UUID) else id_value

This could be used throughout the code:

-[str(classification_data) for classification_data in list(set(classification_data_points))]
+[ensure_string_id(classification_data) for classification_data in list(set(classification_data_points))]

) if len(classification_data_points) > 0 else []

existing_points_map = {point.id: True for point in existing_data_points}
Expand Down
Loading