Skip to content

Commit cf95331

Browse files
Initial check-in of GPU XLA thunks dialect
PiperOrigin-RevId: 323077049 Change-Id: I4a9b32e6772aa342b74954d27edfdf43e4b40f62
1 parent 0943136 commit cf95331

File tree

8 files changed

+240
-3
lines changed

8 files changed

+240
-3
lines changed

tensorflow/compiler/mlir/runlit.cfg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
7474
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
7575
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
76-
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir'
76+
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
77+
'xla-thunks-opt'
7778
]
7879
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
7980
llvm_config.add_tool_substitutions(tools, tool_dirs)

tensorflow/compiler/xla/service/gpu/BUILD

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ load(
2727
"if_cuda_is_configured",
2828
)
2929
load("//tensorflow:tensorflow.bzl", "if_nccl")
30+
load("//third_party/mlir:tblgen.bzl", "gentbl")
3031

3132
package(
3233
default_visibility = [":friends"],
@@ -1875,3 +1876,49 @@ cc_library(
18751876
"@com_google_absl//absl/types:span",
18761877
],
18771878
)
1879+
1880+
gentbl(
1881+
name = "xla_thunks_ops_inc_gen",
1882+
tbl_outs = [
1883+
("-gen-op-decls", "ir/xla_thunks_ops.h.inc"),
1884+
("-gen-op-defs", "ir/xla_thunks_ops.cc.inc"),
1885+
("-gen-struct-attr-decls", "ir/xla_thunks_structs.h.inc"),
1886+
("-gen-struct-attr-defs", "ir/xla_thunks_structs.cc.inc"),
1887+
],
1888+
tblgen = "@llvm-project//mlir:mlir-tblgen",
1889+
td_file = "ir/xla_thunks_ops.td",
1890+
td_srcs = [
1891+
"@llvm-project//mlir:LLVMOpsTdFiles",
1892+
],
1893+
)
1894+
1895+
cc_library(
1896+
name = "xla_thunks_ops",
1897+
srcs = [
1898+
"ir/xla_thunks_ops.cc",
1899+
"ir/xla_thunks_ops.cc.inc",
1900+
"ir/xla_thunks_ops.h.inc",
1901+
],
1902+
hdrs = [
1903+
"ir/xla_thunks_ops.h",
1904+
],
1905+
deps = [
1906+
":xla_thunks_ops_inc_gen",
1907+
"//tensorflow/compiler/mlir/hlo",
1908+
"@llvm-project//mlir:IR",
1909+
"@llvm-project//mlir:LLVMDialect",
1910+
],
1911+
)
1912+
1913+
# Library with XLA thunks dialect static initialization.
1914+
cc_library(
1915+
name = "xla_thunks_dialect_registration",
1916+
srcs = [
1917+
"ir/dialect_registration.cc",
1918+
],
1919+
deps = [
1920+
":xla_thunks_ops",
1921+
"@llvm-project//mlir:IR",
1922+
],
1923+
alwayslink = 1,
1924+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h"
17+
18+
// Static initialization for GPU thunks op registration.
19+
static mlir::DialectRegistration<mlir::xla_thunks::XLAThunksDialect>
20+
xla_thunks_ops;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// This file defines the operations used in the Thunk dialect.
17+
18+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h"
19+
20+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
21+
#include "mlir/IR/Builders.h" // from @llvm-project
22+
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
23+
#include "mlir/IR/StandardTypes.h" // from @llvm-project
24+
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
25+
26+
namespace mlir {
27+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_structs.cc.inc"
28+
namespace xla_thunks {
29+
30+
XLAThunksDialect::XLAThunksDialect(MLIRContext *context)
31+
: Dialect(getDialectNamespace(), context) {
32+
addOperations<
33+
#define GET_OP_LIST
34+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc"
35+
>();
36+
}
37+
38+
#define GET_OP_CLASSES
39+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc"
40+
41+
} // namespace xla_thunks
42+
} // namespace mlir
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_
17+
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_
18+
19+
#include "mlir/IR/Dialect.h" // from @llvm-project
20+
#include "mlir/IR/OpDefinition.h" // from @llvm-project
21+
#include "mlir/IR/OpImplementation.h" // from @llvm-project
22+
23+
namespace mlir {
24+
class OpBuilder;
25+
26+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_structs.h.inc"
27+
28+
namespace xla_thunks {
29+
30+
class XLAThunksDialect : public Dialect {
31+
public:
32+
explicit XLAThunksDialect(MLIRContext *context);
33+
static StringRef getDialectNamespace() { return "xla_thunks"; }
34+
};
35+
36+
#define GET_OP_CLASSES
37+
#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h.inc"
38+
39+
} // namespace xla_thunks
40+
} // namespace mlir
41+
42+
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// Operation definition file for GPU thunks.
17+
18+
#ifndef XLA_THUNKS_OPS
19+
#define XLA_THUNKS_OPS
20+
21+
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
22+
include "mlir/IR/OpBase.td"
23+
24+
class LLVMPointerTo<Type ty>
25+
: ContainerType<ty,
26+
CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isPointerTy()">,
27+
"$_self.cast<::mlir::LLVM::LLVMType>().getPointerElementTy()",
28+
"LLVM pointer">;
29+
30+
def XLAThunks_Dialect : Dialect {
31+
let name = "xla_thunks";
32+
let cppNamespace = "xla_thunks";
33+
}
34+
35+
class ThunkOp<string mnemonic, list<OpTrait> traits = []> :
36+
Op<XLAThunks_Dialect, mnemonic, traits>;
37+
38+
def AllocationSlice : StructAttr<"AllocationSlice", XLAThunks_Dialect, [
39+
StructFieldAttr<"allocation_index", I64Attr>,
40+
StructFieldAttr<"offset", I64Attr>,
41+
StructFieldAttr<"size", I64Attr>,
42+
]> {
43+
let description = "Defines a slice of an allocation for XLA thunk ops";
44+
}
45+
46+
def MemzeroThunkOp : ThunkOp<"execute_memzero_thunk"> {
47+
let arguments = (ins
48+
LLVMPointerTo<LLVMI<8>>:$execute_params,
49+
AllocationSlice:$allocation_slice
50+
);
51+
let results = (outs
52+
I<1>:$ok,
53+
LLVMPointerTo<LLVMI<8>>:$error_message
54+
);
55+
}
56+
57+
#endif // XLA_THUNKS_OPS

tensorflow/compiler/xla/service/gpu/tests/BUILD

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,15 +479,28 @@ glob_lit_tests(
479479
"no_pip",
480480
],
481481
driver = "@llvm-project//mlir:run_lit.sh",
482-
test_file_exts = ["hlo"],
482+
test_file_exts = [
483+
"hlo",
484+
"mlir",
485+
],
483486
)
484487

485488
# Bundle together all of the test utilities that are used by tests.
486489
filegroup(
487490
name = "test_utilities",
488491
testonly = True,
489492
data = [
490-
"//tensorflow/compiler/xla/service/gpu/tests:hlo_to_llvm_ir",
493+
":hlo_to_llvm_ir",
494+
":xla-thunks-opt",
491495
"@llvm-project//llvm:FileCheck",
492496
],
493497
)
498+
499+
# Binary with only the thunks dialect registered, for testing purposes.
500+
tf_cc_binary(
501+
name = "xla-thunks-opt",
502+
deps = [
503+
"//tensorflow/compiler/mlir:tf_mlir_opt_main",
504+
"//tensorflow/compiler/xla/service/gpu:xla_thunks_dialect_registration",
505+
],
506+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: xla-thunks-opt %s | FileCheck --color --dump-input=fail %s
2+
3+
func @main( %execute_params: !llvm<"i8*"> ) {
4+
// CHECK: "xla_thunks.execute_memzero_thunk"
5+
// CHECK-SAME: {allocation_index = 0 : i64, offset = 128 : i64, size = 1024 : i64}
6+
// CHECK-SAME: (!llvm<"i8*">) -> (i1, !llvm<"i8*">)
7+
%ok, %error_message =
8+
"xla_thunks.execute_memzero_thunk"( %execute_params )
9+
{ allocation_slice = { allocation_index = 0
10+
, offset = 128
11+
, size = 1024 } }
12+
: (!llvm<"i8*">) -> (i1, !llvm<"i8*">)
13+
return
14+
}
15+

0 commit comments

Comments
 (0)