diff --git a/parser/yaml/yaml.go b/parser/yaml/yaml.go index 671eacd18..31617a61a 100644 --- a/parser/yaml/yaml.go +++ b/parser/yaml/yaml.go @@ -3,6 +3,7 @@ package yaml import ( "bytes" "fmt" + "slices" "sigs.k8s.io/yaml" ) @@ -10,6 +11,12 @@ import ( // Parser is a YAML parser. type Parser struct{} +var ( + lf = []byte{'\n'} + crlf = []byte{'\r', '\n'} + sep = []byte{'-', '-', '-'} +) + // Unmarshal unmarshals YAML files. func (yp *Parser) Unmarshal(p []byte, v interface{}) error { subDocuments := separateSubDocuments(p) @@ -29,12 +36,36 @@ func (yp *Parser) Unmarshal(p []byte, v interface{}) error { } func separateSubDocuments(data []byte) [][]byte { - linebreak := "\n" - if bytes.Contains(data, []byte("\r\n---\r\n")) { - linebreak = "\r\n" + // Determine line ending style + linebreak := lf + if bytes.Contains(data, crlf) { + linebreak = crlf + } + + separator := slices.Concat(linebreak, sep, linebreak) + + // Count actual document separators + parts := bytes.Split(data, separator) + + // If we have a directive, first part is not a separate document + if bytes.HasPrefix(data, []byte("%")) { + if len(parts) <= 2 { + // Single document with directive + return [][]byte{data} + } + // Multiple documents - combine directive with first real document + firstDoc := append(parts[0], append(separator, parts[1]...)...) + result := [][]byte{firstDoc} + result = append(result, parts[2:]...) + return result } - return bytes.Split(data, []byte(linebreak+"---"+linebreak)) + // No directive case + if len(parts) <= 1 { + // Single document + return [][]byte{data} + } + return parts } func unmarshalMultipleDocuments(subDocuments [][]byte, v interface{}) error { diff --git a/parser/yaml/yaml_test.go b/parser/yaml/yaml_test.go index 5d1012607..775023dd6 100644 --- a/parser/yaml/yaml_test.go +++ b/parser/yaml/yaml_test.go @@ -2,6 +2,7 @@ package yaml_test import ( "reflect" + "strings" "testing" "github.com/open-policy-agent/conftest/parser/yaml" @@ -15,6 +16,12 @@ func TestYAMLParser(t *testing.T) { expectedResult interface{} shouldError bool }{ + { + name: "empty config", + controlConfigs: []byte(``), + expectedResult: nil, + shouldError: false, + }, { name: "a single config", controlConfigs: []byte(`sample: true`), @@ -44,6 +51,72 @@ nice: true`), }, shouldError: false, }, + { + name: "a single config with multiple yaml subdocs with crlf line endings", + controlConfigs: []byte(strings.ReplaceAll(`--- +sample: true +--- +hello: true +--- +nice: true`, "\n", "\r\n")), + expectedResult: []interface{}{ + map[string]interface{}{ + "sample": true, + }, + map[string]interface{}{ + "hello": true, + }, + map[string]interface{}{ + "nice": true, + }, + }, + shouldError: false, + }, + { + name: "multiple documents with one invalid yaml", + controlConfigs: []byte(`--- +valid: true +--- +invalid: + - not closed +[ +--- +also_valid: true`), + expectedResult: nil, + shouldError: true, + }, + { + name: "yaml with version directive", + controlConfigs: []byte(`%YAML 1.1 +--- +group_id: 1234`), + expectedResult: map[string]interface{}{ + "group_id": float64(1234), + }, + shouldError: false, + }, + { + name: "yaml with version directive and multiple documents", + controlConfigs: []byte(`%YAML 1.1 +--- +group_id: 1234 +--- +other_id: 5678 +--- +third_id: 9012`), + expectedResult: []interface{}{ + map[string]interface{}{ + "group_id": float64(1234), + }, + map[string]interface{}{ + "other_id": float64(5678), + }, + map[string]interface{}{ + "third_id": float64(9012), + }, + }, + shouldError: false, + }, } for _, test := range testTable { @@ -51,14 +124,13 @@ nice: true`), var unmarshalledConfigs interface{} yamlParser := new(yaml.Parser) - if err := yamlParser.Unmarshal(test.controlConfigs, &unmarshalledConfigs); err != nil { + err := yamlParser.Unmarshal(test.controlConfigs, &unmarshalledConfigs) + if test.shouldError && err == nil { + t.Error("expected error but got none") + } else if !test.shouldError && err != nil { t.Errorf("errors unmarshalling: %v", err) } - if unmarshalledConfigs == nil { - t.Error("error seeing actual value in object, received nil") - } - if !reflect.DeepEqual(test.expectedResult, unmarshalledConfigs) { t.Errorf("Expected\n%T : %v\n to equal\n%T : %v\n", unmarshalledConfigs, unmarshalledConfigs, test.expectedResult, test.expectedResult) }