Skip to content

Commit 1a9e7aa

Browse files
abgoswamshauheen
authored andcommitted
Convert LdaTransform to IEstimator/ITransformer API (dotnet#1410)
1 parent dafa30c commit 1a9e7aa

File tree

13 files changed

+931
-597
lines changed

13 files changed

+931
-597
lines changed

src/Microsoft.ML.Legacy/CSharpApi.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13997,10 +13997,10 @@ public LabelToFloatConverterPipelineStep(Output output)
1399713997
namespace Legacy.Transforms
1399813998
{
1399913999

14000-
public sealed partial class LdaTransformColumn : OneToOneColumn<LdaTransformColumn>, IOneToOneColumn
14000+
public sealed partial class LatentDirichletAllocationTransformerColumn : OneToOneColumn<LatentDirichletAllocationTransformerColumn>, IOneToOneColumn
1400114001
{
1400214002
/// <summary>
14003-
/// The number of topics in the LDA
14003+
/// The number of topics
1400414004
/// </summary>
1400514005
public int? NumTopic { get; set; }
1400614006

@@ -14099,26 +14099,26 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo
1409914099

1410014100
public void AddColumn(string inputColumn)
1410114101
{
14102-
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>() : new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>(Column);
14103-
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>.Create(inputColumn));
14102+
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>(Column);
14103+
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>.Create(inputColumn));
1410414104
Column = list.ToArray();
1410514105
}
1410614106

1410714107
public void AddColumn(string outputColumn, string inputColumn)
1410814108
{
14109-
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>() : new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>(Column);
14110-
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>.Create(outputColumn, inputColumn));
14109+
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>(Column);
14110+
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>.Create(outputColumn, inputColumn));
1411114111
Column = list.ToArray();
1411214112
}
1411314113

1411414114

1411514115
/// <summary>
1411614116
/// New column definition(s) (optional form: name:srcs)
1411714117
/// </summary>
14118-
public LdaTransformColumn[] Column { get; set; }
14118+
public LatentDirichletAllocationTransformerColumn[] Column { get; set; }
1411914119

1412014120
/// <summary>
14121-
/// The number of topics in the LDA
14121+
/// The number of topics
1412214122
/// </summary>
1412314123
[TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})]
1412414124
public int NumTopic { get; set; } = 100;
@@ -14153,14 +14153,14 @@ public void AddColumn(string outputColumn, string inputColumn)
1415314153
public int LikelihoodInterval { get; set; } = 5;
1415414154

1415514155
/// <summary>
14156-
/// The threshold of maximum count of tokens per doc
14156+
/// The number of training threads. Default value depends on number of logical processors.
1415714157
/// </summary>
14158-
public int NumMaxDocToken { get; set; } = 512;
14158+
public int NumThreads { get; set; }
1415914159

1416014160
/// <summary>
14161-
/// The number of training threads. Default value depends on number of logical processors.
14161+
/// The threshold of maximum count of tokens per doc
1416214162
/// </summary>
14163-
public int? NumThreads { get; set; }
14163+
public int NumMaxDocToken { get; set; } = 512;
1416414164

1416514165
/// <summary>
1416614166
/// The number of words to summarize the topic

src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.EntryPoints;
77
using Microsoft.ML.Transforms.Categorical;
88
using Microsoft.ML.Transforms.Text;
9+
using System.Linq;
910

1011
[assembly: LoadableClass(typeof(void), typeof(TextAnalytics), null, typeof(SignatureEntryPointModule), "TextAnalytics")]
1112

@@ -118,18 +119,21 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, T
118119
}
119120

120121
[TlcModule.EntryPoint(Name = "Transforms.LightLda",
121-
Desc = LdaTransform.Summary,
122-
UserName = LdaTransform.UserName,
123-
ShortName = LdaTransform.ShortName,
122+
Desc = LatentDirichletAllocationTransformer.Summary,
123+
UserName = LatentDirichletAllocationTransformer.UserName,
124+
ShortName = LatentDirichletAllocationTransformer.ShortName,
124125
XmlInclude = new[] { @"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/member[@name=""LightLDA""]/*' />",
125126
@"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/example[@name=""LightLDA""]/*' />" })]
126-
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input)
127+
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input)
127128
{
128129
Contracts.CheckValue(env, nameof(env));
129130
env.CheckValue(input, nameof(input));
130131

131132
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input);
132-
var view = new LdaTransform(h, input, input.Data);
133+
var cols = input.Column.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray();
134+
var est = new LatentDirichletAllocationEstimator(h, cols);
135+
var view = est.Fit(input.Data).Transform(input.Data);
136+
133137
return new CommonOutputs.TransformOutput()
134138
{
135139
Model = new TransformModel(h, view, input.Data),
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.StaticPipe.Runtime;
9+
using Microsoft.ML.Transforms.Text;
10+
using System;
11+
using System.Collections.Generic;
12+
13+
namespace Microsoft.ML.StaticPipe
14+
{
15+
/// <summary>
16+
/// Information on the result of fitting a LDA transform.
17+
/// </summary>
18+
public sealed class LdaFitResult
19+
{
20+
/// <summary>
21+
/// For user defined delegates that accept instances of the containing type.
22+
/// </summary>
23+
/// <param name="result"></param>
24+
public delegate void OnFit(LdaFitResult result);
25+
26+
public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary;
27+
public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
28+
{
29+
LdaTopicSummary = ldaTopicSummary;
30+
}
31+
}
32+
33+
public static class LdaStaticExtensions
34+
{
35+
private struct Config
36+
{
37+
public readonly int NumTopic;
38+
public readonly Single AlphaSum;
39+
public readonly Single Beta;
40+
public readonly int MHStep;
41+
public readonly int NumIter;
42+
public readonly int LikelihoodInterval;
43+
public readonly int NumThread;
44+
public readonly int NumMaxDocToken;
45+
public readonly int NumSummaryTermPerTopic;
46+
public readonly int NumBurninIter;
47+
public readonly bool ResetRandomGenerator;
48+
49+
public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit;
50+
51+
public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval,
52+
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator,
53+
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit)
54+
{
55+
NumTopic = numTopic;
56+
AlphaSum = alphaSum;
57+
Beta = beta;
58+
MHStep = mhStep;
59+
NumIter = numIter;
60+
LikelihoodInterval = likelihoodInterval;
61+
NumThread = numThread;
62+
NumMaxDocToken = numMaxDocToken;
63+
NumSummaryTermPerTopic = numSummaryTermPerTopic;
64+
NumBurninIter = numBurninIter;
65+
ResetRandomGenerator = resetRandomGenerator;
66+
67+
OnFit = onFit;
68+
}
69+
}
70+
71+
private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LdaFitResult.OnFit onFit)
72+
{
73+
if (onFit == null)
74+
return null;
75+
76+
return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary));
77+
}
78+
79+
private interface ILdaCol
80+
{
81+
PipelineColumn Input { get; }
82+
Config Config { get; }
83+
}
84+
85+
private sealed class ImplVector : Vector<float>, ILdaCol
86+
{
87+
public PipelineColumn Input { get; }
88+
public Config Config { get; }
89+
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
90+
{
91+
Input = input;
92+
Config = config;
93+
}
94+
}
95+
96+
private sealed class Rec : EstimatorReconciler
97+
{
98+
public static readonly Rec Inst = new Rec();
99+
100+
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
101+
PipelineColumn[] toOutput,
102+
IReadOnlyDictionary<PipelineColumn, string> inputNames,
103+
IReadOnlyDictionary<PipelineColumn, string> outputNames,
104+
IReadOnlyCollection<string> usedNames)
105+
{
106+
var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length];
107+
Action<LatentDirichletAllocationTransformer> onFit = null;
108+
for (int i = 0; i < toOutput.Length; ++i)
109+
{
110+
var tcol = (ILdaCol)toOutput[i];
111+
112+
infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]],
113+
tcol.Config.NumTopic,
114+
tcol.Config.AlphaSum,
115+
tcol.Config.Beta,
116+
tcol.Config.MHStep,
117+
tcol.Config.NumIter,
118+
tcol.Config.LikelihoodInterval,
119+
tcol.Config.NumThread,
120+
tcol.Config.NumMaxDocToken,
121+
tcol.Config.NumSummaryTermPerTopic,
122+
tcol.Config.NumBurninIter,
123+
tcol.Config.ResetRandomGenerator);
124+
125+
if (tcol.Config.OnFit != null)
126+
{
127+
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call.
128+
onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii));
129+
}
130+
}
131+
132+
var est = new LatentDirichletAllocationEstimator(env, infos);
133+
if (onFit == null)
134+
return est;
135+
136+
return est.WithOnFitDelegate(onFit);
137+
}
138+
}
139+
140+
/// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' />
141+
/// <param name="input">A vector of floats representing the document.</param>
142+
/// <param name="numTopic">The number of topics.</param>
143+
/// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param>
144+
/// <param name="beta">Dirichlet prior on vocab-topic vectors.</param>
145+
/// <param name="mhstep">Number of Metropolis Hasting step.</param>
146+
/// <param name="numIterations">Number of iterations.</param>
147+
/// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param>
148+
/// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param>
149+
/// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param>
150+
/// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param>
151+
/// <param name="numBurninIterations">The number of burn-in iterations.</param>
152+
/// <param name="resetRandomGenerator">Reset the random number generator for each document.</param>
153+
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
154+
public static Vector<float> ToLdaTopicVector(this Vector<float> input,
155+
int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic,
156+
Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum,
157+
Single beta = LatentDirichletAllocationEstimator.Defaults.Beta,
158+
int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep,
159+
int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations,
160+
int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval,
161+
int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads,
162+
int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
163+
int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
164+
int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
165+
bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator,
166+
LdaFitResult.OnFit onFit = null)
167+
{
168+
Contracts.CheckValue(input, nameof(input));
169+
return new ImplVector(input,
170+
new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic,
171+
numBurninIterations, resetRandomGenerator, Wrap(onFit)));
172+
}
173+
}
174+
}

0 commit comments

Comments
 (0)