-
Notifications
You must be signed in to change notification settings - Fork 216
Description
Introduce Nx.block as an extensible, named computation block API that preserves defn tracing/shape inference while allowing backend-specific overrides through a protocol implemented by user structs. The goal is to generalize today's optional mechanism (implemented through Nx.Shared.optional and optional callbacks in Nx.Backend) from backend-owned dispatch to user-extensible dispatch, so advanced kernels (for example, custom shared objects for native acceleration or backend-specific implementations) can be plugged in without patching Nx/EXLA source code.
Background: how optional and runtime_call work today
Nx currently has two similar escape hatches:
Nx.Shared.optional/4for backend-specific optional ops with fallbackNx.runtime_callfor arbitrary runtime callbacks
They solve adjacent problems but with different trade-offs.
Nx.Shared.optional/4
Nx.Shared.optional(function_name, args, output, default_impl) dispatches in three steps:
- If backend exports the operation itself (
function_name/arity), call it - Else, if backend exports
optional/3, callbackend.optional(function_name, args, default_impl) - Else, run
default_impl.(args)
This is the current mechanism used by operations that may or may not be provided by a backend.
At the expression level, Nx.Defn.Expr.optional/4 builds a node:
%Nx.Defn.Expr{
op: :optional,
args: [call_expr, default_expr, callback]
}Nx.Defn.Evaluator then either:
- executes backend op if exported, or
- evaluates the traced default expression
EXLA additionally handles :optional nodes in compiler-specific ways, including special handling for known operations and a generic fallback path.
Nx.runtime_call
Nx.runtime_call creates a :runtime_call expression node carrying:
%Nx.Defn.Expr{
op: :runtime_call,
args: [tensor_expr_or_container, callback_fun, out_template]
}This is flexible, but it requires explicit output templates and does not infer output shape from a defn-traced callback body.
Problems with the current approach
The issue ("Make optional/custom calls mechanism extensible") points to a gap between internal optional dispatch and user-level extensibility.
- Backend-owned extensibility: only backend modules can provide custom implementations via exported ops or
optional/3 - Limited user integration: users cannot naturally attach per-call implementation structs/config to custom kernels inside
defn - Two separate mental models:
optionalhas fallback and defn tracing;runtime_callhas callback flexibility but no shape inference - No unified named-block abstraction: there is no first-class operation representing "this named block may have a custom implementation, otherwise run traced default"
Proposed solution: Nx.block
Add a new API that represents a named, extensible computation block:
Nx.block(struct, container, fn container -> ... end)Where:
structis a user-defined or Nx-defined struct that serves as both the block identity and its compile-time configuration. The struct module name replaces the need for a separate name and keyword arguments — backends pattern-match on the struct type to select specialized implementationscontainerholds the tensor inputs (tuple/list/container). Only tensor-like, runtime values go here.- The anonymous function receives the container and the configuration struct, and provides a default (defn-compatible) implementation
Static values vs tensors: Static configuration (like eps, mode, and other function-specific options) lives in the struct fields. Tensor inputs live in the container. This separation is important because static values are known at compile time and can influence compilation decisions, while tensors flow through the computation graph.
API shape examples
Nx-defined linear algebra operation with static configuration:
Nx.block(%Nx.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a, fn a, %Nx.LinAlg.QR{} = config ->
# default QR implementation provided by Nx.LinAlg
...
end)User-defined extensible block (e.g., from a library like Axon):
Nx.block(%Axon.FlashAttention{causal: true, block_size: 128}, {q, k, v}, fn {q, k, v}, %Axon.FlashAttention{} = config ->
# default portable attention implementation
scores = Nx.dot(q, [2], k, [2])
Nx.softmax(scores) |> Nx.dot(v)
end)Expected semantics of Nx.block
Nx.block is designed to combine strengths of both existing mechanisms:
- Output inference: infer output shape, type and gradient via traced default implementation (defn-like behavior)
- Struct-based identity: the struct module name serves as the block identity — no separate name atom needed. This is sufficient for debugging, compilation, and matching
- Extensible dispatch: backends pattern-match on the struct type; user-defined structs can have backend protocol implementations provided by third parties
- Safe fallback: if no backend-specific or protocol implementation is available, the default implementation runs
At the expression level:
%Nx.Defn.Expr{
op: :block,
args: [struct, args, default_implementation_fn]
}The struct carries both the block identity (via its module) and static configuration (via its fields). Backends pattern-match on the struct, which is equivalent to today's pattern matching on the name of the optional call.
Backend-owned protocol model
Each backend can define its own protocol for extensibility over Nx.block.
For example, EXLA could define something similar to the following:
defprotocol EXLA.CustomCall do
@doc "Returns a custom call implementation for the given block struct"
def custom_call(struct, args, client)
endThen a third-party library (e.g., Axon) or a user could provide a backend-specific implementation:
defimpl EXLA.CustomCall, for: Axon.FlashAttention do
alias EXLA.MLIR.Value
def custom_call(%Axon.FlashAttention{causal: causal, block_size: bs}, {%Value{} = q, %Value{} = k, %Value{} = v}, %EXLA.Client{} = client) do
case client do
%{platform: :host} ->
Value.custom_call(..., [q, k, v])
%{platform: :cuda} ->
Value.custom_call(..., [q, k, v])
%{platform: :rocm} ->
Value.custom_call(..., [q, k, v])
end
end
Finally, the EXLA compiler could catch Protocol.UndefinedError to fall back to its default behavior when no extensions are used.
Nx.block vs optional vs runtime_call
| Aspect | optional (today) |
runtime_call (today) |
Nx.block (proposed) |
|---|---|---|---|
| Primary use | Backend optional ops | Runtime callback invocation | Named extensible compute block |
| Fallback implementation | Yes | Callback itself is primary | Yes (default impl fn) |
| Output inference | Yes (from expr path) | No (requires explicit template) | Yes (from traced default impl) |
| User-extensible via struct/protocol | No | Not protocol-based | Yes (via backend protocols) |
| Identity mechanism | Op atom (:qr, etc.) |
No name | Struct module name |
| Static config | Keyword opts | N/A | Struct fields |
| Best for custom kernel config | Limited | Possible but shape-heavy | Yes |
References
- Nx, EXLA and Torchx source (current behavior):
nx/lib/nx/shared.exnx/lib/nx/defn/expr.exnx/lib/nx/defn/tree.exnx/lib/nx/defn/evaluator.exexla/lib/exla/defn.extorchx/lib/torchx/backend.ex
- Related APIs: