Skip to content

Commit dc20fa5

Browse files
authored
Merge pull request tensorflow#223 from AndreaCatania/setget
Added set and get to the Tensor
2 parents 55d3786 + b968cff commit dc20fa5

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

src/lib.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,60 @@ impl<T: TensorType> Tensor<T> {
12111211
Ok(self)
12121212
}
12131213

1214+
/// Set one single value on the tensor.
1215+
///
1216+
/// ```
1217+
/// # use tensorflow::Tensor;
1218+
/// let mut a = Tensor::<i32>::new(&[3, 3, 3]);
1219+
///
1220+
/// a.set(&[0, 0, 1], 10);
1221+
/// assert_eq!(a[0 + 0 + 1], 10);
1222+
///
1223+
/// a.set(&[2, 2, 0], 9);
1224+
/// assert_eq!(a[2*9 + 2*3 + 0], 9);
1225+
/// ```
1226+
pub fn set(&mut self, indices: &[u64], value: T) {
1227+
let index = self.get_index(indices);
1228+
self[index] = value;
1229+
}
1230+
1231+
/// Get one single value from the Tensor.
1232+
///
1233+
/// ```
1234+
/// # use tensorflow::Tensor;
1235+
/// let mut a = Tensor::<i32>::new(&[2, 3, 5]);
1236+
///
1237+
/// a[1*15 + 1*5 + 1] = 5;
1238+
/// assert_eq!(a.get(&[1, 1, 1]), 5);
1239+
/// ```
1240+
pub fn get(&self, indices: &[u64]) -> T {
1241+
let index = self.get_index(indices);
1242+
self[index].clone()
1243+
}
1244+
1245+
/// Get the array index from rows / columns indices.
1246+
///
1247+
/// ```
1248+
/// # use tensorflow::Tensor;
1249+
/// let a = Tensor::<f32>::new(&[3, 3, 3]);
1250+
///
1251+
/// assert_eq!(a.get_index(&[2, 2, 2]), 26);
1252+
/// assert_eq!(a.get_index(&[1, 2, 2]), 17);
1253+
/// assert_eq!(a.get_index(&[1, 2, 0]), 15);
1254+
/// assert_eq!(a.get_index(&[1, 0, 1]), 10);
1255+
/// ```
1256+
pub fn get_index(&self, indices: &[u64]) -> usize {
1257+
assert!(self.dims.len() == indices.len());
1258+
let mut index = 0;
1259+
let mut d = 1;
1260+
for i in (0..indices.len()).rev() {
1261+
assert!(self.dims[i] > indices[i]);
1262+
index += indices[i] * d;
1263+
d *= self.dims[i];
1264+
}
1265+
index as usize
1266+
}
1267+
12141268
/// Returns the tensor's dimensions.
12151269
pub fn dims(&self) -> &[u64] {
12161270
&self.dims

0 commit comments

Comments
 (0)