forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_safeguard.py
More file actions
47 lines (43 loc) · 2.05 KB
/
_safeguard.py
File metadata and controls
47 lines (43 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# mypy: allow-untyped-defs
import torch
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch.overrides import TorchFunctionMode
class AutogradStateOpsFailSafeguard(TorchFunctionMode):
"""
Detect grad state ops during exporting the graph and fail the process by
raising an error, to avoid unexpected behavior. Those grad mode ops could be:
`torch.no_grad`
`torch.enable_grad`
`torch.set_grad_enabled`
Export with predispatch mode is exempted.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
unsupported_grad_mode_ops = [
torch._C._set_grad_enabled,
]
# It's only enabled while tracing, by confirming the torch dispatch mode is
# any active PROXY. This is to allow the autograd ops out of tracing.
current_state = torch._C.is_grad_enabled()
if func in unsupported_grad_mode_ops:
if len(args) != 1:
raise AssertionError(
f"Expected exactly 1 argument for grad mode op, but got {len(args)}"
)
changed_state = args[0]
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
# Intend to check if it's not the pre_dispatch mode. It's allowed to use
# autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
if (
mode
and isinstance(mode, ProxyTorchDispatchMode)
and not mode.pre_dispatch
and changed_state != current_state
):
raise RuntimeError(
f"Encountered autograd state manager op {func} trying to change global autograd state "
"while exporting. This is unsafe because we don't capture this op in torch.export "
"today, hence we can't reflect the user intention soundly. You can fix this by "
"adding a torch.no_grad() context around the export call."
)
return func(*args, **kwargs)