Skip to content

Commit 584a1cc

Browse files
committed
feat: add context-awareness for Java projects
1 parent 91dd7bd commit 584a1cc

File tree

22 files changed

+757
-13
lines changed

22 files changed

+757
-13
lines changed

codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParser.java renamed to codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/completion/CodeCompletionParser.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package ee.carlrobert.codegpt.treesitter;
1+
package ee.carlrobert.codegpt.treesitter.completion;
22

33
import org.treesitter.TSLanguage;
44
import org.treesitter.TSNode;

codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserFactory.java renamed to codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/completion/CodeCompletionParserFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package ee.carlrobert.codegpt.treesitter;
1+
package ee.carlrobert.codegpt.treesitter.completion;
22

33
import org.treesitter.TSLanguage;
44
import org.treesitter.TreeSitterCSharp;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package ee.carlrobert.codegpt.treesitter.repository;
2+
3+
public class ProcessedTag {
4+
5+
private final String filePath;
6+
private final String modifiedContent;
7+
8+
public ProcessedTag(String filePath, String modifiedContent) {
9+
this.filePath = filePath;
10+
this.modifiedContent = modifiedContent;
11+
}
12+
13+
public String getFilePath() {
14+
return filePath;
15+
}
16+
17+
public String getModifiedContent() {
18+
return modifiedContent;
19+
}
20+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package ee.carlrobert.codegpt.treesitter.repository;
2+
3+
import static java.util.stream.Collectors.toSet;
4+
5+
import ee.carlrobert.codegpt.treesitter.repository.parser.RepositoryParser.FileMapping;
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
import java.util.Set;
9+
import java.util.function.Function;
10+
import org.treesitter.TSLanguage;
11+
import org.treesitter.TSNode;
12+
import org.treesitter.TSParser;
13+
import org.treesitter.TSQuery;
14+
import org.treesitter.TSQueryCursor;
15+
import org.treesitter.TSQueryMatch;
16+
import org.treesitter.TSTree;
17+
18+
public class QueryUtil {
19+
20+
public static Set<Tag> extractTagsFromFile(
21+
TSLanguage language,
22+
FileMapping fileMapping,
23+
List<String> queries,
24+
TagType tagType) {
25+
var fileContent = fileMapping.getContent();
26+
var rootNode = QueryUtil.getTree(language, fileContent).getRootNode();
27+
return queries.stream()
28+
.map(query -> new TSQuery(language, query))
29+
.flatMap(query ->
30+
query(query, rootNode, node -> new Tag(
31+
fileMapping.getPath(),
32+
fileMapping.getContent().substring(node.getStartByte(), node.getEndByte()),
33+
language.symbolName(node.getSymbol()),
34+
tagType)).stream())
35+
.collect(toSet());
36+
}
37+
38+
public static List<QueryResult> query(TSLanguage language, String code, String query) {
39+
return query(
40+
new TSQuery(language, query),
41+
getTree(language, code).getRootNode(),
42+
node -> new QueryResult(
43+
code.substring(node.getStartByte(), node.getEndByte()),
44+
language.symbolName(node.getSymbol())));
45+
}
46+
47+
private static TSTree getTree(TSLanguage language, String input) {
48+
var parser = new TSParser();
49+
parser.setLanguage(language);
50+
return parser.parseString(null, input);
51+
}
52+
53+
private static <T> List<T> query(TSQuery query, TSNode rootNode, Function<TSNode, T> onCapture) {
54+
var cursor = new TSQueryCursor();
55+
cursor.exec(query, rootNode);
56+
var matches = new ArrayList<T>();
57+
var match = new TSQueryMatch();
58+
while (cursor.nextMatch(match)) {
59+
for (var capture : match.getCaptures()) {
60+
try {
61+
matches.add(onCapture.apply(capture.getNode()));
62+
} catch (Throwable t) {
63+
// todo: log
64+
}
65+
}
66+
}
67+
return matches;
68+
}
69+
70+
public static class QueryResult {
71+
72+
private final String name;
73+
private final String symbolName;
74+
75+
public QueryResult(String name, String symbolName) {
76+
this.name = name;
77+
this.symbolName = symbolName;
78+
}
79+
80+
public String getName() {
81+
return name;
82+
}
83+
84+
public String getSymbolName() {
85+
return symbolName;
86+
}
87+
}
88+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package ee.carlrobert.codegpt.treesitter.repository;
2+
3+
import com.fasterxml.jackson.annotation.JsonCreator;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import java.nio.file.Path;
6+
import java.util.Objects;
7+
8+
public class Tag {
9+
10+
private final String filePath;
11+
private final String name;
12+
private final String symbolName;
13+
private final TagType type;
14+
15+
@JsonCreator
16+
public Tag(
17+
@JsonProperty("filePath") String filePath,
18+
@JsonProperty("name") String name,
19+
@JsonProperty("symbolName") String symbolName,
20+
@JsonProperty("type") TagType type) {
21+
this.filePath = filePath;
22+
this.name = name;
23+
this.symbolName = symbolName;
24+
this.type = type;
25+
}
26+
27+
public String getFilePath() {
28+
return filePath;
29+
}
30+
31+
public String getName() {
32+
return name;
33+
}
34+
35+
public String getSymbolName() {
36+
return symbolName;
37+
}
38+
39+
public TagType getType() {
40+
return type;
41+
}
42+
43+
@Override
44+
public String toString() {
45+
return "Tag{" +
46+
"fileName='" + Path.of(filePath).getFileName().toString() + '"' +
47+
"name='" + name + '"' +
48+
", symbolName='" + symbolName + '"' +
49+
'}' + "\n";
50+
}
51+
52+
53+
@Override
54+
public boolean equals(Object o) {
55+
if (this == o) {
56+
return true;
57+
}
58+
if (!(o instanceof Tag)) {
59+
return false;
60+
}
61+
Tag tag = (Tag) o;
62+
return Objects.equals(filePath, tag.filePath)
63+
&& Objects.equals(name, tag.name)
64+
&& Objects.equals(symbolName, tag.symbolName);
65+
}
66+
67+
@Override
68+
public int hashCode() {
69+
return Objects.hash(filePath, name, symbolName);
70+
}
71+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package ee.carlrobert.codegpt.treesitter.repository;
2+
3+
public enum TagType {
4+
DEFINITION, REFERENCE
5+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package ee.carlrobert.codegpt.treesitter.repository.parser;
2+
3+
import static java.util.stream.Collectors.toList;
4+
import static java.util.stream.Collectors.toSet;
5+
6+
import com.intellij.openapi.project.Project;
7+
import ee.carlrobert.codegpt.treesitter.repository.ProcessedTag;
8+
import ee.carlrobert.codegpt.treesitter.repository.Tag;
9+
import ee.carlrobert.codegpt.treesitter.repository.TagType;
10+
import java.io.IOException;
11+
import java.nio.file.Files;
12+
import java.nio.file.Path;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.Map.Entry;
16+
import java.util.Set;
17+
import java.util.regex.Pattern;
18+
import javax.annotation.Nullable;
19+
import org.treesitter.TreeSitterJava;
20+
21+
public class JavaRepositoryParser extends RepositoryParser {
22+
23+
JavaRepositoryParser(Project project) {
24+
super(project, new TreeSitterJava(), "java");
25+
}
26+
27+
@Override
28+
protected List<String> getDefinitionQueries() {
29+
return List.of(
30+
"(class_declaration name: (identifier) @name.definition.class) @definition.class",
31+
"(method_declaration name: (identifier) @name.definition.method) @definition.method",
32+
"(interface_declaration name: (identifier) @name.definition.interface) @definition.interface",
33+
"(package_declaration) @definition.package");
34+
}
35+
36+
@Override
37+
protected List<String> getReferenceQueries() {
38+
return List.of(
39+
"(method_invocation name: (identifier) @name.reference.call arguments: (argument_list) @reference.call)",
40+
"(type_list (type_identifier) @name.reference.implementation) @reference.implementation",
41+
"(object_creation_expression type: (type_identifier) @name.reference.class) @reference.class",
42+
"(superclass (type_identifier) @name.reference.class) @reference.class");
43+
}
44+
45+
@Override
46+
protected Set<ProcessedTag> processMatchedTags(
47+
Map<String, List<Tag>> matchedTags,
48+
Set<Tag> definitionTags,
49+
String codeSnippetPath) {
50+
var dependencies = findDependencies(codeSnippetPath, definitionTags);
51+
return matchedTags.entrySet().stream()
52+
.filter(entry -> {
53+
var className = Path.of(entry.getKey()).getFileName().toString().replace(".java", "");
54+
return dependencies.stream().anyMatch(dep -> dep.endsWith(className));
55+
})
56+
.map(this::processMatchEntry)
57+
.limit(10)
58+
.collect(toSet());
59+
}
60+
61+
private @Nullable String getPackageDeclaration(String code) {
62+
var result = query(code, "(package_declaration) @definition.package");
63+
if (result.isEmpty()) {
64+
return null;
65+
}
66+
return result.get(0).getName();
67+
}
68+
69+
private Set<String> getImports(String code) {
70+
return query(code, "(import_declaration) @definition.import").stream()
71+
.map(it -> extractPackageName(it.getName()))
72+
.collect(toSet());
73+
}
74+
75+
private Set<String> findDependencies(String codeSnippetPath, Set<Tag> definitionTags) {
76+
var fileContent = readFileContent(Path.of(codeSnippetPath));
77+
var packageDeclaration = getPackageDeclaration(fileContent);
78+
var packageLevelDependencies = definitionTags.stream()
79+
.filter(
80+
defTag -> "package_declaration".equals(defTag.getSymbolName())
81+
&& packageDeclaration != null
82+
&& defTag.getName().contentEquals(packageDeclaration))
83+
.map(declaration -> {
84+
var fileName = Path.of(declaration.getFilePath()).getFileName().toString();
85+
return String.format(
86+
"%s.%s",
87+
extractPackageName(declaration.getName()),
88+
fileName.replace(".java", ""));
89+
})
90+
.collect(toSet());
91+
packageLevelDependencies.addAll(getImports(fileContent));
92+
return packageLevelDependencies;
93+
}
94+
95+
private ProcessedTag processMatchEntry(Entry<String, List<Tag>> entry) {
96+
var fileContent = readFileContent(Path.of(entry.getKey()));
97+
var tags = getTags(
98+
List.of(new FileMapping(entry.getKey(), fileContent)),
99+
getDefinitionQueries(),
100+
TagType.DEFINITION);
101+
var methodDeclarations = tags
102+
.stream()
103+
.filter(
104+
tag -> "method_declaration".equals(tag.getSymbolName())
105+
&& entry.getValue().stream().anyMatch(target ->
106+
tag.getName().contains(target.getName())))
107+
.map(Tag::getName)
108+
.collect(toList());
109+
return new ProcessedTag(entry.getKey(), String.join("\n\n", methodDeclarations));
110+
}
111+
112+
private static String extractPackageName(String line) {
113+
var regex = "(?:import|package)\\s+([\\w.]+)\\.*";
114+
var pattern = Pattern.compile(regex);
115+
var matcher = pattern.matcher(line);
116+
if (matcher.find()) {
117+
return matcher.group(1);
118+
}
119+
return null;
120+
}
121+
122+
private static String readFileContent(Path path) {
123+
try {
124+
return Files.readString(path);
125+
} catch (IOException e) {
126+
return "";
127+
}
128+
}
129+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package ee.carlrobert.codegpt.treesitter.repository.parser;
2+
3+
import com.intellij.openapi.project.Project;
4+
import ee.carlrobert.codegpt.treesitter.repository.ProcessedTag;
5+
import ee.carlrobert.codegpt.treesitter.repository.Tag;
6+
import java.util.List;
7+
import java.util.Map;
8+
import java.util.Set;
9+
import org.treesitter.TreeSitterPhp;
10+
11+
public class PhpRepositoryParser extends RepositoryParser {
12+
13+
PhpRepositoryParser(Project project) {
14+
super(project, new TreeSitterPhp(), "php");
15+
}
16+
17+
@Override
18+
protected List<String> getDefinitionQueries() {
19+
return List.of(
20+
"(namespace_definition\n"
21+
+ " name: (namespace_name) @name) @module",
22+
"(interface_declaration\n"
23+
+ " name: (name) @name) @definition.interface",
24+
"(trait_declaration\n"
25+
+ " name: (name) @name) @definition.interface",
26+
"(class_declaration\n"
27+
+ " name: (name) @name) @definition.class",
28+
"(class_interface_clause [(name) (qualified_name)] @name) @impl",
29+
"(property_declaration\n"
30+
+ " (property_element (variable_name (name) @name))) @definition.field",
31+
"(function_definition\n"
32+
+ " name: (name) @name) @definition.function",
33+
"(method_declaration\n"
34+
+ " name: (name) @name) @definition.function");
35+
}
36+
37+
@Override
38+
protected List<String> getReferenceQueries() {
39+
return List.of(
40+
"(object_creation_expression\n"
41+
+ " [\n"
42+
+ " (qualified_name (name) @name)\n"
43+
+ " (variable_name (name) @name)\n"
44+
+ " ]) @reference.class",
45+
"(function_call_expression\n"
46+
+ " function: [\n"
47+
+ " (qualified_name (name) @name)\n"
48+
+ " (variable_name (name)) @name\n"
49+
+ " ]) @reference.call",
50+
"(scoped_call_expression\n"
51+
+ " name: (name) @name) @reference.call",
52+
"(member_call_expression\n"
53+
+ " name: (name) @name) @reference.call");
54+
}
55+
56+
@Override
57+
protected Set<ProcessedTag> processMatchedTags(
58+
Map<String, List<Tag>> matchedTags,
59+
Set<Tag> definitionTags,
60+
String codeSnippetPath) {
61+
return Set.of();
62+
}
63+
}

0 commit comments

Comments
 (0)