Skip to content

Commit d0522f3

Browse files
authored
Remove the uses of CreatePredictionEngine (dotnet#1637)
* One test reborns * Remove two mores uses and clean copy-and-paste * Clean the rest TF tests * Clean the last CreatePredictionEngine * Further clean up redundant ctor's of PredictionEngine and their relatives * Use paths in TestDatasets * Rollback baselines * Fix
1 parent 3f43c12 commit d0522f3

File tree

4 files changed

+272
-568
lines changed

4 files changed

+272
-568
lines changed

src/Microsoft.ML.Api/ComponentCreation.cs

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -150,46 +150,6 @@ public static BatchPredictionEngine<TSrc, TDst> CreateBatchPredictionEngine<TSrc
150150
return new BatchPredictionEngine<TSrc, TDst>(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
151151
}
152152

153-
/// <summary>
154-
/// Create an on-demand prediction engine.
155-
/// </summary>
156-
/// <param name="env">The host environment to use.</param>
157-
/// <param name="modelStream">The stream to deserialize the pipeline (transforms and predictor) from.</param>
158-
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
159-
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
160-
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
161-
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, Stream modelStream,
162-
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
163-
where TSrc : class
164-
where TDst : class, new()
165-
{
166-
Contracts.CheckValue(env, nameof(env));
167-
env.CheckValue(modelStream, nameof(modelStream));
168-
env.CheckValueOrNull(inputSchemaDefinition);
169-
env.CheckValueOrNull(outputSchemaDefinition);
170-
return new PredictionEngine<TSrc, TDst>(env, modelStream, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
171-
}
172-
173-
/// <summary>
174-
/// Create an on-demand prediction engine.
175-
/// </summary>
176-
/// <param name="env">The host environment to use.</param>
177-
/// <param name="dataPipe">The transformation pipe that may or may not include a scorer.</param>
178-
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
179-
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
180-
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
181-
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, IDataView dataPipe,
182-
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
183-
where TSrc : class
184-
where TDst : class, new()
185-
{
186-
Contracts.CheckValue(env, nameof(env));
187-
env.CheckValue(dataPipe, nameof(dataPipe));
188-
env.CheckValueOrNull(inputSchemaDefinition);
189-
env.CheckValueOrNull(outputSchemaDefinition);
190-
return new PredictionEngine<TSrc, TDst>(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
191-
}
192-
193153
/// <summary>
194154
/// Create an on-demand prediction engine.
195155
/// </summary>
@@ -198,7 +158,7 @@ public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(th
198158
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
199159
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
200160
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
201-
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, ITransformer transformer,
161+
internal static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, ITransformer transformer,
202162
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
203163
where TSrc : class
204164
where TDst : class, new()
@@ -210,23 +170,6 @@ public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(th
210170
return new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
211171
}
212172

213-
/// <summary>
214-
/// Create a prediction engine.
215-
/// This encapsulates the 'classic' prediction problem, where the input is denoted by the float array of features,
216-
/// and the output is a float score. For binary classification predictors that can output probability, there are output
217-
/// fields that report the predicted label and probability.
218-
/// </summary>
219-
/// <param name="env">The host environment to use.</param>
220-
/// <param name="modelStream">The model stream to load pipeline from.</param>
221-
/// <param name="nFeatures">Number of features.</param>
222-
public static SimplePredictionEngine CreateSimplePredictionEngine(this IHostEnvironment env, Stream modelStream, int nFeatures)
223-
{
224-
Contracts.CheckValue(env, nameof(env));
225-
env.CheckValue(modelStream, nameof(modelStream));
226-
env.CheckParam(nFeatures > 0, nameof(nFeatures), "Number of features must be positive.");
227-
return new SimplePredictionEngine(env, modelStream, nFeatures);
228-
}
229-
230173
/// <summary>
231174
/// Load the transforms (but not loader) from the model steram and apply them to the specified data.
232175
/// It is acceptable to have no transforms in the model stream: in this case the original

src/Microsoft.ML.Api/PredictionEngine.cs

Lines changed: 9 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,6 @@ public sealed class PredictionEngine<TSrc, TDst>
140140
private readonly IRowReadableAs<TDst> _outputRow;
141141
private readonly Action _disposer;
142142

143-
internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
144-
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
145-
: this(env, StreamChecker(env, modelStream), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
146-
{
147-
}
148-
149143
private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
150144
{
151145
env.CheckValue(modelStream, nameof(modelStream));
@@ -158,29 +152,12 @@ private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env,
158152
};
159153
}
160154

161-
internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
162-
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
163-
: this(env, new TransformWrapper(env, env.CheckRef(dataPipe, nameof(dataPipe))), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
164-
{
165-
}
166-
167155
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
168156
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
169-
: this(env, TransformerChecker(env, transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
170-
{
171-
}
172-
173-
private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
174-
{
175-
ectx.CheckValue(transformer, nameof(transformer));
176-
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
177-
return transformer.GetRowToRowMapper;
178-
}
179-
180-
private PredictionEngine(IHostEnvironment env, Func<Schema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
181-
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
182157
{
183158
Contracts.CheckValue(env, nameof(env));
159+
env.AssertValue(transformer);
160+
var makeMapper = TransformerChecker(env, transformer);
184161
env.AssertValue(makeMapper);
185162

186163
_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
@@ -190,6 +167,13 @@ private PredictionEngine(IHostEnvironment env, Func<Schema, IRowToRowMapper> mak
190167
_outputRow = cursorable.GetRow(outputRow);
191168
}
192169

170+
private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
171+
{
172+
ectx.CheckValue(transformer, nameof(transformer));
173+
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
174+
return transformer.GetRowToRowMapper;
175+
}
176+
193177
~PredictionEngine()
194178
{
195179
_disposer?.Invoke();
@@ -222,76 +206,4 @@ public void Predict(TSrc example, ref TDst prediction)
222206
_outputRow.FillValues(prediction);
223207
}
224208
}
225-
226-
/// <summary>
227-
/// This class encapsulates the 'classic' prediction problem, where the input is denoted by the float array of features,
228-
/// and the output is a float score. For binary classification predictors that can output probability, there are output
229-
/// fields that report the predicted label and probability.
230-
/// </summary>
231-
public sealed class SimplePredictionEngine
232-
{
233-
private class Example
234-
{
235-
// REVIEW: convert to VBuffer once we have support for them.
236-
public Float[] Features;
237-
}
238-
239-
/// <summary>
240-
/// The prediction output. For every field, if there are no column with the matched name in the scoring pipeline,
241-
/// the field will be left intact by the engine (and keep 0 as value unless the user code changes it).
242-
/// </summary>
243-
public class Prediction
244-
{
245-
public Float Score;
246-
public Float Probability;
247-
}
248-
249-
private readonly PredictionEngine<Example, Prediction> _engine;
250-
private readonly int _nFeatures;
251-
252-
/// <summary>
253-
/// Create a prediction engine.
254-
/// </summary>
255-
/// <param name="env">The host environment to use.</param>
256-
/// <param name="modelStream">The model stream to load pipeline from.</param>
257-
/// <param name="nFeatures">Number of features.</param>
258-
/// <param name="featureColumnName">Name of the features column.</param>
259-
internal SimplePredictionEngine(IHostEnvironment env, Stream modelStream, int nFeatures, string featureColumnName = "Features")
260-
{
261-
Contracts.AssertValue(env);
262-
Contracts.AssertValue(modelStream);
263-
Contracts.Assert(nFeatures > 0);
264-
265-
_nFeatures = nFeatures;
266-
var schema =
267-
new SchemaDefinition
268-
{
269-
new SchemaDefinition.Column
270-
{
271-
MemberName = featureColumnName,
272-
ColumnType = new VectorType(NumberType.Float, nFeatures)
273-
}
274-
};
275-
_engine = new PredictionEngine<Example, Prediction>(env, modelStream, true, schema);
276-
}
277-
278-
/// <summary>
279-
/// Score an example.
280-
/// </summary>
281-
/// <param name="features">The feature array of the example.</param>
282-
/// <returns>The prediction object. New object is created on every call.</returns>
283-
public Prediction Predict(Float[] features)
284-
{
285-
Contracts.CheckValue(features, nameof(features));
286-
if (features.Length != _nFeatures)
287-
throw Contracts.ExceptParam(nameof(features), "Number of features should be {0}, but it is {1}", _nFeatures, features.Length);
288-
289-
var example = new Example { Features = features };
290-
return _engine.Predict(example);
291-
}
292-
public Prediction Predict(VBuffer<Float> features)
293-
{
294-
throw Contracts.ExceptNotImpl("VBuffers aren't supported yet.");
295-
}
296-
}
297209
}

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs

Lines changed: 34 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.ML.Runtime.Data;
1010
using Microsoft.ML.Runtime.Learners;
1111
using Microsoft.ML.Runtime.Model;
12+
using Microsoft.ML.Runtime.RunTests;
1213
using Microsoft.ML.Trainers;
1314
using Microsoft.ML.Transforms.Normalizers;
1415
using System;
@@ -22,54 +23,43 @@ public partial class ScenariosTests
2223
[Fact]
2324
public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
2425
{
25-
string dataPath = GetDataPath("iris.txt");
26-
string testDataPath = dataPath;
26+
var mlContext = new MLContext(seed: 1, conc: 1);
2727

28-
var env = new MLContext(seed: 1, conc: 1);
29-
// Pipeline
30-
var loader = TextLoader.ReadFile(env,
31-
new TextLoader.Arguments()
28+
var reader = mlContext.Data.TextReader(new TextLoader.Arguments()
29+
{
30+
HasHeader = false,
31+
Column = new[]
3232
{
33-
HasHeader = false,
34-
Column = new[]
35-
{
36-
new TextLoader.Column("Label", DataKind.R4, 0),
37-
new TextLoader.Column("SepalLength", DataKind.R4, 1),
38-
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
39-
new TextLoader.Column("PetalLength", DataKind.R4, 3),
40-
new TextLoader.Column("PetalWidth", DataKind.R4, 4)
41-
}
42-
}, new MultiFileSource(dataPath));
43-
44-
IDataView pipeline = new ColumnConcatenatingTransformer(env, "Features",
45-
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(loader);
46-
47-
// NormalizingEstimator is not automatically added though the trainer has 'NormalizeFeatures' On/Auto
48-
pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features");
49-
50-
// Train
51-
var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: s => s.NumThreads = 1);
52-
53-
// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
54-
var cached = new CacheDataView(env, pipeline, prefetch: null);
55-
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
56-
var pred = trainer.Train(trainRoles);
57-
58-
// Get scorer and evaluate the predictions from test data
59-
IDataScorerTransform testDataScorer = GetScorer(env, pipeline, pred, testDataPath);
60-
var metrics = Evaluate(env, testDataScorer);
61-
CompareMatrics(metrics);
33+
new TextLoader.Column("Label", DataKind.R4, 0),
34+
new TextLoader.Column("SepalLength", DataKind.R4, 1),
35+
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
36+
new TextLoader.Column("PetalLength", DataKind.R4, 3),
37+
new TextLoader.Column("PetalWidth", DataKind.R4, 4)
38+
}
39+
});
6240

63-
// Create prediction engine and test predictions
64-
var model = env.CreatePredictionEngine<IrisData, IrisPrediction>(testDataScorer);
65-
ComparePredictions(model);
41+
var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
42+
.Append(mlContext.Transforms.Normalize("Features"))
43+
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1));
6644

67-
// Get feature importance i.e. weight vector
68-
var summary = pred.GetSummaryInKeyValuePairs(trainRoles.Schema);
69-
Assert.Equal(7.76443, Convert.ToDouble(summary[0].Value), 5);
45+
// Read training and test data sets
46+
string dataPath = GetDataPath(TestDatasets.iris.trainFilename);
47+
string testDataPath = dataPath;
48+
var trainData = reader.Read(dataPath);
49+
var testData = reader.Read(testDataPath);
50+
51+
// Train the pipeline
52+
var trainedModel = pipe.Fit(trainData);
53+
54+
// Make prediction and then evaluate the trained pipeline
55+
var predicted = trainedModel.Transform(testData);
56+
var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
57+
CompareMatrics(metrics);
58+
var predictFunction = trainedModel.MakePredictionFunction<IrisData, IrisPrediction>(mlContext);
59+
ComparePredictions(predictFunction);
7060
}
7161

72-
private void ComparePredictions(PredictionEngine<IrisData, IrisPrediction> model)
62+
private void ComparePredictions(PredictionFunction<IrisData, IrisPrediction> model)
7363
{
7464
IrisPrediction prediction = model.Predict(new IrisData()
7565
{
@@ -108,46 +98,17 @@ private void ComparePredictions(PredictionEngine<IrisData, IrisPrediction> model
10898
Assert.Equal(0, prediction.PredictedLabels[2], 2);
10999
}
110100

111-
private void CompareMatrics(ClassificationMetrics metrics)
101+
private void CompareMatrics(MultiClassClassifierEvaluator.Result metrics)
112102
{
113103
Assert.Equal(.98, metrics.AccuracyMacro);
114104
Assert.Equal(.98, metrics.AccuracyMicro, 2);
115-
Assert.Equal(.06, metrics.LogLoss, 2);
105+
Assert.InRange(metrics.LogLoss, .05, .06);
116106
Assert.InRange(metrics.LogLossReduction, 94, 96);
117-
Assert.Equal(1, metrics.TopKAccuracy);
118107

119108
Assert.Equal(3, metrics.PerClassLogLoss.Length);
120109
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
121110
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
122111
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);
123-
124-
ConfusionMatrix matrix = metrics.ConfusionMatrix;
125-
Assert.Equal(3, matrix.Order);
126-
Assert.Equal(3, matrix.ClassNames.Count);
127-
Assert.Equal("0", matrix.ClassNames[0]);
128-
Assert.Equal("1", matrix.ClassNames[1]);
129-
Assert.Equal("2", matrix.ClassNames[2]);
130-
131-
Assert.Equal(50, matrix[0, 0]);
132-
Assert.Equal(50, matrix["0", "0"]);
133-
Assert.Equal(0, matrix[0, 1]);
134-
Assert.Equal(0, matrix["0", "1"]);
135-
Assert.Equal(0, matrix[0, 2]);
136-
Assert.Equal(0, matrix["0", "2"]);
137-
138-
Assert.Equal(0, matrix[1, 0]);
139-
Assert.Equal(0, matrix["1", "0"]);
140-
Assert.Equal(48, matrix[1, 1]);
141-
Assert.Equal(48, matrix["1", "1"]);
142-
Assert.Equal(2, matrix[1, 2]);
143-
Assert.Equal(2, matrix["1", "2"]);
144-
145-
Assert.Equal(0, matrix[2, 0]);
146-
Assert.Equal(0, matrix["2", "0"]);
147-
Assert.Equal(1, matrix[2, 1]);
148-
Assert.Equal(1, matrix["2", "1"]);
149-
Assert.Equal(49, matrix[2, 2]);
150-
Assert.Equal(49, matrix["2", "2"]);
151112
}
152113

153114
private ClassificationMetrics Evaluate(IHostEnvironment env, IDataView scoredData)

0 commit comments

Comments
 (0)