Skip to content

Commit fca602d

Browse files
committed
Fix Issue #5 - Variable Logging to io.Writers
Switched all models that did print some sort of logging to using any io.Writer. This gives the user control over whether to log to stdout (which is default), log to some file or API or something, or not log at all. Every model that did print has a public struct field such as below (only the comment might differ slightly. The name is always `Output io.Writer` ```go // Output is the io.Writer used for logging // and printing. Defaults to os.Stdout. Output io.Writer ```
2 parents 49cbc63 + c948131 commit fca602d

File tree

12 files changed

+116
-68
lines changed

12 files changed

+116
-68
lines changed

base/data.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ func LoadDataFromCSV(filepath string) ([][]float64, []float64, error) {
4949
return nil, nil, err
5050
}
5151

52-
fmt.Printf("Loading Data From CSV <%v>\n", filepath)
5352
// parse until the end of the file
5453
for err != io.EOF {
5554
var row []float64
@@ -76,7 +75,6 @@ func LoadDataFromCSV(filepath string) ([][]float64, []float64, error) {
7675
return nil, nil, fmt.Errorf("ERROR: Training set has no valid examples (either for x or y or both)")
7776
}
7877

79-
fmt.Printf("Finished Loading Data From <%v>\n\tTraining Examples: %v\n\tFeatures: %v\n", filepath, len(y), len(x[0]))
8078
return x, y, nil
8179
}
8280

@@ -122,7 +120,6 @@ func LoadDataFromCSVToStream(filepath string, data chan Datapoint, errors chan e
122120
return
123121
}
124122

125-
fmt.Printf("Loading Data From CSV <%v> Into Data Channel\n", filepath)
126123
// parse until the end of the file
127124
for err != io.EOF {
128125
var row []float64
@@ -149,8 +146,6 @@ func LoadDataFromCSVToStream(filepath string, data chan Datapoint, errors chan e
149146
record, err = reader.Read()
150147
}
151148

152-
fmt.Printf("Finished Loading Data From <%v> Into Data Channel\n\tClosing error channel\n\tClosing data channel\n", filepath)
153-
154149
close(errors)
155150
close(data)
156151
return
@@ -201,7 +196,6 @@ func SaveDataToCSV(filepath string, x [][]float64, y []float64, highPrecision bo
201196
writer := csv.NewWriter(file)
202197
records := [][]string{}
203198

204-
fmt.Printf("Writing Data To <%v>\n\tTraining Examples: %v\n\tFeatures: %v\n", filepath, len(x), len(x[0]))
205199
// parse until the end of the file
206200
for i := range x {
207201
record := []string{}
@@ -221,6 +215,5 @@ func SaveDataToCSV(filepath string, x [][]float64, y []float64, highPrecision bo
221215
return err
222216
}
223217

224-
fmt.Printf("Finished Writing Data To <%v>\n\n", filepath)
225218
return nil
226219
}

base/munge.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package base
22

33
import (
4-
"fmt"
54
"math"
65
)
76

@@ -16,8 +15,6 @@ func Normalize(x [][]float64) {
1615
for i := range x {
1716
NormalizePoint(x[i])
1817
}
19-
20-
fmt.Printf("Normalized < %v > data points by dividing by unit length\n", len(x))
2118
}
2219

2320
// NormalizePoint is the same as Normalize,

base/optimize.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ func GradientAscent(d Ascendable) error {
5252
}
5353
}
5454

55-
fmt.Printf("Went through %v iterations.\n", iter+1)
56-
5755
return nil
5856
}
5957

@@ -114,7 +112,5 @@ func StochasticGradientAscent(d StochasticAscendable) error {
114112
}
115113
}
116114

117-
fmt.Printf("Went through %v iterations.\n", iter+1)
118-
119115
return nil
120116
}

cluster/kmeans.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cluster
33
import (
44
"encoding/json"
55
"fmt"
6+
"io"
67
"io/ioutil"
78
"math/rand"
89
"os"
@@ -133,6 +134,11 @@ type KMeans struct {
133134
guesses []int
134135

135136
Centroids [][]float64 `json:"centroids"`
137+
138+
// Output is the io.Writer to write
139+
// logging to. Defaults to os.Stdout
140+
// but can be changed to any io.Writer
141+
Output io.Writer
136142
}
137143

138144
// OnlineParams is used to pass optional
@@ -190,6 +196,7 @@ func NewKMeans(k, maxIterations int, trainingSet [][]float64, params ...OnlinePa
190196
guesses: guesses,
191197

192198
Centroids: centroids,
199+
Output: os.Stdout,
193200
}
194201
}
195202

@@ -275,21 +282,21 @@ func (k *KMeans) Predict(x []float64, normalize ...bool) ([]float64, error) {
275282
func (k *KMeans) Learn() error {
276283
if k.trainingSet == nil {
277284
err := fmt.Errorf("ERROR: Attempting to learn with no training examples!\n")
278-
fmt.Printf(err.Error())
285+
fmt.Fprintf(k.Output, err.Error())
279286
return err
280287
}
281288

282289
examples := len(k.trainingSet)
283290
if examples == 0 || len(k.trainingSet[0]) == 0 {
284291
err := fmt.Errorf("ERROR: Attempting to learn with no training examples!\n")
285-
fmt.Printf(err.Error())
292+
fmt.Fprintf(k.Output, err.Error())
286293
return err
287294
}
288295

289296
centroids := len(k.Centroids)
290297
features := len(k.trainingSet[0])
291298

292-
fmt.Printf("Training:\n\tModel: K-Means++ Classification\n\tTraining Examples: %v\n\tFeatures: %v\n\tClasses: %v\n...\n\n", examples, features, centroids)
299+
fmt.Fprintf(k.Output, "Training:\n\tModel: K-Means++ Classification\n\tTraining Examples: %v\n\tFeatures: %v\n\tClasses: %v\n...\n\n", examples, features, centroids)
293300

294301
// instantiate the centroids using k-means++
295302
k.Centroids[0] = k.trainingSet[rand.Intn(len(k.trainingSet))]
@@ -372,7 +379,7 @@ func (k *KMeans) Learn() error {
372379
}
373380
}
374381

375-
fmt.Printf("Training Completed in %v iterations.\n%v\n", iter, k)
382+
fmt.Fprintf(k.Output, "Training Completed in %v iterations.\n%v\n", iter, k)
376383

377384
return nil
378385
}
@@ -499,7 +506,7 @@ func (k *KMeans) OnlineLearn(errors chan error, dataset chan base.Datapoint, onU
499506
centroids := len(k.Centroids)
500507
features := len(k.Centroids[0])
501508

502-
fmt.Printf("Training:\n\tModel: Online K-Means Classification\n\tFeatures: %v\n\tClasses: %v\n...\n\n", features, centroids)
509+
fmt.Fprintf(k.Output, "Training:\n\tModel: Online K-Means Classification\n\tFeatures: %v\n\tClasses: %v\n...\n\n", features, centroids)
503510

504511
var point base.Datapoint
505512
var more bool
@@ -531,7 +538,7 @@ func (k *KMeans) OnlineLearn(errors chan error, dataset chan base.Datapoint, onU
531538
go onUpdate([][]float64{[]float64{float64(c)}, k.Centroids[c]})
532539

533540
} else {
534-
fmt.Printf("Training Completed.\n%v\n\n", k)
541+
fmt.Fprintf(k.Output, "Training Completed.\n%v\n\n", k)
535542
close(errors)
536543
return
537544
}

cluster/triangle_kmeans.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cluster
33
import (
44
"encoding/json"
55
"fmt"
6+
"io"
67
"io/ioutil"
78
"math/rand"
89
"os"
@@ -146,6 +147,10 @@ type TriangleKMeans struct {
146147
// calculations
147148
centroidDist [][]float64
148149
minCentroidDist []float64
150+
151+
// Output is the io.Writer to write logs
152+
// and output from training to
153+
Output io.Writer
149154
}
150155

151156
// pointInfo stores information needed to use
@@ -214,6 +219,8 @@ func NewTriangleKMeans(k, maxIterations int, trainingSet [][]float64) *TriangleK
214219
Centroids: centroids,
215220
centroidDist: centroidDist,
216221
minCentroidDist: minCentroidDist,
222+
223+
Output: os.Stdout,
217224
}
218225
}
219226

@@ -383,21 +390,21 @@ func (k *TriangleKMeans) recalculateCentroids() [][]float64 {
383390
func (k *TriangleKMeans) Learn() error {
384391
if k.trainingSet == nil {
385392
err := fmt.Errorf("ERROR: Attempting to learn with no training examples!\n")
386-
fmt.Printf(err.Error())
393+
fmt.Fprintf(k.Output, err.Error())
387394
return err
388395
}
389396

390397
examples := len(k.trainingSet)
391398
if examples == 0 || len(k.trainingSet[0]) == 0 {
392399
err := fmt.Errorf("ERROR: Attempting to learn with no training examples!\n")
393-
fmt.Printf(err.Error())
400+
fmt.Fprintf(k.Output, err.Error())
394401
return err
395402
}
396403

397404
centroids := len(k.Centroids)
398405
features := len(k.trainingSet[0])
399406

400-
fmt.Printf("Training:\n\tModel: Triangle Inequality Accelerated K-Means++ Classification\n\tTraining Examples: %v\n\tFeatures: %v\n\tClasses: %v\n...\n\n", examples, features, centroids)
407+
fmt.Fprintf(k.Output, "Training:\n\tModel: Triangle Inequality Accelerated K-Means++ Classification\n\tTraining Examples: %v\n\tFeatures: %v\n\tClasses: %v\n...\n\n", examples, features, centroids)
401408

402409
/* Step 0 */
403410

@@ -535,7 +542,7 @@ func (k *TriangleKMeans) Learn() error {
535542
k.Centroids = newCentroids
536543
}
537544

538-
fmt.Printf("Training Completed in %v iterations.\n%v\n", iter, k)
545+
fmt.Fprintf(k.Output, "Training Completed in %v iterations.\n%v\n", iter, k)
539546

540547
return nil
541548
}

linear/linear.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import (
5353
"bytes"
5454
"encoding/json"
5555
"fmt"
56+
"io"
5657
"io/ioutil"
5758
"math"
5859
"os"
@@ -92,6 +93,10 @@ type LeastSquares struct {
9293
expectedResults []float64
9394

9495
Parameters []float64 `json:"theta"`
96+
97+
// Output is the io.Writer used for logging
98+
// and printing. Defaults to os.Stdout.
99+
Output io.Writer
95100
}
96101

97102
// NewLeastSquares returns a pointer to the linear model
@@ -148,6 +153,8 @@ func NewLeastSquares(method base.OptimizationMethod, alpha, regularization float
148153
// initialize θ as the zero vector (that is,
149154
// the vector of all zeros)
150155
Parameters: params,
156+
157+
Output: os.Stdout,
151158
}
152159
}
153160

@@ -229,23 +236,23 @@ func (l *LeastSquares) Predict(x []float64, normalize ...bool) ([]float64, error
229236
func (l *LeastSquares) Learn() error {
230237
if l.trainingSet == nil || l.expectedResults == nil {
231238
err := fmt.Errorf("ERROR: Attempting to learn with no training examples!\n")
232-
fmt.Printf(err.Error())
239+
fmt.Fprintf(l.Output, err.Error())
233240
return err
234241
}
235242

236243
examples := len(l.trainingSet)
237244
if examples == 0 || len(l.trainingSet[0]) == 0 {
238245
err := fmt.Errorf("ERROR: Attempting to learn with no training examples!\n")
239-
fmt.Printf(err.Error())
246+
fmt.Fprintf(l.Output, err.Error())
240247
return err
241248
}
242249
if len(l.expectedResults) == 0 {
243250
err := fmt.Errorf("ERROR: Attempting to learn with no expected results! This isn't an unsupervised model!! You'll need to include data before you learn :)\n")
244-
fmt.Printf(err.Error())
251+
fmt.Fprintf(l.Output, err.Error())
245252
return err
246253
}
247254

248-
fmt.Printf("Training:\n\tModel: Logistic (Binary) Classification\n\tOptimization Method: %v\n\tTraining Examples: %v\n\tFeatures: %v\n\tLearning Rate α: %v\n\tRegularization Parameter λ: %v\n...\n\n", l.method, examples, len(l.trainingSet[0]), l.alpha, l.regularization)
255+
fmt.Fprintf(l.Output, "Training:\n\tModel: Logistic (Binary) Classification\n\tOptimization Method: %v\n\tTraining Examples: %v\n\tFeatures: %v\n\tLearning Rate α: %v\n\tRegularization Parameter λ: %v\n...\n\n", l.method, examples, len(l.trainingSet[0]), l.alpha, l.regularization)
249256

250257
var err error
251258
if l.method == base.BatchGA {
@@ -257,11 +264,11 @@ func (l *LeastSquares) Learn() error {
257264
}
258265

259266
if err != nil {
260-
fmt.Printf("\nERROR: Error while learning –\n\t%v\n\n", err)
267+
fmt.Fprintf(l.Output, "\nERROR: Error while learning –\n\t%v\n\n", err)
261268
return err
262269
}
263270

264-
fmt.Printf("Training Completed.\n%v\n\n", l)
271+
fmt.Fprintf(l.Output, "Training Completed.\n%v\n\n", l)
265272
return nil
266273
}
267274

@@ -368,7 +375,7 @@ func (l *LeastSquares) OnlineLearn(errors chan error, dataset chan base.Datapoin
368375
return
369376
}
370377

371-
fmt.Printf("Training:\n\tModel: Ordinary Least Squares Regression\n\tOptimization Method: Online Stochastic Gradient Descent\n\tFeatures: %v\n\tLearning Rate α: %v\n...\n\n", len(l.Parameters), l.alpha)
378+
fmt.Fprintf(l.Output, "Training:\n\tModel: Ordinary Least Squares Regression\n\tOptimization Method: Online Stochastic Gradient Descent\n\tFeatures: %v\n\tLearning Rate α: %v\n...\n\n", len(l.Parameters), l.alpha)
372379

373380
var point base.Datapoint
374381
var more bool
@@ -439,7 +446,7 @@ func (l *LeastSquares) OnlineLearn(errors chan error, dataset chan base.Datapoin
439446
go onUpdate([][]float64{l.Parameters})
440447

441448
} else {
442-
fmt.Printf("Training Completed.\n%v\n\n", l)
449+
fmt.Fprintf(l.Output, "Training Completed.\n%v\n\n", l)
443450
close(errors)
444451
return
445452
}
@@ -452,7 +459,7 @@ func (l *LeastSquares) OnlineLearn(errors chan error, dataset chan base.Datapoin
452459
func (l *LeastSquares) String() string {
453460
features := len(l.Parameters) - 1
454461
if len(l.Parameters) == 0 {
455-
fmt.Printf("ERROR: Attempting to print model with the 0 vector as it's parameter vector! Train first!\n")
462+
fmt.Fprintf(l.Output, "ERROR: Attempting to print model with the 0 vector as it's parameter vector! Train first!\n")
456463
}
457464
var buffer bytes.Buffer
458465

linear/local_linear.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package linear
33
import (
44
"bytes"
55
"fmt"
6+
"io"
67
"math"
8+
"os"
79

810
"github.com/cdipaolo/goml/base"
911
)
@@ -100,6 +102,10 @@ type LocalLinear struct {
100102
expectedResults []float64
101103

102104
Parameters []float64 `json:"theta"`
105+
106+
// Output is the io.Writer used for logging
107+
// and printing. Defaults to os.Stdout.
108+
Output io.Writer
103109
}
104110

105111
// NewLocalLinear returns a pointer to the linear model
@@ -156,6 +162,8 @@ func NewLocalLinear(method base.OptimizationMethod, alpha, regularization, bandw
156162
// initialize θ as the zero vector (that is,
157163
// the vector of all zeros)
158164
Parameters: params,
165+
166+
Output: os.Stdout,
159167
}
160168
}
161169

@@ -242,7 +250,7 @@ func (l *LocalLinear) Predict(x []float64, normalize ...bool) ([]float64, error)
242250
return nil, err
243251
}
244252

245-
fmt.Printf("Training:\n\tModel: Locally Weighted Linear Regression\n\tOptimization Method: %v\n\tCenter Point: %v\n\tTraining Examples: %v\n\tFeatures: %v\n\tLearning Rate α: %v\n\tRegularization Parameter λ: %v\n...\n\n", l.method, x, examples, len(l.trainingSet[0]), l.alpha, l.regularization)
253+
fmt.Fprintf(l.Output, "Training:\n\tModel: Locally Weighted Linear Regression\n\tOptimization Method: %v\n\tCenter Point: %v\n\tTraining Examples: %v\n\tFeatures: %v\n\tLearning Rate α: %v\n\tRegularization Parameter λ: %v\n...\n\n", l.method, x, examples, len(l.trainingSet[0]), l.alpha, l.regularization)
246254

247255
var iter int
248256
features := len(l.Parameters)
@@ -295,7 +303,7 @@ func (l *LocalLinear) Predict(x []float64, normalize ...bool) ([]float64, error)
295303
return nil, fmt.Errorf("Chose a training method not implemented for LocalLinear regression")
296304
}
297305

298-
fmt.Printf("Training Completed. Went through %v iterations.\n%v\n\n", iter, l)
306+
fmt.Fprintf(l.Output, "Training Completed. Went through %v iterations.\n%v\n\n", iter, l)
299307

300308
// include constant term in sum
301309
sum := l.Parameters[0]
@@ -313,7 +321,7 @@ func (l *LocalLinear) Predict(x []float64, normalize ...bool) ([]float64, error)
313321
func (l *LocalLinear) String() string {
314322
features := len(l.Parameters) - 1
315323
if len(l.Parameters) == 0 {
316-
fmt.Printf("ERROR: Attempting to print model with the 0 vector as it's parameter vector! Train first!\n")
324+
fmt.Fprintf(l.Output, "ERROR: Attempting to print model with the 0 vector as it's parameter vector! Train first!\n")
317325
}
318326
var buffer bytes.Buffer
319327

0 commit comments

Comments
 (0)