Skip to content

Commit 2d7c2cf

Browse files
committed
Clone the ivalue custom object smart pointers before passing them to the C++ side.
1 parent 36323ef commit 2d7c2cf

5 files changed

Lines changed: 20 additions & 4 deletions

File tree

src/wrappers/jit.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ impl From<&str> for IValue {
239239
}
240240

241241
impl IValue {
242+
#![allow(unused_unsafe)]
242243
pub(super) fn to_c(&self) -> Result<*mut CIValue, TchError> {
243244
let c = unsafe_torch_err!(match self {
244245
IValue::Tensor(tensor) => ati_tensor(tensor.c_tensor),
@@ -296,7 +297,10 @@ impl IValue {
296297
}
297298
dict
298299
}
299-
IValue::Object(Object { c_ivalue }) => *c_ivalue,
300+
IValue::Object(Object { c_ivalue }) => {
301+
// Clone the object if necessary before passing the pointer to the C++ side.
302+
unsafe_torch_err!(ati_clone(*c_ivalue))
303+
}
300304
});
301305
Ok(c)
302306
}

tests/jit_tests.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ fn jit_double_free() {
153153
&Tensor::of_slice(&[4_f32, 5_f32, 6_f32]).into(),
154154
],
155155
);
156-
if false {
157-
let _output = foo.method_is("add_them", &[&input.unwrap()]);
158-
}
156+
let result = foo.method_is("add_them", &[&input.unwrap()]);
157+
let result = match result.unwrap() {
158+
IValue::Tensor(tensor) => tensor,
159+
result => panic!("expected a tensor got {:?}", result),
160+
};
161+
assert_eq!(Vec::<f64>::from(&result), [5.0, 7.0, 9.0])
159162
}

torch-sys/libtch/torch_api.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,13 @@ ivalue ati_object_method_(ivalue i, char *method_name, ivalue *ivalues, int niva
14081408
return nullptr;
14091409
}
14101410

1411+
ivalue ati_clone(ivalue i) {
1412+
PROTECT(
1413+
return new torch::jit::IValue(*i);
1414+
)
1415+
return nullptr;
1416+
}
1417+
14111418
void ati_free(ivalue i) {
14121419
delete(i);
14131420
}

torch-sys/libtch/torch_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ int ati_tag(ivalue);
215215

216216
ivalue ati_object_method_(ivalue i, char *method_name, ivalue *ivalues, int nivalues);
217217

218+
ivalue ati_clone(ivalue);
218219
void ati_free(ivalue);
219220

220221
void at_set_graph_executor_optimize(bool);

torch-sys/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ extern "C" {
238238
pub fn ati_to_tensor_list(arg: *mut CIValue, outputs: *mut *mut C_tensor, n: c_int);
239239
pub fn ati_to_string(arg: *mut CIValue) -> *mut c_char;
240240

241+
pub fn ati_clone(arg: *mut CIValue) -> *mut CIValue;
241242
pub fn ati_free(arg: *mut CIValue);
242243

243244
pub fn ati_object_method_(

0 commit comments

Comments
 (0)