Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions Algorithms.Tests/MachineLearning/DecisionTreeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
using NUnit.Framework;
using Algorithms.MachineLearning;
using System;

namespace Algorithms.Tests.MachineLearning;

[TestFixture]
public class DecisionTreeTests
{
[Test]
public void Fit_ThrowsOnEmptyInput()
{
var tree = new DecisionTree();
Assert.Throws<ArgumentException>(() => tree.Fit(Array.Empty<int[]>(), Array.Empty<int>()));
}

[Test]
public void Fit_ThrowsOnMismatchedLabels()
{
var tree = new DecisionTree();
int[][] X = { new[] { 1, 2 } };
int[] y = { 1, 0 };
Assert.Throws<ArgumentException>(() => tree.Fit(X, y));
}

[Test]
public void Predict_ThrowsIfNotTrained()
{
var tree = new DecisionTree();
Assert.Throws<InvalidOperationException>(() => tree.Predict(new[] { 1, 2 }));
}

[Test]
public void Predict_ThrowsOnFeatureMismatch()
{
var tree = new DecisionTree();
int[][] X = { new[] { 1, 2 } };
int[] y = { 1 };
tree.Fit(X, y);
Assert.Throws<ArgumentException>(() => tree.Predict(new[] { 1 }));
}

[Test]
public void FitAndPredict_WorksOnSimpleData()
{
// Simple OR logic
int[][] X =
{
new[] { 0, 0 },
new[] { 0, 1 },
new[] { 1, 0 },
new[] { 1, 1 }
};
int[] y = { 0, 1, 1, 1 };
var tree = new DecisionTree();
tree.Fit(X, y);
Assert.That(tree.Predict(new[] { 0, 0 }), Is.EqualTo(0));
Assert.That(tree.Predict(new[] { 0, 1 }), Is.EqualTo(1));
Assert.That(tree.Predict(new[] { 1, 0 }), Is.EqualTo(1));
Assert.That(tree.Predict(new[] { 1, 1 }), Is.EqualTo(1));
}

[Test]
public void FeatureCount_ReturnsCorrectValue()
{
var tree = new DecisionTree();
int[][] X = { new[] { 1, 2, 3 } };
int[] y = { 1 };
tree.Fit(X, y);
Assert.That(tree.FeatureCount, Is.EqualTo(3));
}

[Test]
public void Predict_FallbacksToZeroForUnseenValue()
{
int[][] X = { new[] { 0, 0 }, new[] { 1, 1 } };
int[] y = { 0, 1 };
var tree = new DecisionTree();
tree.Fit(X, y);
// Value 2 is unseen in feature 0
Assert.That(tree.Predict(new[] { 2, 0 }), Is.EqualTo(0));
}

[Test]
public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft()
{
int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 } };
int[] y = { 1, 0, 1 };
var tree = new DecisionTree();
tree.Fit(X, y);
// All features used, fallback to most common label (0)
Assert.That(tree.Predict(new[] { 3 }), Is.EqualTo(0));
}

[Test]
public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft_MultipleLabels()
{
int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 }, new[] { 3 } };
int[] y = { 1, 0, 1, 0 };
var tree = new DecisionTree();
tree.Fit(X, y);
// Most common label is 0 (2 times)
Assert.That(tree.Predict(new[] { 4 }), Is.EqualTo(0));
}

[Test]
public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsZero()
{
int[][] X = { new[] { 0 }, new[] { 1 } };
int[] y = { 0, 0 };
var tree = new DecisionTree();
tree.Fit(X, y);
Assert.That(tree.Predict(new[] { 0 }), Is.EqualTo(0));
Assert.That(tree.Predict(new[] { 1 }), Is.EqualTo(0));
}

[Test]
public void Entropy_ReturnsZero_WhenAllZeroOrAllOne()
{
var method = typeof(DecisionTree).GetMethod("Entropy", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
Assert.That(method!.Invoke(null, new[] { new int[] { 0, 0, 0 } }), Is.EqualTo(0d));

Check warning on line 121 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L121

Refactor the code to not rely on potentially unsafe array conversions.

Check notice on line 121 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L121

Remove the array type; it is redundant.
Assert.That(method!.Invoke(null, new[] { new int[] { 1, 1, 1 } }), Is.EqualTo(0d));

Check warning on line 122 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L122

Refactor the code to not rely on potentially unsafe array conversions.

Check notice on line 122 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L122

Remove the array type; it is redundant.
}

[Test]
public void MostCommon_ReturnsCorrectLabel()
{
var method = typeof(DecisionTree).GetMethod("MostCommon", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
Assert.That(method!.Invoke(null, new[] { new int[] { 1, 0, 1, 1, 0, 0, 0 } }), Is.EqualTo(0));

Check warning on line 129 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L129

Refactor the code to not rely on potentially unsafe array conversions.

Check notice on line 129 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L129

Remove the array type; it is redundant.
Assert.That(method!.Invoke(null, new[] { new int[] { 1, 1, 1, 0 } }), Is.EqualTo(1));

Check warning on line 130 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L130

Refactor the code to not rely on potentially unsafe array conversions.

Check notice on line 130 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L130

Remove the array type; it is redundant.
}

[Test]
public void Traverse_FallbacksToZero_WhenChildrenIsNull()
{
// Create a node with Children = null and Label = null
var nodeType = typeof(DecisionTree).GetNestedType("Node", System.Reflection.BindingFlags.NonPublic);
var node = Activator.CreateInstance(nodeType!);
nodeType!.GetProperty("Feature")!.SetValue(node, 0);
nodeType!.GetProperty("Label")!.SetValue(node, null);
nodeType!.GetProperty("Children")!.SetValue(node, null);
var method = typeof(DecisionTree).GetMethod("Traverse", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
Assert.That(method!.Invoke(null, new[] { node!, new int[] { 99 } }), Is.EqualTo(0));

Check notice on line 143 in Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs#L143

Remove the array type; it is redundant.
}

[Test]
public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsSame()
{
int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 } };
int[] y = { 1, 1, 1 };
var tree = new DecisionTree();
tree.Fit(X, y);
Assert.That(tree.Predict(new[] { 0 }), Is.EqualTo(1));
Assert.That(tree.Predict(new[] { 1 }), Is.EqualTo(1));
Assert.That(tree.Predict(new[] { 2 }), Is.EqualTo(1));
}

[Test]
public void Entropy_ReturnsZero_WhenEmptyLabels()
{
// Use reflection to call private static Entropy
var method = typeof(DecisionTree).GetMethod("Entropy", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
Assert.That(method!.Invoke(null, new object[] { Array.Empty<int>() }), Is.EqualTo(0d));
}

[Test]
public void BestFeature_SkipsEmptyIdxBranch()
{
// Feature 1 has value 2 which is never present, triggers idx.Length == 0 branch
int[][] X = { new[] { 0, 1 }, new[] { 1, 1 } };
int[] y = { 0, 1 };
var method = typeof(DecisionTree).GetMethod("BestFeature", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
var features = new System.Collections.Generic.List<int> { 0, 1 };
var resultObj = method!.Invoke(null, new object[] { X, y, features });
Assert.That(resultObj, Is.Not.Null);
Assert.That((int)resultObj!, Is.EqualTo(0));
}

[Test]
public void BuildTree_MostCommonLabelBranch_IsCovered()
{
int[][] X = { new[] { 0 }, new[] { 1 } };
int[] y = { 0, 1 };
var tree = new DecisionTree();
tree.Fit(X, y);
Assert.That(tree.Predict(new[] { 2 }), Is.EqualTo(0));
}

[Test]
public void BuildTree_ContinueBranch_IsCovered()
{
int[][] X = { new[] { 0 }, new[] { 1 } };
int[] y = { 0, 1 };
var method = typeof(DecisionTree).GetMethod("BuildTree", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
var features = new System.Collections.Generic.List<int> { 0 };
Assert.DoesNotThrow(() => method!.Invoke(null, new object[] { X, y, features }));
}

[Test]
public void BestFeature_ContinueBranch_IsCovered()
{
int[][] X = { new[] { 0, 1 }, new[] { 1, 1 } };
int[] y = { 0, 1 };
var method = typeof(DecisionTree).GetMethod("BestFeature", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
var features = new System.Collections.Generic.List<int> { 0, 1 };
Assert.DoesNotThrow(() => method!.Invoke(null, new object[] { X, y, features }));
}
}
176 changes: 176 additions & 0 deletions Algorithms/MachineLearning/DecisionTree.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Algorithms.MachineLearning;

/// <summary>
/// Simple Decision Tree for binary classification using the ID3 algorithm.
/// Supports categorical features (int values).
/// </summary>
public class DecisionTree
{
private Node? root;

/// <summary>
/// Trains the decision tree using the ID3 algorithm.
/// </summary>
/// <param name="x">2D array of features (samples x features), categorical (int).</param>
/// <param name="y">Array of labels (0 or 1).</param>
public void Fit(int[][] x, int[] y)
{
if (x.Length == 0 || x[0].Length == 0)
{
throw new ArgumentException("Input features cannot be empty.");
}

if (x.Length != y.Length)
{
throw new ArgumentException("Number of samples and labels must match.");
}

root = BuildTree(x, y, Enumerable.Range(0, x[0].Length).ToList());
}

/// <summary>
/// Predicts the class label (0 or 1) for a single sample.
/// </summary>
public int Predict(int[] x)
{
if (root is null)
{
throw new InvalidOperationException("Model not trained.");
}

if (x.Length != FeatureCount)
{
throw new ArgumentException("Feature count mismatch.");
}

return Traverse(root, x);
}

/// <summary>
/// Gets the number of features used in training.
/// </summary>
public int FeatureCount => root?.FeatureCount ?? 0;

private static Node BuildTree(int[][] x, int[] y, List<int> features)
{
if (y.All(l => l == y[0]))
{
return new Node { Label = y[0], FeatureCount = x[0].Length };
}

if (features.Count == 0)
{
return new Node { Label = MostCommon(y), FeatureCount = x[0].Length };
}

int bestFeature = BestFeature(x, y, features);
var node = new Node { Feature = bestFeature, FeatureCount = x[0].Length };
var values = x.Select(row => row[bestFeature]).Distinct();
node.Children = new();
foreach (var v in values)
{
var idx = x.Select((row, i) => (row, i)).Where(t => t.row[bestFeature] == v).Select(t => t.i).ToArray();
if (idx.Length == 0)
{
continue;
}

var subX = idx.Select(i => x[i]).ToArray();
var subY = idx.Select(i => y[i]).ToArray();
var subFeatures = features.Where(f => f != bestFeature).ToList();
node.Children[v] = BuildTree(subX, subY, subFeatures);
}

return node;
}

private static int Traverse(Node node, int[] x)
{
if (node.Label is not null)
{
return node.Label.Value;
}

int v = x[node.Feature!.Value];
if (node.Children != null && node.Children.TryGetValue(v, out var child))
{
return Traverse(child, x);
}

// fallback to 0 if unseen value or Children is null
return 0;
}

private static int MostCommon(int[] y) => y.GroupBy(l => l).OrderByDescending(g => g.Count()).First().Key;

private static int BestFeature(int[][] x, int[] y, List<int> features)
{
double baseEntropy = Entropy(y);
double bestGain = double.MinValue;
int bestFeature = features[0];
foreach (var f in features)
{
var values = x.Select(row => row[f]).Distinct();
double splitEntropy = 0;
foreach (var v in values)
{
var idx = x.Select((row, i) => (row, i)).Where(t => t.row[f] == v).Select(t => t.i).ToArray();
if (idx.Length == 0)
{
continue;
}

var subY = idx.Select(i => y[i]).ToArray();
splitEntropy += (double)subY.Length / y.Length * Entropy(subY);
}

double gain = baseEntropy - splitEntropy;
if (gain > bestGain)
{
bestGain = gain;
bestFeature = f;
}
}

return bestFeature;
}

private static double Entropy(int[] y)
{
int n = y.Length;
if (n == 0)
{
return 0;
}

double p0 = y.Count(l => l == 0) / (double)n;
double p1 = y.Count(l => l == 1) / (double)n;
double e = 0;
if (p0 > 0)
{
e -= p0 * Math.Log2(p0);
}

if (p1 > 0)
{
e -= p1 * Math.Log2(p1);
}

return e;
}

private class Node
{
public int? Feature { get; set; }

public int? Label { get; set; }

public int FeatureCount { get; set; }

public Dictionary<int, Node>? Children { get; set; }
}
}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg
* [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering)
* [Machine Learning](./Algorithms/MachineLearning)
* [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs)
* [Decision Tree](./Algorithms/MachineLearning/DecisionTree.cs)
* [Searches](./Algorithms/Search)
* [A-Star](./Algorithms/Search/AStar/)
* [Binary Search](./Algorithms/Search/BinarySearcher.cs)
Expand Down