@@ -330,5 +330,73 @@ public void TestCrossValidationMacro()
330330 }
331331 }
332332 }
333+
334+ [ Fact ]
335+ public void TestCrossValidationMacroWithStratification ( )
336+ {
337+ var dataPath = GetDataPath ( @"breast-cancer.txt" ) ;
338+ using ( var env = new TlcEnvironment ( ) )
339+ {
340+ var subGraph = env . CreateExperiment ( ) ;
341+
342+ var nop = new ML . Transforms . NoOperation ( ) ;
343+ var nopOutput = subGraph . Add ( nop ) ;
344+
345+ var learnerInput = new ML . Trainers . StochasticDualCoordinateAscentBinaryClassifier
346+ {
347+ TrainingData = nopOutput . OutputData ,
348+ NumThreads = 1
349+ } ;
350+ var learnerOutput = subGraph . Add ( learnerInput ) ;
351+
352+ var modelCombine = new ML . Transforms . ManyHeterogeneousModelCombiner
353+ {
354+ TransformModels = new ArrayVar < ITransformModel > ( nopOutput . Model ) ,
355+ PredictorModel = learnerOutput . PredictorModel
356+ } ;
357+ var modelCombineOutput = subGraph . Add ( modelCombine ) ;
358+
359+ var experiment = env . CreateExperiment ( ) ;
360+ var importInput = new ML . Data . TextLoader ( dataPath ) ;
361+ importInput . Arguments . Column = new ML . Data . TextLoaderColumn [ ]
362+ {
363+ new ML . Data . TextLoaderColumn { Name = "Label" , Source = new [ ] { new ML . Data . TextLoaderRange ( 0 ) } } ,
364+ new ML . Data . TextLoaderColumn { Name = "Strat" , Source = new [ ] { new ML . Data . TextLoaderRange ( 1 ) } } ,
365+ new ML . Data . TextLoaderColumn { Name = "Features" , Source = new [ ] { new ML . Data . TextLoaderRange ( 2 , 9 ) } }
366+ } ;
367+ var importOutput = experiment . Add ( importInput ) ;
368+
369+ var crossValidate = new ML . Models . CrossValidator
370+ {
371+ Data = importOutput . Data ,
372+ Nodes = subGraph ,
373+ TransformModel = null ,
374+ StratificationColumn = "Strat"
375+ } ;
376+ crossValidate . Inputs . Data = nop . Data ;
377+ crossValidate . Outputs . Model = modelCombineOutput . PredictorModel ;
378+ var crossValidateOutput = experiment . Add ( crossValidate ) ;
379+
380+ experiment . Compile ( ) ;
381+ experiment . SetInput ( importInput . InputFile , new SimpleFileHandle ( env , dataPath , false , false ) ) ;
382+ experiment . Run ( ) ;
383+ var data = experiment . GetOutput ( crossValidateOutput . OverallMetrics [ 0 ] ) ;
384+
385+ var schema = data . Schema ;
386+ var b = schema . TryGetColumnIndex ( "AUC" , out int metricCol ) ;
387+ Assert . True ( b ) ;
388+ using ( var cursor = data . GetRowCursor ( col => col == metricCol ) )
389+ {
390+ var getter = cursor . GetGetter < double > ( metricCol ) ;
391+ b = cursor . MoveNext ( ) ;
392+ Assert . True ( b ) ;
393+ double val = 0 ;
394+ getter ( ref val ) ;
395+ Assert . Equal ( 0.99 , val , 2 ) ;
396+ b = cursor . MoveNext ( ) ;
397+ Assert . False ( b ) ;
398+ }
399+ }
400+ }
333401 }
334402}
0 commit comments