Skip to content

Commit 48083fa

Browse files
committed
matrix multiplication
1 parent 9d3d279 commit 48083fa

File tree

4 files changed

+68
-28
lines changed

4 files changed

+68
-28
lines changed

src/ConvNetSharp.Volume.Tests/VolumeTests.cs

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -792,23 +792,51 @@ public void MultiplyCommutative()
792792
Assert.IsTrue(result1.ToArray().SequenceEqual(result2.ToArray()));
793793
}
794794

795+
[TestMethod]
796+
public void MatMultiplyByVector()
797+
{
798+
var matrixA = new[] { 1.0, 2.0, 3.0, 4.0 };
799+
var a = NewVolume(matrixA, Shape.From(2, 2));
800+
801+
var matrixB = new[] { 1.0, 2.0 };
802+
var b = NewVolume(matrixB, Shape.From(1, 2));
803+
804+
var result = BuilderInstance<T>.Volume.SameAs(new Shape(1, 2));
805+
806+
a.MatMultiply(b, result);
807+
808+
AssertNumber.AreEqual(5.0, result.Get(0, 0));
809+
AssertNumber.AreEqual(11.0, result.Get(0, 1));
810+
}
811+
795812
[TestMethod]
796813
public void MatMultiply()
797814
{
798815
var matrixA = new[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 };
799-
var a = NewVolume(matrixA, Shape.From(4, 2));
816+
var a = NewVolume(matrixA, Shape.From(2, 4));
800817

801818
var matrixB = new[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
802-
var b = NewVolume(matrixB, Shape.From(2, 3));
819+
var b = NewVolume(matrixB, Shape.From(3, 2));
803820

804-
var result = BuilderInstance<T>.Volume.SameAs(new Shape(4, 3));
821+
var result = BuilderInstance<T>.Volume.SameAs(new Shape(3, 4));
805822

806823
a.MatMultiply(b, result);
807824

808-
AssertNumber.AreEqual(11.0, result.Get(0, 0));
809-
AssertNumber.AreEqual(68.0, result.Get(3, 2));
810-
AssertNumber.AreEqual(20.0, result.Get(3, 0));
811-
AssertNumber.AreEqual(35.0, result.Get(0, 2));
825+
AssertNumber.AreEqual(9.0, result.Get(0, 0));
826+
AssertNumber.AreEqual(12.0, result.Get(1, 0));
827+
AssertNumber.AreEqual(15.0, result.Get(2, 0));
828+
829+
AssertNumber.AreEqual(19.0, result.Get(0, 1));
830+
AssertNumber.AreEqual(26.0, result.Get(1, 1));
831+
AssertNumber.AreEqual(33.0, result.Get(2, 1));
832+
833+
AssertNumber.AreEqual(29.0, result.Get(0, 2));
834+
AssertNumber.AreEqual(40.0, result.Get(1, 2));
835+
AssertNumber.AreEqual(51.0, result.Get(2, 2));
836+
837+
AssertNumber.AreEqual(39.0, result.Get(0, 3));
838+
AssertNumber.AreEqual(54.0, result.Get(1, 3));
839+
AssertNumber.AreEqual(69.0, result.Get(2, 3));
812840
}
813841

814842
[TestMethod]
@@ -828,19 +856,19 @@ public void MatMultiplyBatch()
828856
};
829857
var b = NewVolume(matrixB, Shape.From(2, 3, 1, 2));
830858

831-
var result = BuilderInstance<T>.Volume.SameAs(new Shape(3, 3, 1, 2));
859+
var result = BuilderInstance<T>.Volume.SameAs(new Shape(2, 2, 1, 2));
832860

833861
a.MatMultiply(b, result);
834862

835-
AssertNumber.AreEqual(5.0, result.Get(0, 0));
836-
AssertNumber.AreEqual(5.0, result.Get(2, 0));
837-
AssertNumber.AreEqual(17.0, result.Get(0, 2));
838-
AssertNumber.AreEqual(17.0, result.Get(2, 2));
839-
840-
AssertNumber.AreEqual(11.0, result.Get(0, 0, 0, 1));
841-
AssertNumber.AreEqual(11.0, result.Get(2, 0, 0, 1));
842-
AssertNumber.AreEqual(39.0, result.Get(0, 2, 0, 1));
843-
AssertNumber.AreEqual(39.0, result.Get(2, 2, 0, 1));
863+
AssertNumber.AreEqual(9.0, result.Get(0, 0));
864+
AssertNumber.AreEqual(12.0, result.Get(1, 0));
865+
AssertNumber.AreEqual(18.0, result.Get(0, 1));
866+
AssertNumber.AreEqual(24.0, result.Get(1, 1));
867+
868+
AssertNumber.AreEqual(27.0, result.Get(0, 0, 0, 1));
869+
AssertNumber.AreEqual(36.0, result.Get(1, 0, 0, 1));
870+
AssertNumber.AreEqual(36.0, result.Get(0, 1, 0, 1));
871+
AssertNumber.AreEqual(48.0, result.Get(1, 1, 0, 1));
844872
}
845873

846874
[TestMethod]

src/ConvNetSharp.Volume/Double/Volume.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ public override void MatMultiply(Volume<double> right, Volume<double> result)
362362
throw new ArgumentException($"Left and right volumes should have the same batch size. left = {this.Shape.Dimensions[3]} right = {right.Shape.Dimensions[3]}");
363363
}
364364

365-
var expectedShape = new Shape(this.Shape.Dimensions[0], right.Shape.Dimensions[1], 1, this.Shape.Dimensions[3]);
365+
var expectedShape = ComputeMatMultiplyShape(this.Shape, right.Shape);
366366

367367
if (!result.Shape.Equals(expectedShape))
368368
{
@@ -376,9 +376,9 @@ public override void MatMultiply(Volume<double> right, Volume<double> result)
376376
for (var j = 0; j < expectedShape.Dimensions[1]; j++)
377377
{
378378
var cell = 0.0;
379-
for (var k = 0; k < this.Shape.Dimensions[1]; k++)
379+
for (var k = 0; k < this.Shape.Dimensions[0]; k++)
380380
{
381-
cell = cell + Get(i, k, 0, n) * right.Get(k, j, 0, n);
381+
cell = cell + Get(k, j, 0, n) * right.Get(i, k, 0, n);
382382
}
383383

384384
result.Set(i, j, 0, n, cell);

src/ConvNetSharp.Volume/Single/Volume.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ public override void MatMultiply(Volume<float> right, Volume<float> result)
362362
throw new ArgumentException($"Left and right volumes should have the same batch size. left = {this.Shape.Dimensions[3]} right = {right.Shape.Dimensions[3]}");
363363
}
364364

365-
var expectedShape = new Shape(this.Shape.Dimensions[0], right.Shape.Dimensions[1], 1, this.Shape.Dimensions[3]);
365+
var expectedShape = ComputeMatMultiplyShape(this.Shape, right.Shape);
366366

367367
if (!result.Shape.Equals(expectedShape))
368368
{
@@ -376,9 +376,9 @@ public override void MatMultiply(Volume<float> right, Volume<float> result)
376376
for (var j = 0; j < expectedShape.Dimensions[1]; j++)
377377
{
378378
var cell = 0.0f;
379-
for (var k = 0; k < this.Shape.Dimensions[1]; k++)
379+
for (var k = 0; k < this.Shape.Dimensions[0]; k++)
380380
{
381-
cell = cell + Get(i, k, 0, n) * right.Get(k, j, 0, n);
381+
cell = cell + Get(k, j, 0, n) * right.Get(i, k, 0, n);
382382
}
383383

384384
result.Set(i, j, 0, n, cell);

src/ConvNetSharp.Volume/Volume.cs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ public static Shape ComputePoolShape(Shape inputShape, int windowWidth, int wind
7979
return new Shape(outputWidth, outputHeight, outputDepth, outputN);
8080
}
8181

82+
/// <summary>
83+
/// Compute expected 2D matrix multiplication result shape
84+
/// [K, M, 1, BatchSize] x [N, K, 1, BatchSize] => [N, M, 1, BatchSize]
85+
/// </summary>
86+
/// <param name="leftShape">left 2D matrix / volume</param>
87+
/// <param name="rightShape">right 2D matrix / volume</param>
88+
/// <returns></returns>
89+
public static Shape ComputeMatMultiplyShape(Shape leftShape, Shape rightShape)
90+
{
91+
return new Shape(rightShape.Dimensions[0], leftShape.Dimensions[1], 1, leftShape.Dimensions[3]);
92+
}
93+
8294
public abstract void Concat(Volume<T> right, Volume<T> result);
8395

8496
public abstract void Convolution(Volume<T> filters, int pad, int stride, Volume<T> result);
@@ -148,12 +160,12 @@ public void MapInplace(Func<T, T, T> f, Volume<T> other)
148160
/// <summary>
149161
/// Matrix multiplication
150162
/// left (this) x right = result
151-
/// Where left is a 2D volume of shape [M, K, 1, batchsize]
152-
/// right is a 2D volume of shape [K, N, 1, batchsize]
153-
/// and result is a 2D volume of shape [M, N, 1, batchsize]
163+
/// Where left is a 2D volume of shape [K, M, 1, batchsize]
164+
/// right is a 2D volume of shape [N, K, 1, batchsize]
165+
/// and result is a 2D volume of shape [N, M, 1, batchsize]
154166
/// </summary>
155-
/// <param name="right">2D volume of shape [K, N, 1, batchsize]</param>
156-
/// <param name="result">2D volume of shape [M, N, 1, batchsize]</param>
167+
/// <param name="right">2D volume of shape [N, K, 1, batchsize]</param>
168+
/// <param name="result">2D volume of shape [N, M, 1, batchsize]</param>
157169
public abstract void MatMultiply(Volume<T> right, Volume<T> result);
158170

159171
public abstract void Max(Volume<T> result);

0 commit comments

Comments
 (0)