Skip to content

Commit 225f2e3

Browse files
committed
update jit_ext access to torchfn_to_thunder registry : test
1 parent 8abf040 commit 225f2e3

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

thunder/core/jit_ext.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions
7272
from thunder.core.symbol import Symbol
7373
from thunder.core.trace import TraceCtx, TraceResults
74-
from thunder.torch import _torch_to_thunder_function_map
74+
from thunder.torch import _torch_to_thunder_function_map, maybe_get_torch_to_thunder_symbol
7575
from thunder.clang import _clang_fn_set
7676
from thunder.core.pytree import tree_map, tree_iter
7777
from thunder.torch.experimental.dtensor_torch_and_aten_ops import register_dtensor_and_aten_function
@@ -386,14 +386,6 @@ def wrapper(*args, **kwargs):
386386
return wrapper
387387

388388

389-
_general_jit_lookaside_map.update(
390-
{
391-
k: ensure_recursive_proxies(interpreter_needs_wrap(record_source_loc_in_symbol_header(v)))
392-
for k, v in _torch_to_thunder_function_map.items()
393-
}
394-
)
395-
396-
397389
def register_general_jit_lookaside(diverted_fn):
398390
def lookaside_wrapper(lookaside):
399391
_general_jit_lookaside_map[diverted_fn] = lookaside
@@ -1166,6 +1158,10 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable:
11661158
# NOTE clang operations are not symbols, but we still prevent their internals from being jitted
11671159
recursively_proxy(*args, **kwargs)
11681160
lookaside = interpreter_needs_wrap(record_source_loc_in_symbol_header(fn))
1161+
elif (torch_lookaside := maybe_get_torch_to_thunder_symbol(fn)) is not None:
1162+
lookaside = ensure_recursive_proxies(
1163+
interpreter_needs_wrap(record_source_loc_in_symbol_header(torch_lookaside))
1164+
)
11691165
elif (general_jit_lookaside := _general_jit_lookaside_map.get(fn, None)) is not None:
11701166
lookaside = general_jit_lookaside
11711167
else:

thunder/torch/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ def register_function(torchfn, thunderfn_impl):
247247
_torch_to_thunder_function_map[torchfn] = thunderfn_impl
248248

249249

250+
def maybe_get_torch_to_thunder_symbol(torchfn):
251+
return _torch_to_thunder_function_map.get(torchfn, None)
252+
253+
250254
def _copy_(a, b, /):
251255
cd = get_compile_data()
252256
b = clang.maybe_convert_to_dtype(b, a.dtype)

0 commit comments

Comments
 (0)