Skip to content

Commit b6f94bc

Browse files
author
Shahab Moradi
authored
Updated xml docs for tree-based trainers. (dotnet#2970)
* Updated xml docs for tree-based trainers. * Addressed PR comments.
1 parent 9cd9a8c commit b6f94bc

File tree

17 files changed

+362
-247
lines changed

17 files changed

+362
-247
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs

Lines changed: 0 additions & 40 deletions
This file was deleted.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML.Data;
5+
6+
namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression
7+
{
8+
public static class FastTree
9+
{
10+
// This example requires installation of additional NuGet package
11+
// <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>.
12+
public static void Example()
13+
{
14+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
15+
// as a catalog of available operations and as the source of randomness.
16+
// Setting the seed to a fixed number in this example to make outputs deterministic.
17+
var mlContext = new MLContext(seed: 0);
18+
19+
// Create a list of training examples.
20+
var examples = GenerateRandomDataPoints(1000);
21+
22+
// Convert the examples list to an IDataView object, which is consumable by ML.NET API.
23+
var data = mlContext.Data.LoadFromEnumerable(examples);
24+
25+
// Define the trainer.
26+
var pipeline = mlContext.BinaryClassification.Trainers.FastTree();
27+
28+
// Train the model.
29+
var model = pipeline.Fit(data);
30+
}
31+
32+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count)
33+
{
34+
var random = new Random(0);
35+
float randomFloat() => (float)random.NextDouble();
36+
for (int i = 0; i < count; i++)
37+
{
38+
var label = randomFloat();
39+
yield return new DataPoint
40+
{
41+
Label = label,
42+
// Create random features that are correlated with label.
43+
Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray()
44+
};
45+
}
46+
}
47+
48+
private class DataPoint
49+
{
50+
public float Label { get; set; }
51+
[VectorType(50)]
52+
public float[] Features { get; set; }
53+
}
54+
}
55+
}

src/Microsoft.ML.Data/Training/TrainerInputBase.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,17 @@ private protected TrainerInputBase() { }
4141
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
4242

4343
/// <summary>
44-
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
44+
/// Whether trainer should cache input training data. Used only in entry-points, since the intended API mechanism
4545
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
4646
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
4747
/// </summary>
4848
[BestFriend]
49-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
49+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether trainer should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
5050
internal CachingOptions Caching = CachingOptions.Auto;
5151
}
5252

5353
/// <summary>
54-
/// The base class for all learner inputs that support a Label column.
54+
/// The base class for all trainer inputs that support a Label column.
5555
/// </summary>
5656
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
5757
public abstract class TrainerInputBaseWithLabel : TrainerInputBase
@@ -67,7 +67,7 @@ private protected TrainerInputBaseWithLabel() { }
6767

6868
// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
6969
/// <summary>
70-
/// The base class for all learner inputs that support a weight column.
70+
/// The base class for all trainer inputs that support a weight column.
7171
/// </summary>
7272
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
7373
public abstract class TrainerInputBaseWithWeight : TrainerInputBaseWithLabel
@@ -82,7 +82,7 @@ private protected TrainerInputBaseWithWeight() { }
8282
}
8383

8484
/// <summary>
85-
/// The base class for all unsupervised learner inputs that support a weight column.
85+
/// The base class for all unsupervised trainer inputs that support a weight column.
8686
/// </summary>
8787
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
8888
public abstract class UnsupervisedTrainerInputBaseWithWeight : TrainerInputBase
@@ -96,6 +96,9 @@ private protected UnsupervisedTrainerInputBaseWithWeight() { }
9696
public string ExampleWeightColumnName = null;
9797
}
9898

99+
/// <summary>
100+
/// The base class for all trainer inputs that support a group column.
101+
/// </summary>
99102
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
100103
public abstract class TrainerInputBaseWithGroupId : TrainerInputBaseWithWeight
101104
{

src/Microsoft.ML.FastTree/FastTreeArguments.cs

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,18 @@ public enum EarlyStoppingRankingMetric
5151
NdcgAt3 = 3
5252
}
5353

54-
/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
54+
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
5555
public sealed partial class FastTreeBinaryClassificationTrainer
5656
{
57+
/// <summary>
58+
/// Options for the <see cref="FastTreeBinaryClassificationTrainer"/>.
59+
/// </summary>
5760
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
5861
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
5962
{
6063

6164
/// <summary>
62-
/// Option for using derivatives optimized for unbalanced sets.
65+
/// Whether to use derivatives optimized for unbalanced training data.
6366
/// </summary>
6467
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Option for using derivatives optimized for unbalanced sets", ShortName = "us")]
6568
[TGUI(Label = "Optimize for unbalanced")]
@@ -90,6 +93,9 @@ public EarlyStoppingMetric EarlyStoppingMetric
9093
}
9194
}
9295

96+
/// <summary>
97+
/// Create a new <see cref="Options"/> object with default values.
98+
/// </summary>
9399
public Options()
94100
{
95101
// Use L1 by default.
@@ -100,8 +106,12 @@ public Options()
100106
}
101107
}
102108

109+
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
103110
public sealed partial class FastTreeRegressionTrainer
104111
{
112+
/// <summary>
113+
/// Options for the <see cref="FastTreeRegressionTrainer"/>.
114+
/// </summary>
105115
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
106116
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
107117
{
@@ -127,6 +137,9 @@ public EarlyStoppingMetric EarlyStoppingMetric
127137
}
128138
}
129139

140+
/// <summary>
141+
/// Create a new <see cref="Options"/> object with default values.
142+
/// </summary>
130143
public Options()
131144
{
132145
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
@@ -136,14 +149,22 @@ public Options()
136149
}
137150
}
138151

152+
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
139153
public sealed partial class FastTreeTweedieTrainer
140154
{
155+
/// <summary>
156+
/// Options for the <see cref="FastTreeTweedieTrainer"/>.
157+
/// </summary>
141158
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
142159
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
143160
{
144161
// REVIEW: It is possible to estimate this index parameter from the distribution of data, using
145162
// a combination of univariate optimization and grid search, following section 4.2 of the paper. However
146163
// it is probably not worth doing unless and until explicitly asked for.
164+
/// <summary>
165+
/// The index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss,
166+
/// and intermediate values are compound Poisson loss.
167+
/// </summary>
147168
[Argument(ArgumentType.LastOccurenceWins, HelpText =
148169
"Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " +
149170
"and intermediate values are compound Poisson loss.")]
@@ -174,6 +195,9 @@ public EarlyStoppingMetric EarlyStoppingMetric
174195
}
175196
}
176197

198+
/// <summary>
199+
/// Create a new <see cref="Options"/> object with default values.
200+
/// </summary>
177201
public Options()
178202
{
179203
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
@@ -183,15 +207,25 @@ public Options()
183207
}
184208
}
185209

210+
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
186211
public sealed partial class FastTreeRankingTrainer
187212
{
213+
/// <summary>
214+
/// Options for the <see cref="FastTreeRankingTrainer"/>.
215+
/// </summary>
188216
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
189217
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
190218
{
191-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")]
219+
/// <summary>
220+
/// Comma-separated list of gains associated with each relevance label.
221+
/// </summary>
222+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma-separated list of gains associated to each relevance label.", ShortName = "gains")]
192223
[TGUI(NoSweep = true)]
193224
public double[] CustomGains = new double[] { 0, 3, 7, 15, 31 };
194225

226+
/// <summary>
227+
/// Whether to train using discounted cumulative gain (DCG) instead of normalized DCG (NDCG).
228+
/// </summary>
195229
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Train DCG instead of NDCG", ShortName = "dcg")]
196230
public bool UseDcg;
197231

@@ -204,7 +238,11 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
204238
[TGUI(NotGui = true)]
205239
internal string SortingAlgorithm = "DescendingStablePessimistic";
206240

207-
[Argument(ArgumentType.AtMostOnce, HelpText = "max-NDCG truncation to use in the Lambda Mart algorithm", ShortName = "n", Hide = true)]
241+
/// <summary>
242+
/// The maximum NDCG truncation to use in the
243+
/// <a href="https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf">LambdaMAR algorithm</a>.
244+
/// </summary>
245+
[Argument(ArgumentType.AtMostOnce, HelpText = "max-NDCG truncation to use in the LambdaMART algorithm", ShortName = "n", Hide = true)]
208246
[TGUI(NotGui = true)]
209247
public int NdcgTruncationLevel = 100;
210248

@@ -253,6 +291,9 @@ public EarlyStoppingRankingMetric EarlyStoppingMetric
253291
}
254292
}
255293

294+
/// <summary>
295+
/// Create a new <see cref="Options"/> object with default values.
296+
/// </summary>
256297
public Options()
257298
{
258299
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt1; // Use L1 by default.
@@ -295,6 +336,9 @@ internal static class Defaults
295336
public const double LearningRate = 0.2;
296337
}
297338

339+
/// <summary>
340+
/// Options for tree trainers.
341+
/// </summary>
298342
public abstract class TreeOptions : TrainerInputBaseWithGroupId
299343
{
300344
/// <summary>
@@ -428,11 +472,13 @@ public abstract class TreeOptions : TrainerInputBaseWithGroupId
428472
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The feature re-use penalty (regularization) coefficient", ShortName = "frup")]
429473
public Double FeatureReusePenalty;
430474

431-
/// Only consider a gain if its likelihood versus a random choice gain is above a certain value.
432-
/// So 0.95 would mean restricting to gains that have less than a 0.05 change of being generated randomly through choice of a random split.
433475
/// <summary>
434-
/// Tree fitting gain confidence requirement (should be in the range [0,1) ).
476+
/// Tree fitting gain confidence requirement. Only consider a gain if its likelihood versus a random choice gain is above this value.
435477
/// </summary>
478+
/// <value>
479+
/// Value of 0.95 would mean restricting to gains that have less than a 0.05 chance of being generated randomly through choice of a random split.
480+
/// Valid range is [0,1).
481+
/// </value>
436482
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Tree fitting gain confidence requirement (should be in the range [0,1) ).", ShortName = "gainconf")]
437483
public Double GainConfidenceLevel;
438484

@@ -458,7 +504,7 @@ public abstract class TreeOptions : TrainerInputBaseWithGroupId
458504
public int NumberOfLeaves = Defaults.NumberOfLeaves;
459505

460506
/// <summary>
461-
/// The minimal number of examples allowed in a leaf of a regression tree, out of the subsampled data.
507+
/// The minimal number of data points required to form a new tree leaf.
462508
/// </summary>
463509
// REVIEW: Arrays not supported in GUI
464510
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
@@ -582,6 +628,9 @@ internal virtual void Check(IExceptionContext ectx)
582628
}
583629
}
584630

631+
/// <summary>
632+
/// Options for boosting tree trainers.
633+
/// </summary>
585634
public abstract class BoostedTreeOptions : TreeOptions
586635
{
587636
// REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate.
@@ -594,7 +643,7 @@ public abstract class BoostedTreeOptions : TreeOptions
594643
public bool BestStepRankingRegressionTrees = false;
595644

596645
/// <summary>
597-
/// Should we use line search for a step size.
646+
/// Determines whether to use line search for a step size.
598647
/// </summary>
599648
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use line search for a step size", ShortName = "ls")]
600649
public bool UseLineSearch;
@@ -611,11 +660,17 @@ public abstract class BoostedTreeOptions : TreeOptions
611660
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Minimum line search step size", ShortName = "minstep")]
612661
public Double MinimumStepSize;
613662

663+
/// <summary>
664+
/// Types of optimization algorithms.
665+
/// </summary>
614666
public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDescent, ConjugateGradientDescent };
615667

616668
/// <summary>
617-
/// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent).
669+
/// Optimization algorithm to be used.
618670
/// </summary>
671+
/// <value>
672+
/// See <see cref="OptimizationAlgorithmType"/> for available optimizers.
673+
/// </value>
619674
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent)", ShortName = "oa")]
620675
public OptimizationAlgorithmType OptimizationAlgorithm = OptimizationAlgorithmType.GradientDescent;
621676

@@ -655,7 +710,7 @@ public EarlyStoppingRuleBase EarlyStoppingRule
655710
internal int EarlyStoppingMetrics;
656711

657712
/// <summary>
658-
/// Enable post-training pruning to avoid overfitting. (a validation set is required).
713+
/// Enable post-training tree pruning to avoid overfitting. It requires a validation set.
659714
/// </summary>
660715
[Argument(ArgumentType.AtMostOnce, HelpText = "Enable post-training pruning to avoid overfitting. (a validation set is required)", ShortName = "pruning")]
661716
public bool EnablePruning;

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
9898
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
9999
}
100100

101-
/// <include file = 'doc.xml' path='doc/members/member[@name="FastTree"]/*' />
101+
/// <summary>
102+
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using FastTree.
103+
/// </summary>
104+
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
102105
public sealed partial class FastTreeBinaryClassificationTrainer :
103106
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options,
104107
BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939

4040
namespace Microsoft.ML.Trainers.FastTree
4141
{
42-
/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
42+
/// <summary>
43+
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree ranking model using FastTree.
44+
/// </summary>
45+
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
4346
public sealed partial class FastTreeRankingTrainer
4447
: BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Options, RankingPredictionTransformer<FastTreeRankingModelParameters>, FastTreeRankingModelParameters>
4548
{

0 commit comments

Comments
 (0)