-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Convert LdaTransform to IEstimator/ITransformer API #1410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
acd8964
cd2f20c
c289ed1
8dfb527
bcb3b0d
1d39408
9ff60f7
ad36d2f
b0e0375
b0422e4
7bc6e2b
e42c5e4
c099d4a
a1d14ed
3f39a04
57cd1c5
d4a4283
c7fb50a
d7660ca
e0d501b
c91afbb
4238fa1
34bb2e9
8b70ab1
b3c1284
b6e4028
5397de5
edd60af
b073038
49da3ee
0724290
5073baa
d1481f8
65125d4
b869d7f
62955a8
40333a7
850856b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…next iteration (i.e. make training a private sttaic method. removed _types as field)
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -641,7 +641,10 @@ public override Schema.Column[] GetOutputColumns() | |
| { | ||
| var result = new Schema.Column[_parent.ColumnPairs.Length]; | ||
| for (int i = 0; i < _parent.ColumnPairs.Length; i++) | ||
| result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._types[i], null); | ||
| { | ||
| var info = _parent._columns[i]; | ||
| result[i] = new Schema.Column(_parent.ColumnPairs[i].output, new VectorType(NumberType.Float, info.NumTopic), null); | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
|
|
@@ -684,9 +687,8 @@ private static VersionInfo GetVersionInfo() | |
| loaderAssemblyName: typeof(LdaTransformer).Assembly.FullName); | ||
| } | ||
|
|
||
| private readonly ColumnInfo[] _exes; | ||
| private readonly ColumnInfo[] _columns; | ||
| private readonly LdaState[] _ldas; | ||
| private readonly ColumnType[] _types; | ||
|
|
||
| private const string RegistrationName = "LightLda"; | ||
| private const string WordTopicModelFilename = "word_topic_summary.txt"; | ||
|
|
@@ -703,14 +705,9 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum | |
| internal LdaTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransformer)), GetColumnPairs(columns)) | ||
| { | ||
| _exes = columns; | ||
| _types = new ColumnType[columns.Length]; | ||
| _columns = columns; | ||
| _ldas = new LdaState[columns.Length]; | ||
|
|
||
| for (int i = 0; i < columns.Length; i++) | ||
| { | ||
| _types[i] = new VectorType(NumberType.Float, _exes[i].NumTopic); | ||
| } | ||
| using (var ch = Host.Start("Train")) | ||
| { | ||
| Train(ch, input, _ldas); | ||
|
||
|
|
@@ -728,14 +725,12 @@ private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) | |
|
|
||
| // Note: columnsLength would be just one in most cases. | ||
| var columnsLength = ColumnPairs.Length; | ||
| _exes = new ColumnInfo[columnsLength]; | ||
| _columns = new ColumnInfo[columnsLength]; | ||
| _ldas = new LdaState[columnsLength]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
are these guys reentrant? #Closed
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you kindly clarify what you mean by 're-entrant' ? in any case I have made LdaState internal, so hopefully its not an issue anymore In reply to: 232437876 [](ancestors = 232437876)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. re-entrant means that any methods can be called in parallel from multiple threads In reply to: 232470000 [](ancestors = 232470000,232437876)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am going by existing comments here.. He have locks in place to make it thread-safe. In reply to: 233585350 [](ancestors = 233585350,232470000,232437876) |
||
| _types = new ColumnType[columnsLength]; | ||
| for (int i = 0; i < _ldas.Length; i++) | ||
| { | ||
| _ldas[i] = new LdaState(Host, ctx); | ||
| _exes[i] = _ldas[i].InfoEx; | ||
| _types[i] = new VectorType(NumberType.Float, _ldas[i].InfoEx.NumTopic); | ||
| _columns[i] = _ldas[i].InfoEx; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -861,13 +856,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
| Host.AssertValue(ch); | ||
| ch.AssertValue(trainingData); | ||
| ch.AssertValue(states); | ||
| ch.Assert(states.Length == _exes.Length); | ||
| ch.Assert(states.Length == _columns.Length); | ||
|
|
||
| bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; | ||
| int[] numVocabs = new int[_exes.Length]; | ||
| int[] srcCols = new int[_exes.Length]; | ||
| int[] numVocabs = new int[_columns.Length]; | ||
| int[] srcCols = new int[_columns.Length]; | ||
|
|
||
| for (int i = 0; i < _exes.Length; i++) | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| if (!trainingData.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) | ||
| throw Host.ExceptSchemaMismatch(nameof(trainingData), "input", ColumnPairs[i].input); | ||
|
|
@@ -880,13 +875,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
| //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, | ||
| //one for the pre-calc memory, one for feedin data really | ||
| //another solution can be prepare these two value externally and put them in the beginning of the input file. | ||
| long[] corpusSize = new long[_exes.Length]; | ||
| int[] numDocArray = new int[_exes.Length]; | ||
| long[] corpusSize = new long[_columns.Length]; | ||
| int[] numDocArray = new int[_columns.Length]; | ||
|
|
||
| using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) | ||
| { | ||
| var getters = new ValueGetter<VBuffer<Double>>[_exes.Length]; | ||
| for (int i = 0; i < _exes.Length; i++) | ||
| var getters = new ValueGetter<VBuffer<Double>>[_columns.Length]; | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| corpusSize[i] = 0; | ||
| numDocArray[i] = 0; | ||
|
|
@@ -898,7 +893,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
| while (cursor.MoveNext()) | ||
| { | ||
| ++rowCount; | ||
| for (int i = 0; i < _exes.Length; i++) | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| int docSize = 0; | ||
| getters[i](ref src); | ||
|
|
@@ -914,7 +909,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
| break; | ||
| } | ||
|
|
||
| if (docSize >= _exes[i].NumMaxDocToken - termFreq) | ||
| if (docSize >= _columns[i].NumMaxDocToken - termFreq) | ||
| break; //control the document length | ||
|
|
||
| //if legal then add the term | ||
|
|
@@ -934,7 +929,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
| } | ||
| } | ||
|
|
||
| for (int i = 0; i < _exes.Length; ++i) | ||
| for (int i = 0; i < _columns.Length; ++i) | ||
| { | ||
| if (numDocArray[i] != rowCount) | ||
| { | ||
|
|
@@ -945,9 +940,9 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
| } | ||
|
|
||
| // Initialize all LDA states | ||
| for (int i = 0; i < _exes.Length; i++) | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| var state = new LdaState(Host, _exes[i], numVocabs[i]); | ||
| var state = new LdaState(Host, _columns[i], numVocabs[i]); | ||
| if (numDocArray[i] == 0 || corpusSize[i] == 0) | ||
| throw ch.Except("The specified documents are all empty in column '{0}'.", ColumnPairs[i].input); | ||
|
|
||
|
|
@@ -957,11 +952,11 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
|
|
||
| using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) | ||
| { | ||
| int[] docSizeCheck = new int[_exes.Length]; | ||
| int[] docSizeCheck = new int[_columns.Length]; | ||
| // This could be optimized so that if multiple trainers consume the same column, it is | ||
| // fed into the train method once. | ||
| var getters = new ValueGetter<VBuffer<Double>>[_exes.Length]; | ||
| for (int i = 0; i < _exes.Length; i++) | ||
| var getters = new ValueGetter<VBuffer<Double>>[_columns.Length]; | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| docSizeCheck[i] = 0; | ||
| getters[i] = RowCursorUtils.GetVecGetterAs<Double>(NumberType.R8, cursor, srcCols[i]); | ||
|
|
@@ -971,13 +966,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) | |
|
|
||
| while (cursor.MoveNext()) | ||
| { | ||
| for (int i = 0; i < _exes.Length; i++) | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| getters[i](ref src); | ||
| docSizeCheck[i] += states[i].FeedTrain(Host, in src); | ||
| } | ||
| } | ||
| for (int i = 0; i < _exes.Length; i++) | ||
| for (int i = 0; i < _columns.Length; i++) | ||
| { | ||
| Host.Assert(corpusSize[i] == docSizeCheck[i]); | ||
| states[i].CompleteTrain(); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this equal to
info.Input? #ResolvedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is equal to info.Output. If output column name is not specified then info.Output is same as info.Input.
Is there any concern here ?
In reply to: 233584111 [](ancestors = 233584111)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well,
info.Outputis shorter, so I'd rather use it.In reply to: 233740244 [](ancestors = 233740244,233584111)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keeping this as-is
We are using SaveColumns from OneToOneTransformer base. So when loading up models, we should be using _parent.ColumnPairs[i].output (info.Output may be null in those cases)
In reply to: 234857190 [](ancestors = 234857190,233740244,233584111)