Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.
Merged
Prev Previous commit
Next Next commit
Added node_sets as a kwarg to get & ids
  • Loading branch information
Joni Herttuainen committed May 31, 2023
commit e88c06a74015bce65f3cb472bd96ea9d91c67701
4 changes: 2 additions & 2 deletions bluepysnap/edges/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def population_names(self):
"""Defines all sorted edge population names from the Circuit."""
return sorted(self._circuit.to_libsonata.edge_populations)

def ids(self, group=None, sample=None, limit=None):
def ids(self, group=None, sample=None, limit=None, **_): # pylint: disable=arguments-differ
"""Edge CircuitEdgeIds corresponding to edges ``edge_ids``.

Args:
Expand Down Expand Up @@ -81,7 +81,7 @@ def ids(self, group=None, sample=None, limit=None):
fun = lambda x: (x.ids(group), x.name)
return self._get_ids_from_pop(fun, CircuitEdgeIds, sample=sample, limit=limit)

def get(self, edge_ids=None, properties=None): # pylint: disable=arguments-renamed
def get(self, edge_ids=None, properties=None): # pylint: disable=arguments-differ
"""Edge properties as pandas DataFrame.

Args:
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def data_units(self):
@property
def node_set(self):
"""Returns the node set for the report."""
return self.simulation.node_sets[self.to_libsonata.cells]
return self.simulation.node_sets.content[self.to_libsonata.cells]

@property
def simulation(self):
Expand Down
6 changes: 3 additions & 3 deletions bluepysnap/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ def _get_ids_from_pop(self, fun_to_apply, returned_ids_cls, sample=None, limit=N
return res

@abc.abstractmethod
def ids(self, group=None, sample=None, limit=None):
def ids(self, group=None, sample=None, limit=None, node_sets=None):
"""Resolves the ids of the NetworkObject."""

@abc.abstractmethod
def get(self, group=None, properties=None):
def get(self, group=None, properties=None, node_sets=None): # pylint: disable=too-many-locals
"""Returns the properties of the NetworkObject."""
ids = self.ids(group)
ids = self.ids(group, node_sets=node_sets)
properties = utils.ensure_list(properties)
# We don t convert to set properties itself to keep the column order.
properties_set = set(properties)
Expand Down
5 changes: 4 additions & 1 deletion bluepysnap/node_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def __init__(self, filepath):
self.content = utils.load_json(filepath)
self._instance = libsonata.NodeSets.from_file(filepath)

def resolve_ids(self, node_set_name, population):
def get_ids(self, node_set_name, population, raise_missing_property=True):
"""Get the resolved node set using name as key."""
try:
return self._instance.materialize(node_set_name, population).flatten()
except libsonata.SonataError as e:
if "No such attribute" in e.args[0]:
if not raise_missing_property:
return []
raise BluepySnapError(*e.args) from e

def __iter__(self):
Expand Down
36 changes: 22 additions & 14 deletions bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,27 +265,32 @@ def _check_property(self, prop):
if prop not in self.property_names:
raise BluepySnapError(f"No such property: '{prop}'")

def _get_node_set(self, node_set_name):
def _get_node_set(self, node_set_name, raise_missing_prop, node_sets):
"""Returns the node set named 'node_set_name'."""
if node_set_name not in self._node_sets:
node_sets = node_sets or self._node_sets
if node_set_name not in node_sets:
raise BluepySnapError(f"Undefined node set: '{node_set_name}'")
return self._node_sets.resolve_ids(node_set_name, self._population)
return node_sets.get_ids(node_set_name, self._population, raise_missing_prop)

def _resolve_nodesets(self, queries):
def _resolve_nodesets(self, queries, raise_missing_prop, node_sets):
def _resolve(queries, queries_key):
if queries_key == query.NODE_SET_KEY:
if query.AND_KEY not in queries:
queries[query.AND_KEY] = []
queries[query.AND_KEY].append(
{query.NODE_ID_KEY: self._get_node_set(queries[queries_key])}
{
query.NODE_ID_KEY: self._get_node_set(
queries[queries_key], raise_missing_prop, node_sets
)
}
)
del queries[queries_key]

resolved_queries = deepcopy(queries)
query.traverse_queries_bottom_up(resolved_queries, _resolve)
return resolved_queries

def _node_ids_by_filter(self, queries, raise_missing_prop):
def _node_ids_by_filter(self, queries, raise_missing_prop, node_sets):
"""Return node IDs if their properties match the `queries` dict.

`props` values could be:
Expand All @@ -302,7 +307,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
>>> { Node.X: (0, 1), Node.MTYPE: 'L1_SLAC' }]})

"""
queries = self._resolve_nodesets(queries)
queries = self._resolve_nodesets(queries, raise_missing_prop, node_sets)
if raise_missing_prop:
properties = query.get_properties(queries)
if not properties.issubset(self._data.columns):
Expand All @@ -311,7 +316,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
idx = query.resolve_ids(self._data, self.name, queries)
return self._data.index[idx].values

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
def ids(self, group=None, limit=None, sample=None, raise_missing_property=True, node_sets=None):
"""Node IDs corresponding to node ``group``.

Args:
Expand All @@ -328,22 +333,23 @@ def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
raise_missing_property (bool): if True, raises if a property is not listed in this
population. Otherwise the ids are just not selected if a property is missing.

node_sets (NodeSets): If specified, use this NodeSets instance to resolve node sets.
Otherwise use the NodeSets instance of the circuit.

Returns:
numpy.array: A numpy array of IDs.
"""
# pylint: disable=too-many-branches
preserve_order = False
if isinstance(group, str):
group = self._get_node_set(group)
group = self._get_node_set(group, raise_missing_property, node_sets)
elif isinstance(group, CircuitNodeIds):
group = group.filter_population(self.name).get_ids()

if group is None:
result = self._data.index.values
elif isinstance(group, Mapping):
result = self._node_ids_by_filter(
queries=group, raise_missing_prop=raise_missing_property
)
result = self._node_ids_by_filter(group, raise_missing_property, node_sets)
elif isinstance(group, np.ndarray):
result = group
self._check_ids(result)
Expand Down Expand Up @@ -374,14 +380,16 @@ def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
else:
return np.unique(result)

def get(self, group=None, properties=None):
def get(self, group=None, properties=None, node_sets=None):
"""Node properties as a pandas Series or DataFrame.

Args:
group (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None):
see :ref:`Group Concept`
properties (list|str|None): If specified, return only the properties in the list.
Otherwise return all properties.
node_sets (NodeSets): If specified, use this NodeSets instance to resolve node sets.
Otherwise use the NodeSets instance of the circuit.

Returns:
value/pandas.Series/pandas.DataFrame:
Expand Down Expand Up @@ -450,7 +458,7 @@ def get(self, group=None, properties=None):
elif isinstance(group, CircuitNodeId):
group = self.ids(group)[0]
else:
group = self.ids(group)
group = self.ids(group, node_sets=node_sets)
result = result.loc[group]

if properties is not None:
Expand Down
8 changes: 4 additions & 4 deletions bluepysnap/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def property_values(self, prop):
for value in pop.property_values(prop)
)

def ids(self, group=None, sample=None, limit=None):
def ids(self, group=None, sample=None, limit=None, node_sets=None):
"""Returns the CircuitNodeIds corresponding to the nodes from ``group``.

Args:
Expand Down Expand Up @@ -116,10 +116,10 @@ def ids(self, group=None, sample=None, limit=None):
if diff.size != 0:
raise BluepySnapError(f"Population {diff} does not exist in the circuit.")

fun = lambda x: (x.ids(group, raise_missing_property=False), x.name)
fun = lambda x: (x.ids(group, raise_missing_property=False, node_sets=node_sets), x.name)
return self._get_ids_from_pop(fun, CircuitNodeIds, sample=sample, limit=limit)

def get(self, group=None, properties=None): # pylint: disable=arguments-differ
def get(self, group=None, properties=None, node_sets=None): # pylint: disable=arguments-differ
"""Node properties as a pandas DataFrame.

Args:
Expand All @@ -141,7 +141,7 @@ def get(self, group=None, properties=None): # pylint: disable=arguments-differ
if properties is None:
# not strictly needed, but ensure that the properties are always in the same order
properties = sorted(self.property_names)
return super().get(group, properties)
return super().get(group, properties, node_sets)

def __getstate__(self):
"""Make Nodes pickle-able, without storing state of caches."""
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def simulator(self):
def node_sets(self):
"""Returns the NodeSets object bound to the simulation."""
node_sets_file = self.to_libsonata.node_sets_file
return NodeSets(node_sets_file).content if node_sets_file else {}
return NodeSets(node_sets_file) if node_sets_file else {}

@cached_property
def spikes(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/data/node_sets_extra.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"ExtraLayer2": {
"layer": 2
}
}
10 changes: 6 additions & 4 deletions tests/test_node_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ def test_init(self):
"failing": {"unknown_property": [0]},
}

def test_resolve_ids(self):
assert_array_equal(self.test_obj.resolve_ids("Node2_L6_Y", self.test_pop), [])
def test_get_ids(self):
assert_array_equal(self.test_obj.get_ids("Node2_L6_Y", self.test_pop), [])
assert_array_equal(
self.test_obj.resolve_ids("double_combined", self.test_pop),
self.test_obj.get_ids("double_combined", self.test_pop),
[0, 1, 2],
)
with pytest.raises(BluepySnapError, match="No such attribute"):
self.test_obj.resolve_ids("failing", self.test_pop)
self.test_obj.get_ids("failing", self.test_pop)

assert self.test_obj.get_ids("failing", self.test_pop, raise_missing_property=False) == []

def test_iter(self):
expected = set(json.loads((TEST_DATA_DIR / "node_sets_file.json").read_text()))
Expand Down
19 changes: 19 additions & 0 deletions tests/test_nodes/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from bluepysnap.circuit import Circuit
from bluepysnap.circuit_ids import CircuitNodeId, CircuitNodeIds
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.node_sets import NodeSets
from bluepysnap.utils import IDS_DTYPE

from utils import TEST_DATA_DIR
Expand Down Expand Up @@ -412,6 +413,24 @@ def test_get(self):
with pytest.raises(BluepySnapError, match="Unknown properties required: {'unknown'}"):
self.test_obj.get(properties="unknown")

def test_functionality_with_separate_node_set(self):
with pytest.raises(BluepySnapError, match="Undefined node set"):
self.test_obj.ids("ExtraLayer2")

node_sets = NodeSets(str(TEST_DATA_DIR / "node_sets_extra.json"))

assert self.test_obj.ids("ExtraLayer2", node_sets=node_sets) == CircuitNodeIds.from_arrays(
["default", "default2"], [0, 3]
)

with pytest.raises(BluepySnapError, match="Undefined node set"):
self.test_obj.get("ExtraLayer2")

pdt.assert_frame_equal(
self.test_obj.get("ExtraLayer2", node_sets=node_sets),
self.test_obj.get("Layer2"),
)

def test_pickle(self, tmp_path):
pickle_path = tmp_path / "pickle.pkl"

Expand Down
4 changes: 3 additions & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PopulationSomaReport,
SomaReport,
)
from bluepysnap.node_sets import NodeSets
from bluepysnap.spike_report import PopulationSpikeReport, SpikeReport

from utils import TEST_DATA_DIR, copy_test_data, edit_config
Expand All @@ -40,7 +41,8 @@ def test_all():
assert simulation.conditions.celsius == 34.0
assert simulation.conditions.v_init == -80

assert simulation.node_sets == {"Layer23": {"layer": [2, 3]}}
assert isinstance(simulation.node_sets, NodeSets)
assert simulation.node_sets.content == {"Layer23": {"layer": [2, 3]}}
assert isinstance(simulation.spikes, SpikeReport)
assert isinstance(simulation.spikes["default"], PopulationSpikeReport)

Expand Down