diff --git a/cmd/count.go b/cmd/count.go index 5489d58..57c8752 100644 --- a/cmd/count.go +++ b/cmd/count.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + ldsview "github.com/kgoins/ldsview/pkg" "github.com/spf13/cobra" ) @@ -12,24 +13,20 @@ var countCmd = &cobra.Command{ Short: "Counts the number of entities in an ldif file", Run: func(cmd *cobra.Command, args []string) { dumpFile, _ := cmd.Flags().GetString("file") - builder := ldsview.NewLdifParser(dumpFile) - - entities := make(chan ldsview.Entity) - done := make(chan bool) + parser := ldsview.NewLdifParser(dumpFile) - // Start the printing goroutine - go ChannelPrinter(entities, done, cmd) - - err := builder.BuildEntities(entities, done) + count, err := parser.CountEntities() if err != nil { fmt.Printf("Unable to parse file: %s\n", err.Error()) return } + + fmt.Println("Entities: ", count) }, } func init() { rootCmd.AddCommand(countCmd) - countCmd.Flags().Bool( "count", true, "" ) + countCmd.Flags().Bool("count", true, "") countCmd.Flags().MarkHidden("count") } diff --git a/cmd/ui.go b/cmd/ui.go index 769d6cc..684f4c2 100644 --- a/cmd/ui.go +++ b/cmd/ui.go @@ -44,7 +44,6 @@ func PrintEntity(entity ldsview.Entity, decodeTS bool) { // when finished func ChannelPrinter(entities chan ldsview.Entity, done chan bool, cmd *cobra.Command) { - count, _ := cmd.Flags().GetBool("count") tdc, _ := cmd.Flags().GetBool("tdc") printLimit, intParseErr := cmd.Flags().GetInt("first") @@ -59,17 +58,12 @@ func ChannelPrinter(entities chan ldsview.Entity, done chan bool, cmd *cobra.Com for entity := range entities { entCount = entCount + 1 - if !count { - PrintEntity(entity, tdc) - } + PrintEntity(entity, tdc) if entCount == printLimit { break } } - if count { - fmt.Println("Entities: ", entCount) - } done <- true } diff --git a/pkg/counter.go b/pkg/counter.go new file mode 100644 index 0000000..bcf5549 --- /dev/null +++ b/pkg/counter.go @@ -0,0 +1,31 @@ +package ldsview + +import ( + "errors" + "os" +) + +// CountEntities returns the number of entities in the input file +func (parser LdifParser) CountEntities() (count int, err error) { + Logger.Info("Opening ldif file: " + parser.filename) + dumpFile, err := os.Open(parser.filename) + if err != nil { + return + } + defer dumpFile.Close() + + Logger.Info("Finding first entity block") + entityScanner := parser.findFirstEntityBlock(dumpFile) + if entityScanner == nil { + return count, errors.New("Unable to find first entity block") + } + + for entityScanner.Scan() { + titleLine := entityScanner.Text() + if parser.isEntityTitle(titleLine) { + count++ + } + } + + return +} diff --git a/pkg/counter_test.go b/pkg/counter_test.go new file mode 100644 index 0000000..978907b --- /dev/null +++ b/pkg/counter_test.go @@ -0,0 +1,18 @@ +package ldsview + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLdifParser_CountEntities(t *testing.T) { + a := assert.New(t) + + parser := NewLdifParser(TESTFILE) + a.NotNil(parser) + + count, err := parser.CountEntities() + a.NoError(err) + a.Equal(count, NUMENTITIES) +} diff --git a/pkg/ldif_parser_test.go b/pkg/ldif_parser_test.go index 9118834..3181f10 100644 --- a/pkg/ldif_parser_test.go +++ b/pkg/ldif_parser_test.go @@ -28,7 +28,7 @@ func TestBuildEntities(t *testing.T) { err := parser.BuildEntities(entities, done) assert.Nil(t, err) - assert.Equal(t, 3, counter.c) + assert.Equal(t, NUMENTITIES, counter.c) }) } diff --git a/pkg/pkg_main_test.go b/pkg/pkg_main_test.go index 065b60c..f9609e3 100644 --- a/pkg/pkg_main_test.go +++ b/pkg/pkg_main_test.go @@ -1,3 +1,4 @@ package ldsview const TESTFILE = "../testdata/test_users.ldif" +const NUMENTITIES = 3