Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/Microsoft.ML.LightGBM/LightGbmArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface ISupportBoosterParameterFactory : IComponentFactory<IBoosterPar
}
public interface IBoosterParameter
{
void UpdateParameters(Dictionary<string, string> res);
void UpdateParameters(Dictionary<string, object> res);
}

/// <summary>
Expand All @@ -54,7 +54,7 @@ protected BoosterParameter(TArgs args)
/// <summary>
/// Update the parameters by specific Booster, will update parameters into "res" directly.
/// </summary>
public virtual void UpdateParameters(Dictionary<string, string> res)
public virtual void UpdateParameters(Dictionary<string, object> res)
{
FieldInfo[] fields = Args.GetType().GetFields();
foreach (var field in fields)
Expand Down Expand Up @@ -163,7 +163,7 @@ public TreeBooster(Arguments args)
Contracts.CheckUserArg(Args.ScalePosWeight > 0 && Args.ScalePosWeight <= 1, nameof(Args.ScalePosWeight), "must be in (0,1].");
}

public override void UpdateParameters(Dictionary<string, string> res)
public override void UpdateParameters(Dictionary<string, object> res)
{
base.UpdateParameters(res);
res["boosting_type"] = Name;
Expand Down Expand Up @@ -207,7 +207,7 @@ public DartBooster(Arguments args)
Contracts.CheckUserArg(Args.SkipDrop >= 0 && Args.SkipDrop < 1, nameof(Args.SkipDrop), "must be in [0,1).");
}

public override void UpdateParameters(Dictionary<string, string> res)
public override void UpdateParameters(Dictionary<string, object> res)
{
base.UpdateParameters(res);
res["boosting_type"] = Name;
Expand Down Expand Up @@ -244,7 +244,7 @@ public GossBooster(Arguments args)
Contracts.Check(Args.TopRate + Args.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1.");
}

public override void UpdateParameters(Dictionary<string, string> res)
public override void UpdateParameters(Dictionary<string, object> res)
{
base.UpdateParameters(res);
res["boosting_type"] = Name;
Expand Down Expand Up @@ -355,11 +355,11 @@ public enum EvalMetricType
[Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")]
public ISupportParallel ParallelTrainer = new SingleTrainerFactory();

internal Dictionary<string, string> ToDictionary(IHost host)
internal Dictionary<string, object> ToDictionary(IHost host)
{
Contracts.CheckValue(host, nameof(host));
Contracts.CheckUserArg(MaxBin > 0, nameof(MaxBin), "must be > 0.");
Dictionary<string, string> res = new Dictionary<string, string>();
Dictionary<string, object> res = new Dictionary<string, object>();

var boosterParams = Booster.CreateComponent(host);
boosterParams.UpdateParameters(res);
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Globalization;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand Down Expand Up @@ -130,9 +131,9 @@ protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float
protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg=false)
{
base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true);
int numLeaves = int.Parse(Options["num_leaves"]);
int numLeaves = (int)Options["num_leaves"];
int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass);
Options["min_data_per_leaf"] = minDataPerLeaf.ToString();
Options["min_data_per_leaf"] = minDataPerLeaf;
if (!hiddenMsg)
{
if (!Args.LearningRate.HasValue)
Expand All @@ -149,7 +150,7 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role
Host.AssertValue(ch);
ch.Assert(PredictionKind == PredictionKind.MultiClassClassification);
ch.Assert(_numClass > 1);
Options["num_class"] = _numClass.ToString();
Options["num_class"] = _numClass;
bool useSoftmax = false;

if (Args.UseSoftmax.HasValue)
Expand Down
28 changes: 15 additions & 13 deletions src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.FastTree.Internal;

namespace Microsoft.ML.Runtime.LightGBM
{
Expand Down Expand Up @@ -49,7 +45,13 @@ private sealed class CategoricalMetaData

protected readonly IHost Host;
protected readonly LightGbmArguments Args;
protected readonly Dictionary<string, string> Options;

/// <summary>
/// Stores argumments as objects to convert them to invariant string type in the end so that
/// the code is culture agnostic. When retrieving key value from this dictionary as string
/// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]).
/// </summary>
protected readonly Dictionary<string, object> Options;
protected readonly IParallel ParallelTraining;

// Store _featureCount and _trainedEnsemble to construct predictor.
Expand Down Expand Up @@ -159,9 +161,9 @@ protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCat
double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats);
int numLeaves = Args.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats);
int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1);
Options["learning_rate"] = learningRate.ToString();
Options["num_leaves"] = numLeaves.ToString();
Options["min_data_per_leaf"] = minDataPerLeaf.ToString();
Options["learning_rate"] = learningRate;
Options["num_leaves"] = numLeaves;
Options["min_data_per_leaf"] = minDataPerLeaf;
if (!hiddenMsg)
{
if (!Args.LearningRate.HasValue)
Expand Down Expand Up @@ -192,7 +194,7 @@ private static List<int> GetCategoricalBoundires(int[] categoricalFeatures, int
{
if (j < categoricalFeatures.Length && curFidx == categoricalFeatures[j])
{
if (curFidx > catBoundaries.Last())
if (curFidx > catBoundaries[catBoundaries.Count - 1])
catBoundaries.Add(curFidx);
if (categoricalFeatures[j + 1] - categoricalFeatures[j] >= 0)
{
Expand All @@ -219,7 +221,7 @@ private static List<int> GetCategoricalBoundires(int[] categoricalFeatures, int
private static List<string> ConstructCategoricalFeatureMetaData(int[] categoricalFeatures, int rawNumCol, ref CategoricalMetaData catMetaData)
{
List<int> catBoundaries = GetCategoricalBoundires(categoricalFeatures, rawNumCol);
catMetaData.NumCol = catBoundaries.Count() - 1;
Copy link
Contributor

@TomFinley TomFinley Jul 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Count() [](start = 47, length = 7)

Ah very good. Yes that is terrible. There really should be a warning or something on Count()... :P #Resolved

catMetaData.NumCol = catBoundaries.Count - 1;
catMetaData.CategoricalBoudaries = catBoundaries.ToArray();
catMetaData.IsCategoricalFeature = new bool[catMetaData.NumCol];
catMetaData.OnehotIndices = new int[rawNumCol];
Expand Down Expand Up @@ -279,7 +281,7 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t
{
var catIndices = ConstructCategoricalFeatureMetaData(categoricalFeatures, rawNumCol, ref catMetaData);
// Set categorical features
Options["categorical_feature"] = String.Join(",", catIndices);
Options["categorical_feature"] = string.Join(",", catIndices);
}
return catMetaData;
}
Expand Down Expand Up @@ -527,13 +529,13 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor,
++nhot;
var prob = rand.NextSingle();
if (prob < 1.0f / nhot)
values[values.Count() - 1] = fv;
values[values.Count - 1] = fv;
}
lastIdx = newColIdx;
}
indices = featureIndices.ToArray();
featureValues = values.ToArray();
cnt = featureIndices.Count();
cnt = featureIndices.Count;
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal sealed class Booster : IDisposable
public IntPtr Handle { get; private set; }
public int BestIteration { get; set; }

public Booster(Dictionary<string, string> parameters, Dataset trainset, Dataset validset = null)
public Booster(Dictionary<string, object> parameters, Dataset trainset, Dataset validset = null)
{
var param = LightGbmInterfaceUtils.JoinParameters(parameters);
var handle = IntPtr.Zero;
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.ML.LightGBM/WrappedLightGbmInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.InteropServices;
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.InteropServices;

namespace Microsoft.ML.Runtime.LightGBM
{
Expand Down Expand Up @@ -199,13 +200,13 @@ public static void Check(int res)
/// <summary>
/// Join the parameters to key=value format.
/// </summary>
public static string JoinParameters(Dictionary<string, string> parameters)
public static string JoinParameters(Dictionary<string, object> parameters)
{
if (parameters == null)
return "";
List<string> res = new List<string>();
foreach (var keyVal in parameters)
res.Add(keyVal.Key + "=" + keyVal.Value);
res.Add(keyVal.Key + "=" + string.Format(CultureInfo.InvariantCulture, "{0}", keyVal.Value));
return string.Join(" ", res);
}

Expand Down
9 changes: 3 additions & 6 deletions src/Microsoft.ML.LightGBM/WrappedLightGbmTraining.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal static class WrappedLightGbmTraining
/// Train and return a booster.
/// </summary>
public static Booster Train(IChannel ch, IProgressChannel pch,
Dictionary<string, string> parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100,
Dictionary<string, object> parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100,
bool verboseEval = true, int earlyStoppingRound = 0)
{
// create Booster.
Expand All @@ -33,12 +33,9 @@ public static Booster Train(IChannel ch, IProgressChannel pch,
double bestScore = double.MaxValue;
double factorToSmallerBetter = 1.0;

if (earlyStoppingRound > 0 && (parameters["metric"] == "auc"
|| parameters["metric"] == "ndcg"
|| parameters["metric"] == "map"))
{
var metric = (string)parameters["metric"];
if (earlyStoppingRound > 0 && (metric == "auc" || metric == "ndcg" || metric == "map"))
factorToSmallerBetter = -1.0;
}

const int evalFreq = 50;

Expand Down