Skip to content

Commit 73d894b

Browse files
authored
CV macro with stratification column doesn't work (dotnet#213)
* Reduce number of hash bits in stratification column and add a unit test. * Address PR comments.
1 parent 2207a27 commit 73d894b

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/Microsoft.ML/Runtime/EntryPoints/TrainTestSplit.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ public static string CreateStratificationColumn(IHost host, ref IDataView data,
9393
new HashJoinTransform.Arguments
9494
{
9595
Column = new[] { new HashJoinTransform.Column { Name = stratCol, Source = stratificationColumn } },
96-
Join = true
96+
Join = true,
97+
HashBits = 30
9798
}, data);
9899
}
99100

test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,5 +330,73 @@ public void TestCrossValidationMacro()
330330
}
331331
}
332332
}
333+
334+
[Fact]
335+
public void TestCrossValidationMacroWithStratification()
336+
{
337+
var dataPath = GetDataPath(@"breast-cancer.txt");
338+
using (var env = new TlcEnvironment())
339+
{
340+
var subGraph = env.CreateExperiment();
341+
342+
var nop = new ML.Transforms.NoOperation();
343+
var nopOutput = subGraph.Add(nop);
344+
345+
var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentBinaryClassifier
346+
{
347+
TrainingData = nopOutput.OutputData,
348+
NumThreads = 1
349+
};
350+
var learnerOutput = subGraph.Add(learnerInput);
351+
352+
var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
353+
{
354+
TransformModels = new ArrayVar<ITransformModel>(nopOutput.Model),
355+
PredictorModel = learnerOutput.PredictorModel
356+
};
357+
var modelCombineOutput = subGraph.Add(modelCombine);
358+
359+
var experiment = env.CreateExperiment();
360+
var importInput = new ML.Data.TextLoader(dataPath);
361+
importInput.Arguments.Column = new ML.Data.TextLoaderColumn[]
362+
{
363+
new ML.Data.TextLoaderColumn { Name = "Label", Source = new[] { new ML.Data.TextLoaderRange(0) } },
364+
new ML.Data.TextLoaderColumn { Name = "Strat", Source = new[] { new ML.Data.TextLoaderRange(1) } },
365+
new ML.Data.TextLoaderColumn { Name = "Features", Source = new[] { new ML.Data.TextLoaderRange(2, 9) } }
366+
};
367+
var importOutput = experiment.Add(importInput);
368+
369+
var crossValidate = new ML.Models.CrossValidator
370+
{
371+
Data = importOutput.Data,
372+
Nodes = subGraph,
373+
TransformModel = null,
374+
StratificationColumn = "Strat"
375+
};
376+
crossValidate.Inputs.Data = nop.Data;
377+
crossValidate.Outputs.Model = modelCombineOutput.PredictorModel;
378+
var crossValidateOutput = experiment.Add(crossValidate);
379+
380+
experiment.Compile();
381+
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
382+
experiment.Run();
383+
var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]);
384+
385+
var schema = data.Schema;
386+
var b = schema.TryGetColumnIndex("AUC", out int metricCol);
387+
Assert.True(b);
388+
using (var cursor = data.GetRowCursor(col => col == metricCol))
389+
{
390+
var getter = cursor.GetGetter<double>(metricCol);
391+
b = cursor.MoveNext();
392+
Assert.True(b);
393+
double val = 0;
394+
getter(ref val);
395+
Assert.Equal(0.99, val, 2);
396+
b = cursor.MoveNext();
397+
Assert.False(b);
398+
}
399+
}
400+
}
333401
}
334402
}

0 commit comments

Comments
 (0)