Skip to content

Replace optional callbacks with Nx.block #946

@josevalim

Description

@josevalim

Introduce Nx.block as an extensible, named computation block API that preserves defn tracing/shape inference while allowing backend-specific overrides through a protocol implemented by user structs. The goal is to generalize today's optional mechanism (implemented through Nx.Shared.optional and optional callbacks in Nx.Backend) from backend-owned dispatch to user-extensible dispatch, so advanced kernels (for example, custom shared objects for native acceleration or backend-specific implementations) can be plugged in without patching Nx/EXLA source code.

Background: how optional and runtime_call work today

Nx currently has two similar escape hatches:

  1. Nx.Shared.optional/4 for backend-specific optional ops with fallback
  2. Nx.runtime_call for arbitrary runtime callbacks

They solve adjacent problems but with different trade-offs.

Nx.Shared.optional/4

Nx.Shared.optional(function_name, args, output, default_impl) dispatches in three steps:

  1. If backend exports the operation itself (function_name/arity), call it
  2. Else, if backend exports optional/3, call backend.optional(function_name, args, default_impl)
  3. Else, run default_impl.(args)

This is the current mechanism used by operations that may or may not be provided by a backend.

At the expression level, Nx.Defn.Expr.optional/4 builds a node:

%Nx.Defn.Expr{
  op: :optional,
  args: [call_expr, default_expr, callback]
}

Nx.Defn.Evaluator then either:

  • executes backend op if exported, or
  • evaluates the traced default expression

EXLA additionally handles :optional nodes in compiler-specific ways, including special handling for known operations and a generic fallback path.

Nx.runtime_call

Nx.runtime_call creates a :runtime_call expression node carrying:

%Nx.Defn.Expr{
  op: :runtime_call,
  args: [tensor_expr_or_container, callback_fun, out_template]
}

This is flexible, but it requires explicit output templates and does not infer output shape from a defn-traced callback body.

Problems with the current approach

The issue ("Make optional/custom calls mechanism extensible") points to a gap between internal optional dispatch and user-level extensibility.

  • Backend-owned extensibility: only backend modules can provide custom implementations via exported ops or optional/3
  • Limited user integration: users cannot naturally attach per-call implementation structs/config to custom kernels inside defn
  • Two separate mental models: optional has fallback and defn tracing; runtime_call has callback flexibility but no shape inference
  • No unified named-block abstraction: there is no first-class operation representing "this named block may have a custom implementation, otherwise run traced default"

Proposed solution: Nx.block

Add a new API that represents a named, extensible computation block:

Nx.block(struct, container, fn container -> ... end)

Where:

  • struct is a user-defined or Nx-defined struct that serves as both the block identity and its compile-time configuration. The struct module name replaces the need for a separate name and keyword arguments — backends pattern-match on the struct type to select specialized implementations
  • container holds the tensor inputs (tuple/list/container). Only tensor-like, runtime values go here.
  • The anonymous function receives the container and the configuration struct, and provides a default (defn-compatible) implementation

Static values vs tensors: Static configuration (like eps, mode, and other function-specific options) lives in the struct fields. Tensor inputs live in the container. This separation is important because static values are known at compile time and can influence compilation decisions, while tensors flow through the computation graph.

API shape examples

Nx-defined linear algebra operation with static configuration:

Nx.block(%Nx.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a, fn a, %Nx.LinAlg.QR{} = config ->
  # default QR implementation provided by Nx.LinAlg
  ...
end)

User-defined extensible block (e.g., from a library like Axon):

Nx.block(%Axon.FlashAttention{causal: true, block_size: 128}, {q, k, v}, fn {q, k, v}, %Axon.FlashAttention{} = config ->
  # default portable attention implementation
  scores = Nx.dot(q, [2], k, [2])
  Nx.softmax(scores) |> Nx.dot(v)
end)

Expected semantics of Nx.block

Nx.block is designed to combine strengths of both existing mechanisms:

  • Output inference: infer output shape, type and gradient via traced default implementation (defn-like behavior)
  • Struct-based identity: the struct module name serves as the block identity — no separate name atom needed. This is sufficient for debugging, compilation, and matching
  • Extensible dispatch: backends pattern-match on the struct type; user-defined structs can have backend protocol implementations provided by third parties
  • Safe fallback: if no backend-specific or protocol implementation is available, the default implementation runs

At the expression level:

%Nx.Defn.Expr{
  op: :block,
  args: [struct, args, default_implementation_fn]
}

The struct carries both the block identity (via its module) and static configuration (via its fields). Backends pattern-match on the struct, which is equivalent to today's pattern matching on the name of the optional call.

Backend-owned protocol model

Each backend can define its own protocol for extensibility over Nx.block.

For example, EXLA could define something similar to the following:

defprotocol EXLA.CustomCall do
  @doc "Returns a custom call implementation for the given block struct"
  def custom_call(struct, args, client)
end

Then a third-party library (e.g., Axon) or a user could provide a backend-specific implementation:

defimpl EXLA.CustomCall, for: Axon.FlashAttention do
  alias EXLA.MLIR.Value

  def custom_call(%Axon.FlashAttention{causal: causal, block_size: bs}, {%Value{} = q, %Value{} = k, %Value{} = v}, %EXLA.Client{} = client) do
    case client do
      %{platform: :host} -> 
           Value.custom_call(..., [q, k, v])
      %{platform: :cuda} -> 
           Value.custom_call(..., [q, k, v])
      %{platform: :rocm} -> 
           Value.custom_call(..., [q, k, v])
  end
end

Finally, the EXLA compiler could catch Protocol.UndefinedError to fall back to its default behavior when no extensions are used.

Nx.block vs optional vs runtime_call

Aspect optional (today) runtime_call (today) Nx.block (proposed)
Primary use Backend optional ops Runtime callback invocation Named extensible compute block
Fallback implementation Yes Callback itself is primary Yes (default impl fn)
Output inference Yes (from expr path) No (requires explicit template) Yes (from traced default impl)
User-extensible via struct/protocol No Not protocol-based Yes (via backend protocols)
Identity mechanism Op atom (:qr, etc.) No name Struct module name
Static config Keyword opts N/A Struct fields
Best for custom kernel config Limited Possible but shape-heavy Yes

References

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions