@@ -1211,6 +1211,60 @@ impl<T: TensorType> Tensor<T> {
1211
1211
Ok ( self )
1212
1212
}
1213
1213
1214
+ /// Set one single value on the tensor.
1215
+ ///
1216
+ /// ```
1217
+ /// # use tensorflow::Tensor;
1218
+ /// let mut a = Tensor::<i32>::new(&[3, 3, 3]);
1219
+ ///
1220
+ /// a.set(&[0, 0, 1], 10);
1221
+ /// assert_eq!(a[0 + 0 + 1], 10);
1222
+ ///
1223
+ /// a.set(&[2, 2, 0], 9);
1224
+ /// assert_eq!(a[2*9 + 2*3 + 0], 9);
1225
+ /// ```
1226
+ pub fn set ( & mut self , indices : & [ u64 ] , value : T ) {
1227
+ let index = self . get_index ( indices) ;
1228
+ self [ index] = value;
1229
+ }
1230
+
1231
+ /// Get one single value from the Tensor.
1232
+ ///
1233
+ /// ```
1234
+ /// # use tensorflow::Tensor;
1235
+ /// let mut a = Tensor::<i32>::new(&[2, 3, 5]);
1236
+ ///
1237
+ /// a[1*15 + 1*5 + 1] = 5;
1238
+ /// assert_eq!(a.get(&[1, 1, 1]), 5);
1239
+ /// ```
1240
+ pub fn get ( & self , indices : & [ u64 ] ) -> T {
1241
+ let index = self . get_index ( indices) ;
1242
+ self [ index] . clone ( )
1243
+ }
1244
+
1245
+ /// Get the array index from rows / columns indices.
1246
+ ///
1247
+ /// ```
1248
+ /// # use tensorflow::Tensor;
1249
+ /// let a = Tensor::<f32>::new(&[3, 3, 3]);
1250
+ ///
1251
+ /// assert_eq!(a.get_index(&[2, 2, 2]), 26);
1252
+ /// assert_eq!(a.get_index(&[1, 2, 2]), 17);
1253
+ /// assert_eq!(a.get_index(&[1, 2, 0]), 15);
1254
+ /// assert_eq!(a.get_index(&[1, 0, 1]), 10);
1255
+ /// ```
1256
+ pub fn get_index ( & self , indices : & [ u64 ] ) -> usize {
1257
+ assert ! ( self . dims. len( ) == indices. len( ) ) ;
1258
+ let mut index = 0 ;
1259
+ let mut d = 1 ;
1260
+ for i in ( 0 ..indices. len ( ) ) . rev ( ) {
1261
+ assert ! ( self . dims[ i] > indices[ i] ) ;
1262
+ index += indices[ i] * d;
1263
+ d *= self . dims [ i] ;
1264
+ }
1265
+ index as usize
1266
+ }
1267
+
1214
1268
/// Returns the tensor's dimensions.
1215
1269
pub fn dims ( & self ) -> & [ u64 ] {
1216
1270
& self . dims
0 commit comments