Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: get_source_code_chunks code graph pipeline task
  • Loading branch information
lxobr committed Dec 18, 2024
commit aea7382983b5250eb443c46adc6a387d2690e22e
112 changes: 112 additions & 0 deletions cognee/tasks/repo_processor/get_source_code_chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import AsyncGenerator, Generator
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodePart, SourceCodeChunk, CodeFile
import tiktoken
import parso

from cognee.tasks.repo_processor import logger


def _count_tokens(tokenizer: tiktoken.Encoding, source_code: str) -> int:
return len(tokenizer.encode(source_code))


def _get_subchunk_token_counts(
tokenizer: tiktoken.Encoding, source_code: str, max_subchunk_tokens: int = 8000
) -> list[tuple[str, int]]:
"""Splits source code into subchunk and counts tokens for each subchunk."""

try:
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return []

if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return []

if len(module.children) <= 2:
module = module.children[0]

subchunk_token_counts = []
for child in module.children:
subchunk = child.get_code()
token_count = _count_tokens(tokenizer, subchunk)
if token_count <= max_subchunk_tokens:
subchunk_token_counts.append((subchunk, token_count))
continue

subchunk_token_counts.extend(_get_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens))

return subchunk_token_counts


def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
current_count = 0
cumulative_counts = []
current_source_code = ''

for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count
cumulative_counts.append(current_count)
if current_count > max_tokens:
break
current_source_code += f"\n{child_code}"

if current_count<= max_tokens:
return [], current_source_code.strip()

cutoff = 1
for i, cum_count in enumerate(cumulative_counts):
if cum_count> (1 - overlap) * max_tokens:
break
cutoff = i

return code_token_counts[cutoff:], current_source_code.strip()


def get_source_code_chunks_from_code_part(
code_file_part: CodePart,
max_tokens: int = 8192,
overlap: float = 0.25,
granularity: float = 0.1,
model_name: str = "text-embedding-3-large"
) -> Generator[SourceCodeChunk, None, None]:
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)

previous_chunk = None
while subchunk_token_counts:
subchunk_token_counts, chunk_source_code = _get_chunk_source_code(subchunk_token_counts, overlap, max_tokens)
if not chunk_source_code:
continue
current_chunk = SourceCodeChunk(
id=uuid5(NAMESPACE_OID, chunk_source_code),
code_chunk_of=code_file_part,
source_code=chunk_source_code,
previous_chunk=previous_chunk
)
yield current_chunk
previous_chunk = current_chunk


async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
for data_point in data_points:
yield data_point
if not isinstance(data_point, CodeFile):
continue
if not data_point.contains:
continue
for code_part in data_point.contains:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
yield source_code_chunk