Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add possibility to subtract behavior mean in NeuroNormalizer
  • Loading branch information
kklur committed Feb 27, 2023
commit 3f4baf65242860b330838810c7259ddf120c7ef5
11 changes: 8 additions & 3 deletions neuralpredictors/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ class NeuroNormalizer(MovieTransform, StaticTransform, Invertible):
1% of the mean std (to avoid division by 0)
"""

def __init__(self, data, stats_source="all", exclude=None, inputs_mean=None, inputs_std=None):
def __init__(
self, data, stats_source="all", exclude=None, inputs_mean=None, inputs_std=None, subtract_behavior_mean=False
):

self.exclude = exclude or []

Expand Down Expand Up @@ -362,10 +364,13 @@ def __init__(self, data, stats_source="all", exclude=None, inputs_mean=None, inp
if "behavior" in data.data_keys:
s = np.array(data.statistics["behavior"][stats_source]["std"])

self.behavior_mean = (
0 if not subtract_behavior_mean else np.array(data.statistics["behavior"][stats_source]["mean"])
)
self._behavior_precision = 1 / s
# -- behavior
transforms["behavior"] = lambda x: x * self._behavior_precision
itransforms["behavior"] = lambda x: x / self._behavior_precision
transforms["behavior"] = lambda x: (x - self.behavior_mean) * self._behavior_precision
itransforms["behavior"] = lambda x: x / self._behavior_precision + self.behavior_mean

self._transforms = transforms
self._itransforms = itransforms
Expand Down