Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions neuralpredictors/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ class NeuroNormalizer(MovieTransform, StaticTransform, Invertible):
1% of the mean std (to avoid division by 0)
"""

def __init__(self, data, stats_source="all", exclude=None, inputs_mean=None, inputs_std=None):
def __init__(
self, data, stats_source="all", exclude=None, inputs_mean=None, inputs_std=None, subtract_behavior_mean=False
):

self.exclude = exclude or []

Expand Down Expand Up @@ -362,10 +364,13 @@ def __init__(self, data, stats_source="all", exclude=None, inputs_mean=None, inp
if "behavior" in data.data_keys:
s = np.array(data.statistics["behavior"][stats_source]["std"])

self.behavior_mean = (
0 if not subtract_behavior_mean else np.array(data.statistics["behavior"][stats_source]["mean"])
)
self._behavior_precision = 1 / s
# -- behavior
transforms["behavior"] = lambda x: x * self._behavior_precision
itransforms["behavior"] = lambda x: x / self._behavior_precision
transforms["behavior"] = lambda x: (x - self.behavior_mean) * self._behavior_precision
itransforms["behavior"] = lambda x: x / self._behavior_precision + self.behavior_mean

self._transforms = transforms
self._itransforms = itransforms
Expand Down
94 changes: 94 additions & 0 deletions neuralpredictors/layers/modulators/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import warnings
from collections import OrderedDict

import torch
from torch import nn
from torch.nn.init import xavier_normal


class MLP(nn.Module):
def __init__(
self, neurons, input_channels=3, hidden_channels=10, layers=2, bias=True, n_parameters_to_modulate=1, **kwargs
):
super().__init__()
warnings.warn("Ignoring input {} when creating {}".format(repr(kwargs), self.__class__.__name__))
self.n_parameters_to_modulate = n_parameters_to_modulate
self.modulator_networks = nn.ModuleList()
for _ in range(self.n_parameters_to_modulate):

prev_output = input_channels
feat = []
for _ in range(layers - 1):
feat.extend([nn.Linear(prev_output, hidden_channels, bias=bias), nn.ReLU()])
prev_output = hidden_channels

feat.extend([nn.Linear(prev_output, neurons, bias=bias), nn.ReLU()])
self.modulator_networks.append(nn.Sequential(*feat))

def regularizer(self):
return self.linear.weight.abs().mean()

def initialize(self):
for linear_layer in [p for p in self.parameters() if isinstance(p, nn.Linear)]:
xavier_normal(linear_layer.weight)

def forward(self, x, behavior):
mods = []
for network in self.modulator_networks:
# Make modulation positive. Exponential would result in exploding modulation -> Elu+1
mods.append(nn.functional.elu(network(behavior)) + 1)
mods = mods[0] if self.n_parameters_to_modulate == 1 else torch.stack(mods)
return x * mods


class StaticModulator(torch.nn.ModuleDict):
_base_modulator = None

def __init__(
self,
n_neurons,
input_channels=3,
hidden_channels=5,
layers=2,
gamma_modulator=0,
bias=True,
n_parameters_to_modulate=1,
**kwargs
):
warnings.warn("Ignoring input {} when creating {}".format(repr(kwargs), self.__class__.__name__))
super().__init__()
self.gamma_modulator = gamma_modulator
for k, n in n_neurons.items():
if isinstance(input_channels, OrderedDict):
ic = input_channels[k]
else:
ic = input_channels
self.add_module(
k,
self._base_modulator(
n, ic, hidden_channels, layers=layers, bias=bias, n_parameters_to_modulate=n_parameters_to_modulate
),
)

def initialize(self):
for k, mu in self.items():
self[k].initialize()

def regularizer(self, data_key):
return self[data_key].regularizer() * self.gamma_modulator


class MLPModulator(StaticModulator):
_base_modulator = MLP


def NoModulator(*args, **kwargs):
"""
Dummy function to create an object that returns None
Args:
*args: will be ignored
*kwargs: will be ignored
Returns:
None
"""
return None