Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix dimensions of output array when n_parameters_to_modulate is equ…
…al to 1
  • Loading branch information
kklur committed Mar 1, 2023
commit a64fb3dafd46184eedafa41b9cf5fb9af45f2e56
5 changes: 3 additions & 2 deletions neuralpredictors/layers/modulators/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def __init__(
):
super().__init__()
warnings.warn("Ignoring input {} when creating {}".format(repr(kwargs), self.__class__.__name__))
self.n_parameters_to_modulate = n_parameters_to_modulate
self.modulator_networks = nn.ModuleList()
for _ in range(n_parameters_to_modulate):
for _ in range(self.n_parameters_to_modulate):

prev_output = input_channels
feat = []
Expand All @@ -36,7 +37,7 @@ def forward(self, x, behavior):
for network in self.modulator_networks:
# Make modulation positive. Exponential would result in exploding modulation -> Elu+1
mods.append(nn.functional.elu(network(behavior)) + 1)
mods = torch.stack(mods)
mods = mods[0] if self.n_parameters_to_modulate == 1 else torch.stack(mods)
return x * mods


Expand Down