Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.
Merged
84 changes: 8 additions & 76 deletions bluepysnap/node_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,84 +20,13 @@
For more information see:
https://github.com/AllenInstitute/sonata/blob/master/docs/SONATA_DEVELOPER_GUIDE.md#node-sets-file
"""
from collections.abc import Mapping
from copy import deepcopy

import numpy as np
import libsonata

from bluepysnap import utils
from bluepysnap.exceptions import BluepySnapError


def _sanitize(node_set):
"""Sanitize standard node set (not compounds).

Set a single value instead of a one element list.
Sorted and unique values for the lists of values.

Args:
node_set (Mapping): A standard non compound node set.

Return:
map: The sanitized node set.
"""
for key, values in node_set.items():
if isinstance(values, list):
if len(values) == 1:
node_set[key] = values[0]
else:
# sorted unique value list
node_set[key] = np.unique(np.asarray(values)).tolist()
return node_set


def _resolve_set(content, resolved, node_set_name):
"""Resolve the node set 'node_set_name' from content.

The resolved node set is returned and the resolved dict is updated in place with the
resolved node set.

Args:
content (dict): the global dictionary containing all unresolved node sets.
resolved (dict): the global resolved dictionary containing the already resolved node sets.
node_set_name (str): the name of the current node set to resolve.

Returns:
dict: the resolved node set.

Notes:
If the node set is a compound node set then all the sub node sets are also resolved and
stored inside the resolved dictionary.
"""
if node_set_name in resolved:
# return already resolved node_sets
return resolved[node_set_name]

# keep the content intact
set_value = deepcopy(content.get(node_set_name))
if set_value is None:
raise BluepySnapError(f"Missing node_set: '{node_set_name}'")
if not isinstance(set_value, (Mapping, list)) or not set_value:
raise BluepySnapError(f"Ambiguous node_set: { {node_set_name: set_value} }")
if isinstance(set_value, Mapping):
resolved[node_set_name] = _sanitize(set_value)
return resolved[node_set_name]

# compounds only
res = [_resolve_set(content, resolved, sub_set_name) for sub_set_name in set_value]

resolved[node_set_name] = {"$or": res}
return resolved[node_set_name]


def _resolve(content):
"""Resolve all node sets in content."""
resolved = {}
for set_name in content:
_resolve_set(content, resolved, set_name)
return resolved


class NodeSets:
"""Access to node sets data."""

Expand All @@ -111,12 +40,15 @@ def __init__(self, filepath):
NodeSets: A NodeSets object.
"""
self.content = utils.load_json(filepath)
self.resolved = _resolve(self.content)
self._instance = libsonata.NodeSets.from_file(filepath)

def __getitem__(self, node_set_name):
def resolve_ids(self, node_set_name, population):
"""Get the resolved node set using name as key."""
return self.resolved[node_set_name]
try:
return self._instance.materialize(node_set_name, population).flatten()
except libsonata.SonataError as e:
raise BluepySnapError(*e.args) from e

def __iter__(self):
"""Iter through the different node sets names."""
return iter(self.resolved)
return iter(self._instance.names)
6 changes: 4 additions & 2 deletions bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,16 @@ def _get_node_set(self, node_set_name):
"""Returns the node set named 'node_set_name'."""
if node_set_name not in self._node_sets:
raise BluepySnapError(f"Undefined node set: '{node_set_name}'")
return self._node_sets[node_set_name]
return self._node_sets.resolve_ids(node_set_name, self._population)

def _resolve_nodesets(self, queries):
def _resolve(queries, queries_key):
if queries_key == query.NODE_SET_KEY:
if query.AND_KEY not in queries:
queries[query.AND_KEY] = []
queries[query.AND_KEY].append(self._get_node_set(queries[queries_key]))
queries[query.AND_KEY].append(
{query.NODE_ID_KEY: self._get_node_set(queries[queries_key])}
)
del queries[queries_key]

resolved_queries = deepcopy(queries)
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def simulator(self):
def node_sets(self):
"""Returns the NodeSets object bound to the simulation."""
node_sets_file = self.to_libsonata.node_sets_file
return NodeSets(node_sets_file) if node_sets_file else {}
return NodeSets(node_sets_file).content if node_sets_file else {}

@cached_property
def spikes(self):
Expand Down
7 changes: 0 additions & 7 deletions tests/data/node_sets.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
"Layer23": {
"layer": [2,3]
},
"Empty_nodes": {
"node_id": []
},
"Node2012": {
"node_id": [2,0,1,2]
},
Expand All @@ -23,10 +20,6 @@
"mtype": "L6_Y",
"node_id": [0]
},
"Empty_L6_Y": {
"node_id": [],
"mtype": "L6_Y"
},
"Population_default": {
"population": "default"
},
Expand Down
7 changes: 5 additions & 2 deletions tests/data/node_sets_file.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@
"population": "default",
"mtype": "L6_Y"
},
"combined": ["Node2_L6_Y", "Layer23"]
}
"combined": ["Node2_L6_Y", "Layer23"],
"failing": {
"unknown_property": [0]
}
}
59 changes: 14 additions & 45 deletions tests/test_node_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from unittest.mock import patch

import libsonata
import pytest
from numpy.testing import assert_array_equal

import bluepysnap.node_sets as test_module
from bluepysnap.exceptions import BluepySnapError
Expand All @@ -12,6 +14,9 @@
class TestNodeSets:
def setup_method(self):
self.test_obj = test_module.NodeSets(str(TEST_DATA_DIR / "node_sets_file.json"))
self.test_pop = libsonata.NodeStorage(str(TEST_DATA_DIR / "nodes.h5")).open_population(
"default"
)

def test_init(self):
assert self.test_obj.content == {
Expand All @@ -20,54 +25,18 @@ def test_init(self):
"Layer23": {"layer": [3, 2, 2]},
"population_default_L6": {"population": "default", "mtype": "L6_Y"},
"combined": ["Node2_L6_Y", "Layer23"],
"failing": {"unknown_property": [0]},
}

# node_id from Node2_L6_Y should be 2 and not [2], layer from Layer23 should be [2, 3]
assert self.test_obj.resolved == {
"Node2_L6_Y": {"mtype": "L6_Y", "node_id": [20, 30]},
"Layer23": {"layer": [2, 3]},
"combined": {"$or": [{"mtype": "L6_Y", "node_id": [20, 30]}, {"layer": [2, 3]}]},
"population_default_L6": {"population": "default", "mtype": "L6_Y"},
"double_combined": {
"$or": [
{"$or": [{"mtype": "L6_Y", "node_id": [20, 30]}, {"layer": [2, 3]}]},
{"population": "default", "mtype": "L6_Y"},
]
},
}

def test_get(self):
assert self.test_obj["Node2_L6_Y"] == {"mtype": "L6_Y", "node_id": [20, 30]}
assert self.test_obj["double_combined"] == {
"$or": [
{"$or": [{"mtype": "L6_Y", "node_id": [20, 30]}, {"layer": [2, 3]}]},
{"population": "default", "mtype": "L6_Y"},
]
}
def test_resolve_ids(self):
assert_array_equal(self.test_obj.resolve_ids("Node2_L6_Y", self.test_pop), [])
assert_array_equal(
self.test_obj.resolve_ids("double_combined", self.test_pop),
[0, 1, 2],
)
with pytest.raises(BluepySnapError, match="No such attribute"):
self.test_obj.resolve_ids("failing", self.test_pop)

def test_iter(self):
expected = set(json.loads((TEST_DATA_DIR / "node_sets_file.json").read_text()))
assert set(self.test_obj) == expected


@patch("bluepysnap.utils.load_json")
def test_fail_resolve(mock_load):
mock_load.return_value = {"empty_dict": {}}
with pytest.raises(BluepySnapError):
test_module.NodeSets(str(TEST_DATA_DIR / "node_sets_file.json"))

mock_load.return_value = {"empty_list": []}
with pytest.raises(BluepySnapError):
test_module.NodeSets(str(TEST_DATA_DIR / "node_sets_file.json"))

mock_load.return_value = {"int": 1}
with pytest.raises(BluepySnapError):
test_module.NodeSets(str(TEST_DATA_DIR / "node_sets_file.json"))

mock_load.return_value = {"bool": True}
with pytest.raises(BluepySnapError):
test_module.NodeSets(str(TEST_DATA_DIR / "node_sets_file.json"))

mock_load.return_value = {"combined": ["known", "unknown"], "known": {"v": 1}}
with pytest.raises(BluepySnapError):
test_module.NodeSets(str(TEST_DATA_DIR / "node_sets_file.json"))
2 changes: 0 additions & 2 deletions tests/test_nodes/test_node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,11 @@ def test_ids(self):

npt.assert_equal(_call("Layer2"), [0])
npt.assert_equal(_call("Layer23"), [0])
npt.assert_equal(_call("Empty_nodes"), [])
npt.assert_equal(_call("Node2012"), [0, 1, 2]) # reordered + duplicates are removed
npt.assert_equal(_call("Node12_L6_Y"), [1, 2])
npt.assert_equal(_call("Node2_L6_Y"), [2])

npt.assert_equal(_call("Node0_L6_Y"), []) # return empty if disjoint samples
npt.assert_equal(_call("Empty_L6_Y"), []) # return empty if empty node_id = []
npt.assert_equal(_call("Population_default"), [0, 1, 2]) # return all ids
npt.assert_equal(_call("Population_default2"), []) # return empty if diff population
npt.assert_equal(_call("Population_default_L6_Y"), [1, 2]) # population + other query ok
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_all():
assert simulation.conditions.celsius == 34.0
assert simulation.conditions.v_init == -80

assert simulation.node_sets.resolved == {"Layer23": {"layer": [2, 3]}}
assert simulation.node_sets == {"Layer23": {"layer": [2, 3]}}
assert isinstance(simulation.spikes, SpikeReport)
assert isinstance(simulation.spikes["default"], PopulationSpikeReport)

Expand Down