Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0160469
grammars: x{min,max} repetition operator + tweak +/*/? to avoid dupli…
Apr 12, 2024
f2030e3
grammars: handle `x{n}` and fix `x{n,n}`
Apr 12, 2024
de0fd3f
grammars: document new repetition operators
Apr 12, 2024
9d9b5a3
grammars: nit
Apr 12, 2024
6b5518c
grammars: uniform use of int for min & max
Apr 12, 2024
0ceb69a
grammars: refactor parser test
Apr 12, 2024
8938a05
grammar: parsing tests w/ natural pretty print of updated expectations
Apr 12, 2024
0d7347f
grammars: much prettier print of expectations (+ TEST_GRAMMAR_PARSER_…
Apr 12, 2024
2e2df72
grammars: improve test pretty print again
Apr 12, 2024
ffe321d
grammars: pretty print rules and chars
Apr 12, 2024
a9351b8
grammars: fix copy rule skipping
Apr 12, 2024
9d8efa5
grammars: disallow `a{,}` (not allowed in regexps)
Apr 12, 2024
2d98ebf
Update common/grammar-parser.cpp
ochafik Apr 12, 2024
ec91342
grammars: fix copy rule skipping (again) & display of expectations
Apr 12, 2024
22faba6
grammars: more test cases
Apr 12, 2024
1fb7787
Merge remote-tracking branch 'origin/master' into grammar-reps
Apr 15, 2024
15585e0
grammars: update reps parsing to bring ? / * / + closer to before
Apr 19, 2024
93b754e
json: use new GBNF repetitions{m,n} syntax
Apr 19, 2024
2ecc2ae
grammars: update performance gotchas w/ repetition advice
Apr 20, 2024
a9a2983
Merge remote-tracking branch 'origin/master' into grammar-reps
Apr 21, 2024
d47f537
Update examples/json_schema_to_grammar.py
ochafik Apr 24, 2024
724f879
Update examples/server/public/json-schema-to-grammar.mjs
ochafik Apr 24, 2024
a61281f
grammars: comment on rule repetitions
Apr 24, 2024
d03c98e
grammars: ensure unambiguous number alternatives
Apr 24, 2024
21bac1e
grammar: nit typo switched error msgs
Apr 24, 2024
0c74ad3
grammar: nit numbering in comment
Apr 24, 2024
218f41f
json: update numeric rule to be unambiguous
Apr 24, 2024
2813835
Apply suggestions from code review
ochafik Apr 24, 2024
46fe648
Update examples/server/public/json-schema-to-grammar.mjs
ochafik Apr 24, 2024
eb7ccd8
json: fix integral-part
Apr 24, 2024
3c02508
Merge branch 'grammar-reps' of https://github.com/ochafik/llama.cpp i…
Apr 24, 2024
476c97d
Merge remote-tracking branch 'origin/master' into grammar-reps
Apr 30, 2024
990bf57
grammar: add repetition tests
Apr 30, 2024
d070aee
Merge remote-tracking branch 'origin/master' into grammar-reps
May 18, 2024
8266b7c
Merge remote-tracking branch 'origin/master' into grammar-reps
May 21, 2024
2b79d47
Merge remote-tracking branch 'origin/master' into grammar-reps
Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
grammars: refactor parser test
  • Loading branch information
Olivier Chafik committed Apr 12, 2024
commit 0ceb69afbc04a5c522751bdc1c960b09f1a57ccf
157 changes: 54 additions & 103 deletions tests/test-grammar-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,10 @@

#include <cassert>

int main()
{
grammar_parser::parse_state parsed_grammar;

const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+)""";

parsed_grammar = grammar_parser::parse(grammar_bytes);

std::vector<std::pair<std::string, uint32_t>> expected = {
{"expr", 2},
{"expr_5", 5},
{"expr_6", 6},
{"root", 0},
{"root_1", 1},
{"root_4", 4},
{"term", 3},
{"term_7", 7},
};

static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
uint32_t index = 0;
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
{
std::string key = it->first;
Expand All @@ -47,7 +29,47 @@ term ::= [0-9]+)""";

index++;
}
std::vector<llama_grammar_element> expected_rules = {

index = 0;
for (auto rule : parsed_grammar.rules)
{
// compare rule to expected rule
for (uint32_t i = 0; i < rule.size(); i++)
{
llama_grammar_element element = rule[i];
llama_grammar_element expected_element = expected_rules[index];

// pretty print error message before asserting
if (expected_element.type != element.type || expected_element.value != element.value)
{
fprintf(stderr, "index: %u\n", index);
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value);
fprintf(stderr, "expected_element != actual_element\n");
}

assert(expected_element.type == element.type && expected_element.value == element.value);
index++;
}
}
}

int main()
{
verify_parsing(R"""(
root ::= (expr "=" term "\n")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+
)""", {
{"expr", 2},
{"expr_5", 5},
{"expr_6", 6},
{"root", 0},
{"root_1", 1},
{"root_4", 4},
{"term", 3},
{"term_7", 7},
}, {
{LLAMA_GRETYPE_RULE_REF, 4},
{LLAMA_GRETYPE_END, 0},
{LLAMA_GRETYPE_RULE_REF, 2},
Expand Down Expand Up @@ -82,43 +104,16 @@ term ::= [0-9]+)""";
{LLAMA_GRETYPE_CHAR, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_END, 0},
};

index = 0;
for (auto rule : parsed_grammar.rules)
{
// compare rule to expected rule
for (uint32_t i = 0; i < rule.size(); i++)
{
llama_grammar_element element = rule[i];
llama_grammar_element expected_element = expected_rules[index];

// pretty print error message before asserting
if (expected_element.type != element.type || expected_element.value != element.value)
{
fprintf(stderr, "index: %u\n", index);
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value);
fprintf(stderr, "expected_element != actual_element\n");
}

assert(expected_element.type == element.type && expected_element.value == element.value);
index++;
}
}

const char *longer_grammar_bytes = R"""(
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
)""";

parsed_grammar = grammar_parser::parse(longer_grammar_bytes);
});

expected = {
verify_parsing(R"""(
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
)""", {
Comment thread
HanClinto marked this conversation as resolved.
{"expr", 2},
{"expr_6", 6},
{"expr_7", 7},
Expand All @@ -132,28 +127,7 @@ term ::= [0-9]+)""";
{"term", 4},
{"ws", 3},
{"ws_12", 12},
};

index = 0;
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
{
std::string key = it->first;
uint32_t value = it->second;
std::pair<std::string, uint32_t> expected_pair = expected[index];

// pretty print error message before asserting
if (expected_pair.first != key || expected_pair.second != value)
{
fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second);
fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value);
fprintf(stderr, "expected_pair != actual_pair\n");
}

assert(expected_pair.first == key && expected_pair.second == value);

index++;
}
expected_rules = {
}, {
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_END, 0},
{LLAMA_GRETYPE_RULE_REF, 2},
Expand Down Expand Up @@ -221,30 +195,7 @@ term ::= [0-9]+)""";
{LLAMA_GRETYPE_RULE_REF, 12},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0},
};

index = 0;
for (auto rule : parsed_grammar.rules)
{
// compare rule to expected rule
for (uint32_t i = 0; i < rule.size(); i++)
{
llama_grammar_element element = rule[i];
llama_grammar_element expected_element = expected_rules[index];

// pretty print error message before asserting
if (expected_element.type != element.type || expected_element.value != element.value)
{
fprintf(stderr, "index: %u\n", index);
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value);
fprintf(stderr, "expected_element != actual_element\n");
}

assert(expected_element.type == element.type && expected_element.value == element.value);
index++;
}
}
});

return 0;
}