-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Description
The gather grad clause in Nx.Defn.Grad doesn't work with vectorized tensors. Unlike other ops where axis adjustment or devectorization fixes the issue, gather's indices are designed for the inner (non-vectorized) shape and become invalid when devectorized to the full shape.
x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.vectorize(:batch)
fun = fn x -> Nx.sum(Nx.gather(x, Nx.tensor([[0], [2]]))) end
Nx.Defn.grad(x, fun)
# ** (ArgumentError) cannot reshape, current shape {2} is not compatible with new shape {4, 3}Devectorizing doesn't work because index 2 in [[0], [2]] is valid for inner shape {3} but out of bounds for devectorized shape {2, 3}. The grad clause needs to be rewritten to work natively with vectorized tensors. Found while working on #1533 — all other ops are fixed, gather is the remaining gap.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels