Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
higher order tensor
  • Loading branch information
Autoparallel committed Sep 27, 2024
commit 7350209da3736a89523e70e82cc54f60d5a0c6df
111 changes: 111 additions & 0 deletions src/tensor/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use core::{
marker::PhantomData,
ops::{Add, Mul},
};

use super::{Tensor, V};

trait TensorProduct<F> {
type T1: TensorProduct<F>;
type T2: TensorProduct<F>;

fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self;

fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2;

fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1;
}

impl<const M: usize, F: Default + Add<Output = F> + Mul<Output = F> + Copy> TensorProduct<F>
for V<M, F>
{
type T1 = Self;
type T2 = V<1, F>; // Scalar ring

fn tensor_product(tensor_1: Self::T1, _tensor_2: Self::T2) -> Self {
tensor_1
}

fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 {
let val = self
.0
.iter()
.zip(tensor_1.0.iter())
.fold(F::default(), |acc, (a, b)| acc + (*a * *b));
V([val])
}

fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 {
*self * tensor_2.0[0]
}
}

impl<const M: usize, const N: usize, F: Default + Add<Output = F> + Mul<Output = F> + Copy>
TensorProduct<F> for Tensor<M, N, F>
where
[(); M * N]:,
{
type T1 = V<M, F>;
type T2 = V<N, F>;

fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self {
todo!()
}

fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 {
todo!()
}

fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 {
todo!()
}
}

#[derive(Clone)]
pub struct HigherTensor<T1: TensorProduct<F>, T2: TensorProduct<F>, F> {
tensor_1: T1,
tensor_2: T2,
_p: PhantomData<F>,
}

impl<T1: TensorProduct<F>, T2: TensorProduct<F>, F> TensorProduct<F> for HigherTensor<T1, T2, F> {
type T1 = T1;
type T2 = T2;

fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self {
todo!()
}

fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 {
todo!()
}

fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 {
todo!()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn intro() {
let tensor_1 = Tensor::<3, 2, f64> {
coefficients: V::<2>(V::<3>([1, 2, 3])),
};
let tensor_2 = tensor_1.clone();

let tensor = HigherTensor {
tensor_1,
tensor_2,
_p: PhantomData,
};

let nested_tensor = HigherTensor {
tensor_1: tensor.clone(),
tensor_2: tensor.clone(),
_p: PhantomData,
};
}
}
2 changes: 2 additions & 0 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use core::ops::AddAssign;

use super::*;

pub mod extension;
pub mod macros;

#[derive(Clone)]
pub struct Tensor<const M: usize, const N: usize, F>
where
[(); M * N]:,
Expand Down