|
71 | 71 | from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions |
72 | 72 | from thunder.core.symbol import Symbol |
73 | 73 | from thunder.core.trace import TraceCtx, TraceResults |
74 | | -from thunder.torch import _torch_to_thunder_function_map |
| 74 | +from thunder.torch import _torch_to_thunder_function_map, maybe_get_torch_to_thunder_symbol |
75 | 75 | from thunder.clang import _clang_fn_set |
76 | 76 | from thunder.core.pytree import tree_map, tree_iter |
77 | 77 | from thunder.torch.experimental.dtensor_torch_and_aten_ops import register_dtensor_and_aten_function |
@@ -386,14 +386,6 @@ def wrapper(*args, **kwargs): |
386 | 386 | return wrapper |
387 | 387 |
|
388 | 388 |
|
389 | | -_general_jit_lookaside_map.update( |
390 | | - { |
391 | | - k: ensure_recursive_proxies(interpreter_needs_wrap(record_source_loc_in_symbol_header(v))) |
392 | | - for k, v in _torch_to_thunder_function_map.items() |
393 | | - } |
394 | | -) |
395 | | - |
396 | | - |
397 | 389 | def register_general_jit_lookaside(diverted_fn): |
398 | 390 | def lookaside_wrapper(lookaside): |
399 | 391 | _general_jit_lookaside_map[diverted_fn] = lookaside |
@@ -1166,6 +1158,10 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: |
1166 | 1158 | # NOTE clang operations are not symbols, but we still prevent their internals from being jitted |
1167 | 1159 | recursively_proxy(*args, **kwargs) |
1168 | 1160 | lookaside = interpreter_needs_wrap(record_source_loc_in_symbol_header(fn)) |
| 1161 | + elif (torch_lookaside := maybe_get_torch_to_thunder_symbol(fn)) is not None: |
| 1162 | + lookaside = ensure_recursive_proxies( |
| 1163 | + interpreter_needs_wrap(record_source_loc_in_symbol_header(torch_lookaside)) |
| 1164 | + ) |
1169 | 1165 | elif (general_jit_lookaside := _general_jit_lookaside_map.get(fn, None)) is not None: |
1170 | 1166 | lookaside = general_jit_lookaside |
1171 | 1167 | else: |
|
0 commit comments