1
+ // Licensed to the .NET Foundation under one or more agreements.
2
+ // The .NET Foundation licenses this file to you under the MIT license.
3
+ // See the LICENSE file in the project root for more information.
4
+
5
+ using Microsoft . ML . Core . Data ;
6
+ using Microsoft . ML . Runtime ;
7
+ using Microsoft . ML . Runtime . Data ;
8
+ using Microsoft . ML . StaticPipe . Runtime ;
9
+ using Microsoft . ML . Transforms . Text ;
10
+ using System ;
11
+ using System . Collections . Generic ;
12
+
13
+ namespace Microsoft . ML . StaticPipe
14
+ {
15
+ /// <summary>
16
+ /// Information on the result of fitting a LDA transform.
17
+ /// </summary>
18
+ public sealed class LdaFitResult
19
+ {
20
+ /// <summary>
21
+ /// For user defined delegates that accept instances of the containing type.
22
+ /// </summary>
23
+ /// <param name="result"></param>
24
+ public delegate void OnFit ( LdaFitResult result ) ;
25
+
26
+ public LatentDirichletAllocationTransformer . LdaSummary LdaTopicSummary ;
27
+ public LdaFitResult ( LatentDirichletAllocationTransformer . LdaSummary ldaTopicSummary )
28
+ {
29
+ LdaTopicSummary = ldaTopicSummary ;
30
+ }
31
+ }
32
+
33
+ public static class LdaStaticExtensions
34
+ {
35
+ private struct Config
36
+ {
37
+ public readonly int NumTopic ;
38
+ public readonly Single AlphaSum ;
39
+ public readonly Single Beta ;
40
+ public readonly int MHStep ;
41
+ public readonly int NumIter ;
42
+ public readonly int LikelihoodInterval ;
43
+ public readonly int NumThread ;
44
+ public readonly int NumMaxDocToken ;
45
+ public readonly int NumSummaryTermPerTopic ;
46
+ public readonly int NumBurninIter ;
47
+ public readonly bool ResetRandomGenerator ;
48
+
49
+ public readonly Action < LatentDirichletAllocationTransformer . LdaSummary > OnFit ;
50
+
51
+ public Config ( int numTopic , Single alphaSum , Single beta , int mhStep , int numIter , int likelihoodInterval ,
52
+ int numThread , int numMaxDocToken , int numSummaryTermPerTopic , int numBurninIter , bool resetRandomGenerator ,
53
+ Action < LatentDirichletAllocationTransformer . LdaSummary > onFit )
54
+ {
55
+ NumTopic = numTopic ;
56
+ AlphaSum = alphaSum ;
57
+ Beta = beta ;
58
+ MHStep = mhStep ;
59
+ NumIter = numIter ;
60
+ LikelihoodInterval = likelihoodInterval ;
61
+ NumThread = numThread ;
62
+ NumMaxDocToken = numMaxDocToken ;
63
+ NumSummaryTermPerTopic = numSummaryTermPerTopic ;
64
+ NumBurninIter = numBurninIter ;
65
+ ResetRandomGenerator = resetRandomGenerator ;
66
+
67
+ OnFit = onFit ;
68
+ }
69
+ }
70
+
71
+ private static Action < LatentDirichletAllocationTransformer . LdaSummary > Wrap ( LdaFitResult . OnFit onFit )
72
+ {
73
+ if ( onFit == null )
74
+ return null ;
75
+
76
+ return ldaTopicSummary => onFit ( new LdaFitResult ( ldaTopicSummary ) ) ;
77
+ }
78
+
79
+ private interface ILdaCol
80
+ {
81
+ PipelineColumn Input { get ; }
82
+ Config Config { get ; }
83
+ }
84
+
85
+ private sealed class ImplVector : Vector < float > , ILdaCol
86
+ {
87
+ public PipelineColumn Input { get ; }
88
+ public Config Config { get ; }
89
+ public ImplVector ( PipelineColumn input , Config config ) : base ( Rec . Inst , input )
90
+ {
91
+ Input = input ;
92
+ Config = config ;
93
+ }
94
+ }
95
+
96
+ private sealed class Rec : EstimatorReconciler
97
+ {
98
+ public static readonly Rec Inst = new Rec ( ) ;
99
+
100
+ public override IEstimator < ITransformer > Reconcile ( IHostEnvironment env ,
101
+ PipelineColumn [ ] toOutput ,
102
+ IReadOnlyDictionary < PipelineColumn , string > inputNames ,
103
+ IReadOnlyDictionary < PipelineColumn , string > outputNames ,
104
+ IReadOnlyCollection < string > usedNames )
105
+ {
106
+ var infos = new LatentDirichletAllocationTransformer . ColumnInfo [ toOutput . Length ] ;
107
+ Action < LatentDirichletAllocationTransformer > onFit = null ;
108
+ for ( int i = 0 ; i < toOutput . Length ; ++ i )
109
+ {
110
+ var tcol = ( ILdaCol ) toOutput [ i ] ;
111
+
112
+ infos [ i ] = new LatentDirichletAllocationTransformer . ColumnInfo ( inputNames [ tcol . Input ] , outputNames [ toOutput [ i ] ] ,
113
+ tcol . Config . NumTopic ,
114
+ tcol . Config . AlphaSum ,
115
+ tcol . Config . Beta ,
116
+ tcol . Config . MHStep ,
117
+ tcol . Config . NumIter ,
118
+ tcol . Config . LikelihoodInterval ,
119
+ tcol . Config . NumThread ,
120
+ tcol . Config . NumMaxDocToken ,
121
+ tcol . Config . NumSummaryTermPerTopic ,
122
+ tcol . Config . NumBurninIter ,
123
+ tcol . Config . ResetRandomGenerator ) ;
124
+
125
+ if ( tcol . Config . OnFit != null )
126
+ {
127
+ int ii = i ; // Necessary because if we capture i that will change to toOutput.Length on call.
128
+ onFit += tt => tcol . Config . OnFit ( tt . GetLdaDetails ( ii ) ) ;
129
+ }
130
+ }
131
+
132
+ var est = new LatentDirichletAllocationEstimator ( env , infos ) ;
133
+ if ( onFit == null )
134
+ return est ;
135
+
136
+ return est . WithOnFitDelegate ( onFit ) ;
137
+ }
138
+ }
139
+
140
+ /// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' />
141
+ /// <param name="input">A vector of floats representing the document.</param>
142
+ /// <param name="numTopic">The number of topics.</param>
143
+ /// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param>
144
+ /// <param name="beta">Dirichlet prior on vocab-topic vectors.</param>
145
+ /// <param name="mhstep">Number of Metropolis Hasting step.</param>
146
+ /// <param name="numIterations">Number of iterations.</param>
147
+ /// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param>
148
+ /// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param>
149
+ /// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param>
150
+ /// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param>
151
+ /// <param name="numBurninIterations">The number of burn-in iterations.</param>
152
+ /// <param name="resetRandomGenerator">Reset the random number generator for each document.</param>
153
+ /// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
154
+ public static Vector < float > ToLdaTopicVector ( this Vector < float > input ,
155
+ int numTopic = LatentDirichletAllocationEstimator . Defaults . NumTopic ,
156
+ Single alphaSum = LatentDirichletAllocationEstimator . Defaults . AlphaSum ,
157
+ Single beta = LatentDirichletAllocationEstimator . Defaults . Beta ,
158
+ int mhstep = LatentDirichletAllocationEstimator . Defaults . Mhstep ,
159
+ int numIterations = LatentDirichletAllocationEstimator . Defaults . NumIterations ,
160
+ int likelihoodInterval = LatentDirichletAllocationEstimator . Defaults . LikelihoodInterval ,
161
+ int numThreads = LatentDirichletAllocationEstimator . Defaults . NumThreads ,
162
+ int numMaxDocToken = LatentDirichletAllocationEstimator . Defaults . NumMaxDocToken ,
163
+ int numSummaryTermPerTopic = LatentDirichletAllocationEstimator . Defaults . NumSummaryTermPerTopic ,
164
+ int numBurninIterations = LatentDirichletAllocationEstimator . Defaults . NumBurninIterations ,
165
+ bool resetRandomGenerator = LatentDirichletAllocationEstimator . Defaults . ResetRandomGenerator ,
166
+ LdaFitResult . OnFit onFit = null )
167
+ {
168
+ Contracts . CheckValue ( input , nameof ( input ) ) ;
169
+ return new ImplVector ( input ,
170
+ new Config ( numTopic , alphaSum , beta , mhstep , numIterations , likelihoodInterval , numThreads , numMaxDocToken , numSummaryTermPerTopic ,
171
+ numBurninIterations , resetRandomGenerator , Wrap ( onFit ) ) ) ;
172
+ }
173
+ }
174
+ }
0 commit comments