Skip to content
88 changes: 88 additions & 0 deletions src/Microsoft.ML/LearningPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,113 @@ public ScorerPipelineStep(Var<IDataView> data, Var<ITransformModel> model)
public Var<ITransformModel> Model { get; }
}


/// <summary>
/// The <see cref="LearningPipeline"/> class is used to define the steps needed to perform a desired machine learning task.<para/>
/// The steps are defined by adding a data loader (e.g. <see cref="TextLoader"/>) followed by zero or more transforms (e.g. <see cref="Microsoft.ML.Transforms.TextFeaturizer"/>)
/// and at most one trainer/learner (e.g. <see cref="Microsoft.ML.Trainers.FastTreeBinaryClassifier"/>) in the pipeline.
///
/// </summary>
/// <example>
/// <para/>
/// For example,<para/>
/// <code>
/// var pipeline = new LearningPipeline();
/// pipeline.Add(new TextLoader &lt;SentimentData&gt; (dataPath, separator: ","));
/// pipeline.Add(new TextFeaturizer("Features", "SentimentText"));
/// pipeline.Add(new FastTreeBinaryClassifier());
///
/// var model = pipeline.Train&lt;SentimentData, SentimentPrediction&gt;();
/// </code>
/// </example>
[DebuggerTypeProxy(typeof(LearningPipelineDebugProxy))]
public class LearningPipeline : ICollection<ILearningPipelineItem>
{
private List<ILearningPipelineItem> Items { get; } = new List<ILearningPipelineItem>();

/// <summary>
/// Construct an empty <see cref="LearningPipeline"/> object.
/// </summary>
public LearningPipeline()
{
}

/// <summary>
/// Get the count of ML components in the <see cref="LearningPipeline"/> object
/// </summary>
public int Count => Items.Count;
public bool IsReadOnly => false;

/// <summary>
/// Add a data loader, transform or trainer into the pipeline.
/// Possible data loader(s), transforms and trainers options are
/// <para>
/// Data Loader:
/// <see cref="Microsoft.ML.TextLoader{TInput}" />
/// etc.
/// </para>
/// <para>
/// Transforms:
/// <see cref="Microsoft.ML.Transforms.Dictionarizer"/>,
/// <see cref="Microsoft.ML.Transforms.CategoricalOneHotVectorizer"/>
/// <see cref="Microsoft.ML.Transforms.MinMaxNormalizer"/>,
/// <see cref="Microsoft.ML.Transforms.ColumnCopier"/>,
/// <see cref="Microsoft.ML.Transforms.ColumnConcatenator"/>,
/// <see cref="Microsoft.ML.Transforms.TextFeaturizer"/>,
/// etc.
/// </para>
/// <para>
/// Trainers:
/// <see cref="Microsoft.ML.Trainers.AveragedPerceptronBinaryClassifier"/>,
/// <see cref="Microsoft.ML.Trainers.LogisticRegressor"/>,
/// <see cref="Microsoft.ML.Trainers.StochasticDualCoordinateAscentClassifier"/>,
/// <see cref="Microsoft.ML.Trainers.FastTreeRegressor"/>,
/// etc.
/// </para>
/// For a complete list of transforms and trainers, please see "Microsoft.ML.Transforms" and "Microsoft.ML.Trainers" namespaces.
/// </summary>
/// <param name="item">Any ML component (data loader, transform or trainer) defined as <see cref="ILearningPipelineItem"/>.</param>
public void Add(ILearningPipelineItem item) => Items.Add(item);

/// <summary>
/// Remove all the loaders/transforms/trainers from the pipeline.
/// </summary>
public void Clear() => Items.Clear();

/// <summary>
/// Check if a specific loader/transform/trainer is in the pipeline?
/// </summary>
/// <param name="item">Any ML component (data loader, transform or trainer) defined as <see cref="ILearningPipelineItem"/>.</param>
/// <returns>true if item is found in the pipeline; otherwise, false.</returns>
public bool Contains(ILearningPipelineItem item) => Items.Contains(item);

/// <summary>
/// Copy the pipeline items into an array.
/// </summary>
/// <param name="array">The one-dimensional Array that is the destination of the elements copied from.</param>
/// <param name="arrayIndex">The zero-based index in <paramref name="array" /> at which copying begins.</param>
public void CopyTo(ILearningPipelineItem[] array, int arrayIndex) => Items.CopyTo(array, arrayIndex);
public IEnumerator<ILearningPipelineItem> GetEnumerator() => Items.GetEnumerator();

/// <summary>
/// Remove an item from the pipeline.
/// </summary>
/// <param name="item"><see cref="ILearningPipelineItem"/> to remove.</param>
/// <returns>true if item was removed from the pipeline; otherwise, false.</returns>
public bool Remove(ILearningPipelineItem item) => Items.Remove(item);
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

/// <summary>
/// Train the model using the ML components in the pipeline.
/// </summary>
/// <typeparam name="TInput">Type of data instances the model will be trained on. It's a custom type defined by the user according to the structure of data.
/// <para/>
/// Please see https://www.microsoft.com/net/learn/apps/machine-learning-and-ai/ml-dotnet/get-started/windows for more details on input type.
/// </typeparam>
/// <typeparam name="TOutput">Ouput type. The prediction will be return based on this type.
/// Please see https://www.microsoft.com/net/learn/apps/machine-learning-and-ai/ml-dotnet/get-started/windows for more details on output type.
/// </typeparam>
/// <returns>PredictionModel object. This is the model object used for prediction on new instances. </returns>
public PredictionModel<TInput, TOutput> Train<TInput, TOutput>()
where TInput : class
where TOutput : class, new()
Expand Down