Skip to content

Vectorized gather grad needs native vectorization support #1698

@blasphemetheus

Description

@blasphemetheus

The gather grad clause in Nx.Defn.Grad doesn't work with vectorized tensors. Unlike other ops where axis adjustment or devectorization fixes the issue, gather's indices are designed for the inner (non-vectorized) shape and become invalid when devectorized to the full shape.

x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.vectorize(:batch)
fun = fn x -> Nx.sum(Nx.gather(x, Nx.tensor([[0], [2]]))) end
Nx.Defn.grad(x, fun)
# ** (ArgumentError) cannot reshape, current shape {2} is not compatible with new shape {4, 3}

Devectorizing doesn't work because index 2 in [[0], [2]] is valid for inner shape {3} but out of bounds for devectorized shape {2, 3}. The grad clause needs to be rewritten to work natively with vectorized tensors. Found while working on #1533 — all other ops are fixed, gather is the remaining gap.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions