Skip to content

Commit 502c50c

Browse files
authored
Move out dm-tree and networkx from main dependencies (#324)
Moved these do ldp[scg] subpackage
1 parent 961bf8a commit 502c50c

File tree

11 files changed

+332
-35
lines changed

11 files changed

+332
-35
lines changed

.github/workflows/tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ jobs:
5959
env:
6060
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
6161
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
62+
- run:
63+
| # this is to check that we can run README examples with lazily imported dependencies (the [scg] subpackage)
64+
uv pip uninstall dm-tree usearch networkx && uv run --no-sync pytest --noconftest tests/test_readme.py
65+
env:
66+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
67+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
6268
test-lmi:
6369
runs-on: ubuntu-latest
6470
steps:

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,15 @@ which just converts the action constructed by plain python into a node in a comp
185185

186186
For more advanced use-cases, LDP features a stochastic computation graph [^2]
187187
which enables differentiatiation with respect to agent parameters
188-
(including the weights of the LLM). The example computation graph below illustrates the functionality
188+
(including the weights of the LLM).
189+
190+
You should install the `scg` subpackage to work with it:
191+
192+
```bash
193+
pip install ldp[scg]
194+
```
195+
196+
The example computation graph below illustrates the functionality
189197

190198
```py
191199
from ldp.graph import FxnOp, LLMCallOp, PromptOp, compute_graph

pyproject.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,15 @@ classifiers = [
2323
]
2424
dependencies = [
2525
"aiofiles",
26-
"dm-tree",
2726
"fhaviary>=0.20.0", # For Environment.get_id
2827
"fhlmi", # For new LLM client interface
2928
"httpx",
30-
"networkx[default]~=3.4", # Pin for pydot fix
3129
"numpy>=1.20", # For numpy.typing
3230
"pydantic~=2.0",
3331
"tenacity",
3432
"tiktoken",
3533
"tqdm",
3634
"typing-extensions; python_version <= '3.11'", # for typing.override
37-
"usearch>=2.13", # For py.typed
3835
]
3936
description = "Agent framework for constructing language model agents and training on constructive tasks."
4037
dynamic = ["version"]
@@ -66,6 +63,7 @@ dev = [
6663
"pytest-xdist",
6764
"pytest>=8", # Pin to keep recent
6865
"refurb>=2", # Pin to keep recent
66+
"usearch>=2.13", # For py.typed
6967
"vcrpy>=6", # Pin for https://github.com/kevin1024/vcrpy/issues/884
7068
]
7169
monitor = [
@@ -76,6 +74,7 @@ nn = [
7674
"dask-cuda>=24.8.2",
7775
"dask-jobqueue",
7876
"dask[distributed]",
77+
"ldp[scg]",
7978
"tokenizers>0.20",
8079
"torch>=2.5,<2.6", # Temporarily pin <2.6 until someone fixes our CI with newer versions
8180
"transformers>=4.46,<4.51", # Temporarily pin <4.51 until someone fixes our CI with newer versions, and also https://github.com/huggingface/transformers/issues/37339
@@ -85,6 +84,10 @@ rich = [
8584
"rich",
8685
"tqdm>=4.56", # When tqdm.rich was created
8786
]
87+
scg = [
88+
"dm-tree",
89+
"networkx[default]~=3.4", # Pin for pydot fix
90+
]
8891
server = [
8992
"fastapi>=0.109", # For Python 3.12 support
9093
]

src/ldp/alg/algorithms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import itertools
55
import random
66
from collections.abc import Awaitable, Callable, Hashable, Iterable, Sequence
7-
from typing import Any, Literal, TypeVar
7+
from typing import TYPE_CHECKING, Any, Literal, TypeVar
88

9-
import networkx as nx
9+
if TYPE_CHECKING:
10+
import networkx as nx
1011
import numpy as np
1112
from aviary.core import Message, Tool, ToolRequestMessage, join
1213

1314
from ldp.graph import OpResult
15+
from ldp.graph.op_utils import _lazy_import_networkx
1416
from ldp.graph.ops import GradOutType
1517

1618

@@ -36,6 +38,7 @@ def to_network( # noqa: C901
3638
Returns:
3739
Populated a NetworkX multi-edge directed graph.
3840
"""
41+
nx = _lazy_import_networkx()
3942

4043
def gvizify(x: Any) -> str:
4144
"""Stringify and then escape colons for Graphviz labels."""

src/ldp/data_structures.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99
from uuid import UUID
1010

1111
import aiofiles
12-
import networkx as nx
1312
from aviary.core import Message, ToolRequestMessage, ToolResponseMessage, join
1413
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
1514

1615
from ldp.graph import OpResult
16+
from ldp.graph.op_utils import _lazy_import_networkx
1717
from ldp.utils import discounted_returns
1818

1919
if TYPE_CHECKING:
2020
from aviary.core import Environment
2121

22+
2223
logger = logging.getLogger(__name__)
2324

2425

@@ -180,6 +181,7 @@ def __init__(self, root_id: str | UUID):
180181
"""
181182
self.root_id = str(root_id)
182183

184+
nx = _lazy_import_networkx()
183185
self.tree = nx.DiGraph() # the actual tree
184186
self.rev_tree = nx.DiGraph() # the same as self.tree, but with reversed edges
185187

@@ -292,7 +294,7 @@ def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None:
292294
discount_factor: The discount factor to use when computing cumulative
293295
future rewards.
294296
"""
295-
for step_id in nx.topological_sort(self.rev_tree):
297+
for step_id in _lazy_import_networkx().topological_sort(self.rev_tree):
296298
step: Transition | None = self.tree.nodes[step_id]["transition"]
297299
if step is None:
298300
continue
@@ -324,7 +326,9 @@ def compute_advantages(self) -> None:
324326
"""
325327
state_values: dict[str, float] = {}
326328

327-
for step_id in cast("Iterable[str]", nx.topological_sort(self.tree)):
329+
for step_id in cast(
330+
"Iterable[str]", _lazy_import_networkx().topological_sort(self.tree)
331+
):
328332
# topological sort means we will update a parent node in-place before
329333
# descending to its children
330334

@@ -412,7 +416,7 @@ def merge_identical_nodes(
412416
# old step ID -> new step ID
413417
node_remap: dict[str, str] = {self.root_id: self.root_id}
414418

415-
for step_id in nx.topological_sort(self.tree):
419+
for step_id in _lazy_import_networkx().topological_sort(self.tree):
416420
step: Transition | None = self.tree.nodes[step_id]["transition"]
417421
if step is None:
418422
continue

src/ldp/graph/common_ops.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
import logging
99
from collections.abc import Awaitable, Callable
1010
from functools import lru_cache
11-
from typing import Generic, TypeVar, cast, overload
11+
from typing import TYPE_CHECKING, Generic, TypeVar, cast, overload
1212

1313
import numpy as np
1414
import tenacity
15-
import tree
1615
from aviary.core import Message, Tool, ToolRequestMessage, is_coroutine_callable
1716
from lmi import (
1817
EmbeddingModel,
@@ -24,11 +23,13 @@
2423
from lmi import LiteLLMModel as LLMModel
2524
from pydantic import BaseModel
2625

27-
from .gradient_estimators import assign_constant_grads
2826
from .memory import Memory, MemoryModel, UIndexMemoryModel
29-
from .op_utils import CallID, get_call_id, get_training_mode
27+
from .op_utils import CallID, _lazy_import_tree, get_call_id, get_training_mode
3028
from .ops import GradInType, Op, OpCtx, ResultOrValue, TOutput_co
3129

30+
if TYPE_CHECKING:
31+
import tree
32+
3233
logger = logging.getLogger(__name__)
3334

3435

@@ -76,6 +77,8 @@ def backward(
7677
grad_output: tree.Structure,
7778
call_id: CallID,
7879
) -> GradInType:
80+
from .gradient_estimators import assign_constant_grads
81+
7982
return assign_constant_grads(input_args, input_kwargs, None)
8083

8184

@@ -101,7 +104,7 @@ def backward(
101104
call_id: CallID,
102105
) -> GradInType:
103106
# Check that the grad_output structure is consistent with our config
104-
tree.assert_same_structure(
107+
_lazy_import_tree().assert_same_structure(
105108
grad_output, ctx.get(call_id, "output").value, check_types=False
106109
)
107110

@@ -186,6 +189,8 @@ def backward(
186189
grad_output: tree.Structure,
187190
call_id: CallID,
188191
) -> GradInType:
192+
from .gradient_estimators import assign_constant_grads
193+
189194
return assign_constant_grads(input_args, input_kwargs, 0.0)
190195

191196

@@ -397,7 +402,9 @@ def backward(
397402
# but not necessarily each message or tool.
398403

399404
# tree.map_structure allows us to assign a gradient of 0 to all fields of config
400-
grad_config = tree.map_structure(lambda _: 0.0, input_kwargs["config"])
405+
grad_config = _lazy_import_tree().map_structure(
406+
lambda _: 0.0, input_kwargs["config"]
407+
)
401408
grad_kwargs = {"config": grad_config}
402409
for arg in ("msgs", "tools", "tool_choice"):
403410
if arg in input_kwargs:
@@ -472,6 +479,8 @@ def backward(
472479
call_id: CallID,
473480
) -> GradInType:
474481
"""Backward pass for memory retrieval - goes back to item."""
482+
from .gradient_estimators import assign_constant_grads
483+
475484
return assign_constant_grads(input_args, input_kwargs, 0.0)
476485

477486

src/ldp/graph/memory.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
field_validator,
1818
model_validator,
1919
)
20-
from usearch.index import Index
2120

2221
if TYPE_CHECKING:
22+
from usearch.index import Index # noqa: F401
23+
2324
from .common_ops import MemoryOp
2425
from .op_utils import CallID
2526
from .ops import Op, OpResult, TOutput_co
@@ -164,13 +165,20 @@ async def _search_index(
164165
"""Search the internal Index, returning a 'matches' amount of Memories."""
165166

166167

167-
class UIndexMemoryModel(MemoryModel[Index]):
168+
class UIndexMemoryModel(MemoryModel["Index"]):
168169
"""Memory model using a U-Search index."""
169170

170171
def __init__(self, **kwargs):
171172
super().__init__(**kwargs)
172173
if not self.embedding_model.ndim:
173174
raise TypeError("Specify dimensions to the embedding model.")
175+
try:
176+
from usearch.index import Index
177+
except ImportError as e:
178+
raise ImportError(
179+
"U-Search library not found. Unable to use UIndexMemoryModel."
180+
" To install U-Search dependencies, please run `pip install usearch`."
181+
) from e
174182
self._index = Index(ndim=self.embedding_model.ndim)
175183

176184
async def _add_to_index(self, embedding: np.ndarray) -> int:

src/ldp/graph/op_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
11
import contextvars
22
from collections.abc import AsyncIterator
33
from contextlib import asynccontextmanager
4+
from types import ModuleType
45
from uuid import UUID, uuid4
56

67
from aviary.core import is_coroutine_callable
78
from pydantic import BaseModel, field_serializer, field_validator
89

910

11+
def _lazy_import_networkx() -> ModuleType:
12+
try:
13+
import networkx as nx
14+
except ImportError as e:
15+
raise ImportError(
16+
"networkx is required for compute graph operations. "
17+
"Please install it with: pip install ldp[scg]"
18+
) from e
19+
return nx
20+
21+
22+
def _lazy_import_tree() -> ModuleType:
23+
try:
24+
import tree
25+
except ImportError as e:
26+
raise ImportError(
27+
"tree is required for compute graph operations. "
28+
"Please install it with: pip install ldp[scg]"
29+
) from e
30+
return tree
31+
32+
1033
class CallID(BaseModel):
1134
run_id: UUID
1235
fwd_id: UUID

src/ldp/graph/ops.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,34 @@
99
from abc import ABC, abstractmethod
1010
from collections import defaultdict
1111
from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence
12-
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
12+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar
1313
from uuid import UUID
1414

15-
import networkx as nx
16-
import tree
15+
if TYPE_CHECKING:
16+
import networkx as nx
17+
import tree
1718
from pydantic import BaseModel, Field
1819

19-
from .op_utils import CallID, compute_graph, get_call_id, get_training_mode, op_call
20+
from .op_utils import (
21+
CallID,
22+
_lazy_import_networkx,
23+
_lazy_import_tree,
24+
compute_graph,
25+
get_call_id,
26+
get_training_mode,
27+
op_call,
28+
)
2029

2130
logger = logging.getLogger(__name__)
2231

2332

24-
GradOutType: TypeAlias = tree.Structure | None # None means the gradient has terminated
33+
GradOutType: TypeAlias = (
34+
"tree.Structure | None" # None means the gradient has terminated
35+
)
2536
GradInType: TypeAlias = tuple[Sequence[GradOutType], Mapping[str, GradOutType]]
2637
BackwardsType: TypeAlias = Callable[
2738
# Call signature of Op.backward
28-
["OpCtx", list, dict, tree.Structure, CallID],
39+
["OpCtx", list, dict, "tree.Structure", CallID],
2940
GradInType,
3041
]
3142
TOutput_co = TypeVar("TOutput_co", covariant=True)
@@ -93,8 +104,11 @@ def compute_grads(
93104
(a) define the backward computation
94105
(b) store internal gradients for optimizer updates.
95106
"""
96-
# call ID -> [d op(x) / d x] for each op that consumes x
97-
grad_outputs: dict[CallID, list[tree.Structure]] = defaultdict(list)
107+
tree = _lazy_import_tree()
108+
109+
# call ID -> [d op(x) / d x] for each op that consumes x[
110+
# due to interaction between ruff and mypy, we need type ignore
111+
grad_outputs: dict[CallID, list[tree.Structure]] = defaultdict(list) # type: ignore[name-defined]
98112

99113
# grad_outputs stores a list of output grads (corresponding to each consuming op call).
100114
# Since the root node is not consumed by any other node, we create a singleton list here.
@@ -212,7 +226,7 @@ def add_edges(graph: nx.DiGraph, node: OpResult) -> None:
212226
graph.add_edge(*edge)
213227
add_edges(graph, x)
214228

215-
graph = nx.DiGraph()
229+
graph = _lazy_import_networkx().DiGraph()
216230
graph.add_node(self)
217231
add_edges(graph, self)
218232

@@ -251,7 +265,7 @@ def traverse(
251265
"""
252266
if topological_order:
253267
G = self.get_compute_graph()
254-
for node in nx.topological_sort(G):
268+
for node in _lazy_import_networkx().topological_sort(G):
255269
if filter_fn(node):
256270
yield node
257271

0 commit comments

Comments
 (0)