diff --git a/go.mod b/go.mod index be59da595..f1ac9c484 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/mattn/goveralls v0.0.2 // indirect github.com/opencensus-integrations/ocsql v0.1.1 github.com/pkg/errors v0.8.0 // indirect + github.com/pmezard/go-difflib v1.0.0 github.com/prometheus/client_golang v0.9.0 // indirect github.com/prometheus/common v0.0.0-20181015124227-bcb74de08d37 // indirect github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect diff --git a/wire/cmd/wire/main.go b/wire/cmd/wire/main.go index e58b61528..18224f67b 100644 --- a/wire/cmd/wire/main.go +++ b/wire/cmd/wire/main.go @@ -23,6 +23,7 @@ import ( "fmt" "go/token" "go/types" + "io/ioutil" "os" "reflect" "sort" @@ -30,10 +31,11 @@ import ( "strings" "github.com/google/go-cloud/wire/internal/wire" + "github.com/pmezard/go-difflib/difflib" "golang.org/x/tools/go/types/typeutil" ) -const usage = "usage: wire [gen] [PKG] | wire show [...] | wire check [...]" +const usage = "usage: wire [gen|diff|show|check] [...]" func main() { var err error @@ -49,6 +51,10 @@ func main() { err = check(".") case len(os.Args) > 2 && os.Args[1] == "check": err = check(os.Args[2:]...) + case len(os.Args) == 2 && os.Args[1] == "diff": + err = diff(".") + case len(os.Args) > 2 && os.Args[1] == "diff": + err = diff(os.Args[2:]...) case len(os.Args) == 2 && os.Args[1] == "gen": err = generate(".") case len(os.Args) > 2 && os.Args[1] == "gen": @@ -108,6 +114,54 @@ func generate(pkgs ...string) error { return nil } +// diff runs the diff subcommand. +// +// Given one or more packages, diff will generate the content for the +// wire_gen.go file, and output the diff against the existing file. +func diff(pkgs ...string) error { + wd, err := os.Getwd() + if err != nil { + return err + } + outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs) + if len(errs) > 0 { + logErrors(errs) + return errors.New("generate failed") + } + if len(outs) == 0 { + return nil + } + success := true + for _, out := range outs { + if len(out.Errs) > 0 { + fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath) + logErrors(out.Errs) + success = false + } + if len(out.Content) == 0 { + // No Wire output. Maybe errors, maybe no Wire directives. + continue + } + // Assumes the current file is empty if we can't read it. + cur, _ := ioutil.ReadFile(out.OutputPath) + if diff, err := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(string(cur)), + B: difflib.SplitLines(string(out.Content)), + }); err == nil { + if diff != "" { + fmt.Fprintf(os.Stderr, "%s: diff from %s:\n%s", out.PkgPath, out.OutputPath, diff) + } + } else { + fmt.Fprintf(os.Stderr, "%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err) + success = false + } + } + if !success { + return errors.New("at least one generate failure") + } + return nil +} + // show runs the show subcommand. // // Given one or more packages, show will find all the provider sets