@@ -792,23 +792,51 @@ public void MultiplyCommutative()
792
792
Assert . IsTrue ( result1 . ToArray ( ) . SequenceEqual ( result2 . ToArray ( ) ) ) ;
793
793
}
794
794
795
+ [ TestMethod ]
796
+ public void MatMultiplyByVector ( )
797
+ {
798
+ var matrixA = new [ ] { 1.0 , 2.0 , 3.0 , 4.0 } ;
799
+ var a = NewVolume ( matrixA , Shape . From ( 2 , 2 ) ) ;
800
+
801
+ var matrixB = new [ ] { 1.0 , 2.0 } ;
802
+ var b = NewVolume ( matrixB , Shape . From ( 1 , 2 ) ) ;
803
+
804
+ var result = BuilderInstance < T > . Volume . SameAs ( new Shape ( 1 , 2 ) ) ;
805
+
806
+ a . MatMultiply ( b , result ) ;
807
+
808
+ AssertNumber . AreEqual ( 5.0 , result . Get ( 0 , 0 ) ) ;
809
+ AssertNumber . AreEqual ( 11.0 , result . Get ( 0 , 1 ) ) ;
810
+ }
811
+
795
812
[ TestMethod ]
796
813
public void MatMultiply ( )
797
814
{
798
815
var matrixA = new [ ] { 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 } ;
799
- var a = NewVolume ( matrixA , Shape . From ( 4 , 2 ) ) ;
816
+ var a = NewVolume ( matrixA , Shape . From ( 2 , 4 ) ) ;
800
817
801
818
var matrixB = new [ ] { 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 } ;
802
- var b = NewVolume ( matrixB , Shape . From ( 2 , 3 ) ) ;
819
+ var b = NewVolume ( matrixB , Shape . From ( 3 , 2 ) ) ;
803
820
804
- var result = BuilderInstance < T > . Volume . SameAs ( new Shape ( 4 , 3 ) ) ;
821
+ var result = BuilderInstance < T > . Volume . SameAs ( new Shape ( 3 , 4 ) ) ;
805
822
806
823
a . MatMultiply ( b , result ) ;
807
824
808
- AssertNumber . AreEqual ( 11.0 , result . Get ( 0 , 0 ) ) ;
809
- AssertNumber . AreEqual ( 68.0 , result . Get ( 3 , 2 ) ) ;
810
- AssertNumber . AreEqual ( 20.0 , result . Get ( 3 , 0 ) ) ;
811
- AssertNumber . AreEqual ( 35.0 , result . Get ( 0 , 2 ) ) ;
825
+ AssertNumber . AreEqual ( 9.0 , result . Get ( 0 , 0 ) ) ;
826
+ AssertNumber . AreEqual ( 12.0 , result . Get ( 1 , 0 ) ) ;
827
+ AssertNumber . AreEqual ( 15.0 , result . Get ( 2 , 0 ) ) ;
828
+
829
+ AssertNumber . AreEqual ( 19.0 , result . Get ( 0 , 1 ) ) ;
830
+ AssertNumber . AreEqual ( 26.0 , result . Get ( 1 , 1 ) ) ;
831
+ AssertNumber . AreEqual ( 33.0 , result . Get ( 2 , 1 ) ) ;
832
+
833
+ AssertNumber . AreEqual ( 29.0 , result . Get ( 0 , 2 ) ) ;
834
+ AssertNumber . AreEqual ( 40.0 , result . Get ( 1 , 2 ) ) ;
835
+ AssertNumber . AreEqual ( 51.0 , result . Get ( 2 , 2 ) ) ;
836
+
837
+ AssertNumber . AreEqual ( 39.0 , result . Get ( 0 , 3 ) ) ;
838
+ AssertNumber . AreEqual ( 54.0 , result . Get ( 1 , 3 ) ) ;
839
+ AssertNumber . AreEqual ( 69.0 , result . Get ( 2 , 3 ) ) ;
812
840
}
813
841
814
842
[ TestMethod ]
@@ -828,19 +856,19 @@ public void MatMultiplyBatch()
828
856
} ;
829
857
var b = NewVolume ( matrixB , Shape . From ( 2 , 3 , 1 , 2 ) ) ;
830
858
831
- var result = BuilderInstance < T > . Volume . SameAs ( new Shape ( 3 , 3 , 1 , 2 ) ) ;
859
+ var result = BuilderInstance < T > . Volume . SameAs ( new Shape ( 2 , 2 , 1 , 2 ) ) ;
832
860
833
861
a . MatMultiply ( b , result ) ;
834
862
835
- AssertNumber . AreEqual ( 5 .0, result . Get ( 0 , 0 ) ) ;
836
- AssertNumber . AreEqual ( 5 .0, result . Get ( 2 , 0 ) ) ;
837
- AssertNumber . AreEqual ( 17 .0, result . Get ( 0 , 2 ) ) ;
838
- AssertNumber . AreEqual ( 17 .0, result . Get ( 2 , 2 ) ) ;
839
-
840
- AssertNumber . AreEqual ( 11 .0, result . Get ( 0 , 0 , 0 , 1 ) ) ;
841
- AssertNumber . AreEqual ( 11 .0, result . Get ( 2 , 0 , 0 , 1 ) ) ;
842
- AssertNumber . AreEqual ( 39 .0, result . Get ( 0 , 2 , 0 , 1 ) ) ;
843
- AssertNumber . AreEqual ( 39 .0, result . Get ( 2 , 2 , 0 , 1 ) ) ;
863
+ AssertNumber . AreEqual ( 9 .0, result . Get ( 0 , 0 ) ) ;
864
+ AssertNumber . AreEqual ( 12 .0, result . Get ( 1 , 0 ) ) ;
865
+ AssertNumber . AreEqual ( 18 .0, result . Get ( 0 , 1 ) ) ;
866
+ AssertNumber . AreEqual ( 24 .0, result . Get ( 1 , 1 ) ) ;
867
+
868
+ AssertNumber . AreEqual ( 27 .0, result . Get ( 0 , 0 , 0 , 1 ) ) ;
869
+ AssertNumber . AreEqual ( 36 .0, result . Get ( 1 , 0 , 0 , 1 ) ) ;
870
+ AssertNumber . AreEqual ( 36 .0, result . Get ( 0 , 1 , 0 , 1 ) ) ;
871
+ AssertNumber . AreEqual ( 48 .0, result . Get ( 1 , 1 , 0 , 1 ) ) ;
844
872
}
845
873
846
874
[ TestMethod ]
0 commit comments