diff --git a/src/lib.rs b/src/lib.rs index 8546fca75e..54ad671b97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1211,6 +1211,60 @@ impl Tensor { Ok(self) } + /// Set one single value on the tensor. + /// + /// ``` + /// # use tensorflow::Tensor; + /// let mut a = Tensor::::new(&[3, 3, 3]); + /// + /// a.set(&[0, 0, 1], 10); + /// assert_eq!(a[0 + 0 + 1], 10); + /// + /// a.set(&[2, 2, 0], 9); + /// assert_eq!(a[2*9 + 2*3 + 0], 9); + /// ``` + pub fn set(&mut self, indices: &[u64], value: T) { + let index = self.get_index(indices); + self[index] = value; + } + + /// Get one single value from the Tensor. + /// + /// ``` + /// # use tensorflow::Tensor; + /// let mut a = Tensor::::new(&[2, 3, 5]); + /// + /// a[1*15 + 1*5 + 1] = 5; + /// assert_eq!(a.get(&[1, 1, 1]), 5); + /// ``` + pub fn get(&self, indices: &[u64]) -> T { + let index = self.get_index(indices); + self[index].clone() + } + + /// Get the array index from rows / columns indices. + /// + /// ``` + /// # use tensorflow::Tensor; + /// let a = Tensor::::new(&[3, 3, 3]); + /// + /// assert_eq!(a.get_index(&[2, 2, 2]), 26); + /// assert_eq!(a.get_index(&[1, 2, 2]), 17); + /// assert_eq!(a.get_index(&[1, 2, 0]), 15); + /// assert_eq!(a.get_index(&[1, 0, 1]), 10); + /// ``` + pub fn get_index(&self, indices: &[u64]) -> usize { + assert!(self.dims.len() == indices.len()); + let mut index = 0; + let mut d = 1; + for i in (0..indices.len()).rev() { + assert!(self.dims[i] > indices[i]); + index += indices[i] * d; + d *= self.dims[i]; + } + index as usize + } + /// Returns the tensor's dimensions. pub fn dims(&self) -> &[u64] { &self.dims