You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: RFC-0011-InferenceMode.md
+7-36Lines changed: 7 additions & 36 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -62,14 +62,14 @@ In this RFC we introduces the following new concepts:
62
62
```
63
63
-**Inference mode** a thread local state that can be turned on via RAII guard/context manager. (Either you are in inference mode, or you are not.) Intuitively, inference mode lets you do inference only operation with better performance than normal mode.
64
64
- All operations do not create autograd graph, even if the inputs require_grad=True
65
-
- Setting requires_grad in inference mode will update requires_grad field on tensors, but behavior won't change.
65
+
- Setting requires_grad in inference mode will update requires_grad field on tensors, but it doesn't affect any behavior inside InferenceMode.
66
66
- Things that continue to work:
67
67
- Inplace operations on both normal/inference tensors are OK
68
68
- Inplace operation on inference tensor is guaranteed not to VC bump
69
69
- NB: if you do an inplace operation on a normal tensor, you WILL get a version counter bump
70
70
- View operations on both normal/inference tensors are OK
71
71
- View operation on inference tensor is guaranteed not to allocate view metadata
72
-
- View operation on normal tensor produces a "bad" normal tensor (for safety reasons). Bad normal tensors (impl: CreationMeta) cannot be inplace modified outside inference mode. These normal tensors behave identically to no_grad, except that they always raise error (rather than give a warning).
72
+
- View operation on normal tensor produces a normal tensor(NO_GRAD_FN), behavior is the same as creating a view inside NoGrad mode.
73
73
74
74
***Inference tensor** are tensors that are constructed if and only if inference mode is enabled, with the exception of views on normal tensors. Non-inference tensors are called **normal tensors**.
75
75
* Q: Why not views on normal tensors? A: Because we guarantee performance on inference tensors, but views on normal tensors require additional safety checks (e.g. normal tensor ----(view)---> ----(inplace)----> this should properly bump version on base which requires view produce a normal tensor).
@@ -81,48 +81,19 @@ In this RFC we introduces the following new concepts:
81
81
* Impl: Functional on normal tensors is allowed because we cannot conveniently ban it (VariableType/InplaceOrView kernel are all skipped)
82
82
* Mixing inference and normal tensors, even for functional operations, is forbidden.
83
83
* Why? For simplicity of implementation. In particular, if you save the inference tensor in backwards, you’re likely to hit an error in a weird place (better to error early). By forbidding mixed operations, it is impossible for this situation to occur.
84
-
* Impl: inference tensors are guaranteed not to have AutogradMeta
85
-
86
-
84
+
* Impl: inference tensors are guaranteed to have is_leaf=True.
87
85
88
86
89
87
-**Normal tensor** has both Autograd & InplaceOrView keys. This includes both `requires_grad=true` and `requires_grad=false` tensors. (see [Ideal end state] section for more details).
90
88
- Additional notes:
91
-
- All Inference tensors are created in inference mode, but not all of the tensors created in inference mode are inference tensors. For example, a view of normal tensor created in inference mode is still a normal tensor (but with special `creation_meta`!).
89
+
- All Inference tensors are created in inference mode, but not all of the tensors created in inference mode are inference tensors. For example, a view of normal tensor created in inference mode is still a normal tensor (but with special `creation_meta=NO_GRAD_FN`!).
92
90
- (Autograd & !InplaceOrView) and (!Autogad & InplaceOrView) are invalid states, we don't have such tensors.
93
91
94
-
# Expected Behavior
95
-
## Implementation:
96
-
1. Inference Mode: InplaceOrView not in included, Autograd in excluded
97
-
2. Normal Mode: InplaceOrView in included, Autograd not in excluded
98
-
3. In VariableType kernel, throw an error if input is inference tensor.
99
-
4. In InplaceOrView kernel, throw an error if Autograd keyset is not in excluded set already.
100
-
5. In VariableType kernel, throw an error if input is a view with `NO_VARIABLE_TYPE_VIEW` creation_meta.
101
-
## Behavior
102
-
| Mode | Input | Op | Go through Kernels | Produced Output |||
| InferenceMode | All inference tensors | functional | CPU | inference tensor |||
105
-
| InferenceMode | All inference tensors | view | CPU | inference tensor |||
106
-
| InferenceMode | All inference tensors | inplace | CPU | inference tensor |||
107
-
| InferenceMode | Contains normal tensor | functional | InplaceOrView(fallthrough), CPU | inference tensor |||
108
-
| InferenceMode | Contains normal tensor | view | InplaceOrView, CPU | normal tensor (with creation_meta=NO_VARIABLE_TYPE_VIEW) |||
109
-
| InferenceMode | Contains normal tensor | inplace | InplaceOrView, CPU | normal tensor (which is input itself with updated version) |||
110
-
| NormalMode | All inference tensors | functional | InplaceOrView(fallthrough), CPU | normal tensor (see note*) |||
111
-
| NormalMode | All inference tensors | view | InplaceOrView(ERROR4!), CPU ||||
112
-
| NormalMode | All inference tensors | inplace | InplaceOrView(ERROR4!), CPU ||||
113
-
| NormalMode | Mixed normal tensor and inference tensor | functional | VariableType(ERROR3!), InplaceOrView, CPU ||||
114
-
| NormalMode | Mixed normal tensor and inference tensor | view | VariableType(ERROR3!), InplaceOrView, CPU ||||
115
-
| NormalMode | Mixed normal tensor and inference tensor | inplace | VariableType(ERROR3!), InplaceOrView, CPU ||||
116
-
||||||||
117
-
||||||||
118
-
## additional notes:
119
-
1. ERROR3 means it hits (3) described in implementation section and ERROR4 means it hits (4) in implementation section.
120
-
2. Functional ops on inference tensors might run slower outside InferenceMode than inside.
121
-
But it's fine that we don't care about perf of this case that much.
92
+
122
93
123
94
## Alternative implementations we've considered and why they don't work:
124
95
1. For NormalMode + All inference tensors + functional op, an alternative behavior we prefer but didn't implement is throwing an error by forcing this op go through VariableType kernel and hit the assert_no_inference_tensor check. But to do that we'll have to add c10::autograd_dispatch_keyset to the globally enabled set, but doing that might accidentally call autograd kernel from a backend that doesn't match tensor input. Thus we allow functional ops run without throwing an error.
125
-
2.Why implementation (1) and (2)?
96
+
2.
126
97
```
127
98
// 1. When InferenceMode is enabled, Autograd dispatch keys are excluded
128
99
// but not InplaceOrView key.
@@ -154,7 +125,7 @@ In this RFC we introduces the following new concepts:
154
125
// broke our invariant: "Autograd keys must be in excluded set before
155
126
// reaching InplaceOrView kernel".
156
127
```
157
-
3.
128
+
158
129
# Ideal end state
159
130
Ideal end state is that we can link skip VariableType kernel when requires_grad=False which means we don't always go through VariableType kernel in normal mode.
160
131
But this work is currently blocked for the following reason:
0 commit comments