11use super :: * ;
22
3+ // TODO: Could probably just assign a valence to the tensors and use N0, N1, N2,
4+ // etc. as dims
5+
36#[ macro_export]
47macro_rules! tensor {
5- ( $name: ident, $( $const: ident) ,+) => {
6- pub struct $name<$( const $const: usize ) ,+, F >
7- where F : Default + Copy ,
8+ ( $name: ident, $( $consts: ident) ,+) => {
9+ #[ derive( tensor_macros:: MultilinearMap ) ]
10+ pub struct $name<$( const $consts: usize ) ,+, F >
11+ where F : Default + Copy + AddAssign + Mul <F , Output = F >,
812 {
9- pub coefficients: coeff_builder!( $( $const ) ,+; F ) ,
13+ pub coefficients: coeff_builder!( $( $consts ) ,+; F ) ,
1014 }
1115
12- impl <$( const $const : usize ) ,+, F : Default + Copy > Default for $name<$( $const ) ,+, F > {
16+ impl <$( const $consts : usize ) ,+, F : Default + Copy + AddAssign + Mul < F , Output = F >> Default for $name<$( $consts ) ,+, F > {
1317 fn default ( ) -> Self {
14- let coefficients = <def_builder!( $( $const ) ,+; F ) >:: default ( ) ;
18+ let coefficients = <def_builder!( $( $consts ) ,+; F ) >:: default ( ) ;
1519 $name { coefficients }
1620 }
1721
1822 }
1923
20- impl <$( const $const : usize ) ,+, F > Debug for $name<$( $const ) ,+, F >
24+ impl <$( const $consts : usize ) ,+, F > Debug for $name<$( $consts ) ,+, F >
2125 where
22- F : Default + Copy + Debug ,
26+ F : Default + Copy + Debug + AddAssign + Mul < F , Output = F > ,
2327 {
2428 fn fmt( & self , f: & mut Formatter <' _>) -> Result {
2529 f. debug_struct( stringify!( $name) )
@@ -28,34 +32,31 @@ macro_rules! tensor {
2832 }
2933 }
3034
31- impl <$( const $const : usize ) ,+, F > Add for $name<$( $const ) ,+, F >
35+ impl <$( const $consts : usize ) ,+, F > Add for $name<$( $consts ) ,+, F >
3236 where
33- F : Add <Output = F > + Copy + Default ,
37+ F : Add <Output = F > + Copy + Default + AddAssign + Mul < F , Output = F > ,
3438 {
3539 type Output = Self ;
3640
3741 fn add( self , other: Self ) -> Self :: Output {
3842 let mut result = Self :: default ( ) ;
39- add_tensors!( result. coefficients, self . coefficients, other. coefficients; $( $const ) ,+) ;
43+ add_tensors!( result. coefficients, self . coefficients, other. coefficients; $( $consts ) ,+) ;
4044 result
4145 }
4246 }
4347
44- impl <$( const $const : usize ) ,+, F > Mul <F > for $name<$( $const ) ,+, F >
48+ impl <$( const $consts : usize ) ,+, F > Mul <F > for $name<$( $consts ) ,+, F >
4549 where
46- F : Mul <Output = F > + Copy + Default ,
50+ F : Mul <Output = F > + Copy + Default + AddAssign ,
4751 {
4852 type Output = Self ;
4953
5054 fn mul( self , scalar: F ) -> Self :: Output {
5155 let mut result = Self :: default ( ) ;
52- scalar_mul_tensor!( result. coefficients, self . coefficients, scalar; $( $const ) ,+) ;
56+ scalar_mul_tensor!( result. coefficients, self . coefficients, scalar; $( $consts ) ,+) ;
5357 result
5458 }
5559 }
56-
57-
58-
5960 }
6061}
6162
@@ -103,11 +104,14 @@ macro_rules! scalar_mul_tensor {
103104 } ;
104105}
105106
107+ tensor ! ( TensorTester , M , N , P ) ;
108+
106109#[ cfg( test) ]
107110mod tests {
108111
109112 use super :: * ;
110113 tensor ! ( Tensor2 , M , N ) ;
114+
111115 tensor ! ( Tensor3 , M , N , P ) ;
112116
113117 use log:: { debug, info} ;
@@ -161,4 +165,41 @@ mod tests {
161165 let tensor2 = tensor1 * scalar;
162166 info ! ( "output: {:?}" , tensor2. coefficients) ;
163167 }
168+
169+ #[ test]
170+ fn multilinear_map ( ) {
171+ log ( ) ;
172+ // / 1 0 0 \
173+ // tensor = \ 0 1 0 /
174+ let mut tensor = Tensor2 :: < 2 , 3 , f64 > :: default ( ) ;
175+ tensor. coefficients . 0 [ 0 ] . 0 [ 0 ] = 1.0 ;
176+ tensor. coefficients . 0 [ 1 ] . 0 [ 1 ] = 1.0 ;
177+ debug ! ( "tensor: {:?}" , tensor) ;
178+
179+ // / -1 \
180+ // v_0 = \ 1 /
181+ let mut v_0 = V :: default ( ) ;
182+ v_0. 0 [ 0 ] = -1.0 ;
183+ v_0. 0 [ 1 ] = 1.0 ;
184+ debug ! ( "v_0: {:?}" , v_0) ;
185+
186+ // / 1 \
187+ // | 2 |
188+ // v_1 = \ 3 /
189+ let mut v_1 = V :: default ( ) ;
190+ v_1. 0 [ 0 ] = 1.0 ;
191+ v_1. 0 [ 1 ] = 2.0 ;
192+ v_1. 0 [ 2 ] = 3.0 ;
193+ debug ! ( "v_1: {:?}" , v_1) ;
194+
195+ // / 1 \
196+ // tensor.map(_,v_1) = \ 2 /
197+ //
198+ // then the next is:
199+ // / 1 \
200+ // tensor.map(v_0, v_1) = < -1 1 > \ 2 / = -1 + 2 = 1
201+ let output = tensor. multilinear_map ( v_0, v_1) ;
202+ info ! ( "output: {:?}" , output) ;
203+ assert_eq ! ( output, 1.0 ) ;
204+ }
164205}
0 commit comments