forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_c10_kernel.cpp
More file actions
74 lines (56 loc) · 1.84 KB
/
test_c10_kernel.cpp
File metadata and controls
74 lines (56 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <ATen/core/op_registration/op_registration.h>
#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/kernels/C10Kernel.h>
#include <torch/torch.h>
namespace torch::nativert {
at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
return a + b;
}
TEST(C10KernelTest, computeInternal) {
auto registrar = c10::RegisterOperators().op(
"test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);
static constexpr std::string_view source =
R"(graph(%a, %b):
%x = test.foo.default(a=%a, b=%b)
return (%x)
)";
auto graph = stringToGraph(source);
const auto& nodes = graph->nodes();
auto it = nodes.begin();
std::advance(it, 1);
const Node& node = *it;
auto a = at::randn({6, 6, 6});
auto b = at::randn({6, 6, 6});
auto frame = ExecutionFrame(*graph);
frame.setIValue(graph->getValue("a")->id(), a);
frame.setIValue(graph->getValue("b")->id(), b);
auto kernel = C10Kernel(&node);
kernel.computeInternal(frame);
at::Tensor expected = a + b;
EXPECT_TRUE(
torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
}
TEST(ScalarBinaryOpKernelTest, computeInternal) {
static constexpr std::string_view source =
R"(graph(%a, %b):
%x = _operator.add(a=%a, b=%b)
return (%x)
)";
auto graph = stringToGraph(source);
const auto& nodes = graph->nodes();
auto it = nodes.begin();
std::advance(it, 1);
const Node& node = *it;
auto a = 1;
auto b = 2;
auto frame = ExecutionFrame(*graph);
frame.setIValue(graph->getValue("a")->id(), a);
frame.setIValue(graph->getValue("b")->id(), b);
auto kernel = ScalarBinaryOpKernel(&node);
kernel.computeInternal(frame);
auto expected = a + b;
EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
}
} // namespace torch::nativert