Skip to content

Commit e571ef1

Browse files
authored
feat: (bugged) #[derive(MultilinearMap)] ` ( (#3)
* feat: (bugged) `#[derive(MultilinearMap)]`
1 parent e60f764 commit e571ef1

File tree

5 files changed

+190
-18
lines changed

5 files changed

+190
-18
lines changed

Cargo.lock

Lines changed: 45 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ name = "tensor"
33
version = "0.1.0"
44
edition = "2021"
55

6-
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
6+
[dependencies]
7+
tensor_macros = { path = "macros/" }
8+
79

810
[dev-dependencies]
911
# tracing-subscriber = { version = "0.3.18", default-features = false, features = [

macros/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "tensor_macros"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[lib]
7+
proc-macro = true
8+
9+
[dependencies]
10+
syn = { version = "2.0.60", features = ["full"] }
11+
quote = "1.0.36"
12+
proc-macro2 = "1.0.81"

macros/src/lib.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use proc_macro::TokenStream;
2+
use quote::quote;
3+
use syn::{parse_macro_input, DeriveInput, GenericParam, Ident};
4+
5+
#[proc_macro_derive(MultilinearMap)]
6+
pub fn multilinear_map_derive(input: TokenStream) -> TokenStream {
7+
let ast = parse_macro_input!(input as DeriveInput);
8+
let struct_name = &ast.ident;
9+
10+
let generics = &ast.generics;
11+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
12+
13+
let const_generics: Vec<_> = generics
14+
.params
15+
.iter()
16+
.filter_map(|param| match param {
17+
GenericParam::Const(const_param) => Some(&const_param.ident),
18+
_ => None,
19+
})
20+
.collect();
21+
22+
let input_params = const_generics.iter().enumerate().map(|(i, ident)| {
23+
let param_name = Ident::new(&format!("v_{}", i), ident.span());
24+
quote! { #param_name: V<#ident, F> }
25+
});
26+
27+
let loop_indices: Vec<_> = (0..const_generics.len())
28+
.map(|i| Ident::new(&format!("i_{}", i), proc_macro2::Span::call_site()))
29+
.collect();
30+
31+
let component_product = loop_indices.iter().zip(0..).map(|(index, i)| {
32+
let param_name = Ident::new(&format!("v_{}", i), index.span());
33+
quote! { * #param_name.0[#index] }
34+
});
35+
36+
// Add the calculation to the innermost loop
37+
let coefficient_access =
38+
loop_indices
39+
.iter()
40+
.fold(quote! { self.coefficients }, |acc, index| {
41+
quote! { #acc.0[#index] }
42+
});
43+
44+
let mut loop_nest = quote! {
45+
sum += #coefficient_access #(#component_product)*;
46+
};
47+
48+
for (index, ident) in loop_indices.iter().rev().zip(const_generics.iter().rev()) {
49+
loop_nest = quote! {
50+
for #index in 0..#ident {
51+
#loop_nest
52+
}
53+
};
54+
}
55+
56+
loop_nest = quote! {
57+
#loop_nest
58+
59+
};
60+
61+
let expanded = quote! {
62+
impl #impl_generics #struct_name #ty_generics #where_clause {
63+
pub fn multilinear_map(&self, #(#input_params),*) -> F {
64+
let mut sum = F::default();
65+
#loop_nest
66+
sum
67+
}
68+
}
69+
};
70+
71+
TokenStream::from(expanded)
72+
}

src/tensor/macros.rs

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
use super::*;
22

3+
// TODO: Could probably just assign a valence to the tensors and use N0, N1, N2,
4+
// etc. as dims
5+
36
#[macro_export]
47
macro_rules! tensor {
5-
($name:ident, $($const:ident),+) => {
6-
pub struct $name<$(const $const: usize),+, F>
7-
where F: Default + Copy,
8+
($name:ident, $($consts:ident),+) => {
9+
#[derive(tensor_macros::MultilinearMap)]
10+
pub struct $name<$(const $consts: usize),+, F>
11+
where F: Default + Copy + AddAssign + Mul<F, Output = F>,
812
{
9-
pub coefficients: coeff_builder!($($const),+; F),
13+
pub coefficients: coeff_builder!($($consts),+; F),
1014
}
1115

12-
impl<$(const $const: usize),+, F: Default + Copy> Default for $name<$($const),+, F> {
16+
impl<$(const $consts: usize),+, F: Default + Copy + AddAssign + Mul<F, Output = F>> Default for $name<$($consts),+, F> {
1317
fn default() -> Self {
14-
let coefficients = <def_builder!($($const),+; F)>::default();
18+
let coefficients = <def_builder!($($consts),+; F)>::default();
1519
$name { coefficients }
1620
}
1721

1822
}
1923

20-
impl<$(const $const: usize),+, F> Debug for $name<$($const),+, F>
24+
impl<$(const $consts: usize),+, F> Debug for $name<$($consts),+, F>
2125
where
22-
F: Default + Copy + Debug,
26+
F: Default + Copy + Debug + AddAssign + Mul<F, Output = F>,
2327
{
2428
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
2529
f.debug_struct(stringify!($name))
@@ -28,34 +32,31 @@ macro_rules! tensor {
2832
}
2933
}
3034

31-
impl<$(const $const: usize),+, F> Add for $name<$($const),+, F>
35+
impl<$(const $consts: usize),+, F> Add for $name<$($consts),+, F>
3236
where
33-
F: Add<Output = F> + Copy + Default,
37+
F: Add<Output = F> + Copy + Default + AddAssign + Mul<F, Output = F>,
3438
{
3539
type Output = Self;
3640

3741
fn add(self, other: Self) -> Self::Output {
3842
let mut result = Self::default();
39-
add_tensors!(result.coefficients, self.coefficients, other.coefficients; $($const),+);
43+
add_tensors!(result.coefficients, self.coefficients, other.coefficients; $($consts),+);
4044
result
4145
}
4246
}
4347

44-
impl<$(const $const: usize),+, F> Mul<F> for $name<$($const),+, F>
48+
impl<$(const $consts: usize),+, F> Mul<F> for $name<$($consts),+, F>
4549
where
46-
F: Mul<Output = F> + Copy + Default,
50+
F: Mul<Output = F> + Copy + Default + AddAssign,
4751
{
4852
type Output = Self;
4953

5054
fn mul(self, scalar: F) -> Self::Output {
5155
let mut result = Self::default();
52-
scalar_mul_tensor!(result.coefficients, self.coefficients, scalar; $($const),+);
56+
scalar_mul_tensor!(result.coefficients, self.coefficients, scalar; $($consts),+);
5357
result
5458
}
5559
}
56-
57-
58-
5960
}
6061
}
6162

@@ -103,11 +104,14 @@ macro_rules! scalar_mul_tensor {
103104
};
104105
}
105106

107+
tensor!(TensorTester, M, N, P);
108+
106109
#[cfg(test)]
107110
mod tests {
108111

109112
use super::*;
110113
tensor!(Tensor2, M, N);
114+
111115
tensor!(Tensor3, M, N, P);
112116

113117
use log::{debug, info};
@@ -161,4 +165,41 @@ mod tests {
161165
let tensor2 = tensor1 * scalar;
162166
info!("output: {:?}", tensor2.coefficients);
163167
}
168+
169+
#[test]
170+
fn multilinear_map() {
171+
log();
172+
// / 1 0 0 \
173+
// tensor = \ 0 1 0 /
174+
let mut tensor = Tensor2::<2, 3, f64>::default();
175+
tensor.coefficients.0[0].0[0] = 1.0;
176+
tensor.coefficients.0[1].0[1] = 1.0;
177+
debug!("tensor: {:?}", tensor);
178+
179+
// / -1 \
180+
// v_0 = \ 1 /
181+
let mut v_0 = V::default();
182+
v_0.0[0] = -1.0;
183+
v_0.0[1] = 1.0;
184+
debug!("v_0: {:?}", v_0);
185+
186+
// / 1 \
187+
// | 2 |
188+
// v_1 = \ 3 /
189+
let mut v_1 = V::default();
190+
v_1.0[0] = 1.0;
191+
v_1.0[1] = 2.0;
192+
v_1.0[2] = 3.0;
193+
debug!("v_1: {:?}", v_1);
194+
195+
// / 1 \
196+
// tensor.map(_,v_1) = \ 2 /
197+
//
198+
// then the next is:
199+
// / 1 \
200+
// tensor.map(v_0, v_1) = < -1 1 > \ 2 / = -1 + 2 = 1
201+
let output = tensor.multilinear_map(v_0, v_1);
202+
info!("output: {:?}", output);
203+
assert_eq!(output, 1.0);
204+
}
164205
}

0 commit comments

Comments
 (0)