diff --git a/ramalama/arg_types.py b/ramalama/arg_types.py index 4a779a8d2..f45c10404 100644 --- a/ramalama/arg_types.py +++ b/ramalama/arg_types.py @@ -63,3 +63,45 @@ class ChatSubArgsType(Protocol): class ChatArgsType(DefaultArgsType, ChatSubArgsType): pass + + +class ServeRunArgsType(DefaultArgsType, Protocol): + """Args for serve and run commands""" + + MODEL: str + port: int | None + name: str | None + rag: str | None + subcommand: str + detach: bool | None + api: str | None + image: str + host: str | None + generate: str | None + context: int + cache_reuse: int + authfile: str | None + device: list[str] | None + env: list[str] + ARGS: list[str] | None # For run command + mcp: list[str] | None + summarize_after: int + # Chat/run specific options + color: COLOR_OPTIONS + prefix: str + rag_image: str | None + + +ServeRunArgs = protocol_to_dataclass(ServeRunArgsType) + + +class RagArgsType(ServeRunArgsType, Protocol): + """Args when using RAG functionality - wraps model args""" + + model_args: ServeRunArgsType + model_host: str + model_port: int + rag: str # Override to make rag required (not optional) + + +RagArgs = protocol_to_dataclass(RagArgsType) diff --git a/ramalama/cli.py b/ramalama/cli.py index a35636b1a..6d60ce98b 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -9,7 +9,7 @@ import urllib.error from datetime import datetime, timezone from textwrap import dedent -from typing import get_args +from typing import Any, get_args from urllib.parse import urlparse # if autocomplete doesn't exist, just do nothing, don't break @@ -151,7 +151,7 @@ def add_argument(self, *args, help=None, default=None, completer=None, **kwargs) kwargs['help'] += f' (default: {default})' action = super().add_argument(*args, **kwargs) if completer is not None: - action.completer = completer + action.completer = completer # type: ignore[attr-defined] return action @@ -314,7 +314,7 @@ def parse_arguments(parser): def post_parse_setup(args): """Perform additional setup after parsing arguments.""" - def map_https_to_transport(input: str) -> str | None: + def map_https_to_transport(input: str) -> str: if input.startswith("https://") or input.startswith("http://"): url = urlparse(input) # detect if the whole repo is defined or a specific file @@ -468,7 +468,7 @@ def bench_cli(args): model.bench(args, assemble_command(args)) -def add_network_argument(parser, dflt="none"): +def add_network_argument(parser, dflt: str | None = "none"): # Disable network access by default, and give the option to pass any supported network mode into # podman if needed: # https://docs.podman.io/en/latest/markdown/podman-run.1.html#network-mode-net @@ -573,13 +573,11 @@ def _list_models_from_store(args): size_sum += file.size last_modified = max(file.modified, last_modified) - ret.append( - { - "name": f"{model} (partial)" if is_partially_downloaded else model, - "modified": datetime.fromtimestamp(last_modified, tz=local_timezone).isoformat(), - "size": size_sum, - } - ) + ret.append({ + "name": f"{model} (partial)" if is_partially_downloaded else model, + "modified": datetime.fromtimestamp(last_modified, tz=local_timezone).isoformat(), + "size": size_sum, + }) # sort the listed models according to the desired order ret.sort(key=lambda entry: entry[args.sort], reverse=args.order == "desc") @@ -592,7 +590,7 @@ def _list_models(args): def info_cli(args): - info = { + info: dict[str, Any] = { "Accelerator": get_accel(), "Config": load_file_config(), "Engine": { @@ -1335,6 +1333,9 @@ def version_parser(subparsers): class AddPathOrUrl(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): + if not isinstance(values, list): + raise ValueError("AddPathOrUrl can only be used with the settings `nargs='+'`") + setattr(namespace, self.dest, []) namespace.urls = [] for value in values: @@ -1516,7 +1517,7 @@ def inspect_cli(args): def main() -> None: - def eprint(e, exit_code): + def eprint(e: Exception | str, exit_code: int): try: if args.debug: logger.exception(e) @@ -1545,7 +1546,7 @@ def eprint(e, exit_code): except urllib.error.HTTPError as e: eprint(f"pulling {e.geturl()} failed: {e}", errno.EINVAL) except HelpException: - parser.print_help() + parser.print_help() # type: ignore[possibly-unbound] except (ConnectionError, IndexError, KeyError, ValueError, NoRefFileFound) as e: eprint(e, errno.EINVAL) except NotImplementedError as e: @@ -1563,12 +1564,11 @@ def eprint(e, exit_code): except ParseError as e: eprint(f"Failed to parse model: {e}", errno.EINVAL) except SafetensorModelNotSupported: - eprint( - f"""Safetensor models are not supported. Please convert it to GGUF via: -$ ramalama convert --gguf= {args.model} -$ ramalama run -""", - errno.ENOTSUP, + message = ( + "Safetensor models are not supported. Please convert it to GGUF via:\n" + f"$ ramalama convert --gguf= {args.model} \n" # type: ignore[possibly-unbound] + "$ ramalama run \n" ) + eprint(message, errno.ENOTSUP) except NoGGUFModelFileFound: - eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) + eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) # type: ignore diff --git a/ramalama/common.py b/ramalama/common.py index b4700042b..64c7f2176 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -13,9 +13,9 @@ import string import subprocess import sys -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast, get_args +from typing import IO, TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, TypedDict, cast, get_args import yaml @@ -137,7 +137,15 @@ def exec_cmd(args, stdout2null: bool = False, stderr2null: bool = False): raise -def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_all=False, encoding=None, env=None): +def run_cmd( + args: Sequence[str], + cwd: str | None = None, + stdout: int | IO[Any] | None = subprocess.PIPE, + ignore_stderr: bool = False, + ignore_all: bool = False, + encoding: str | None = None, + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[Any]: """ Run the given command arguments. diff --git a/ramalama/config.py b/ramalama/config.py index d616ade92..f693258ed 100644 --- a/ramalama/config.py +++ b/ramalama/config.py @@ -99,7 +99,6 @@ def get_inference_spec_files() -> dict[str, Path]: files: dict[str, Path] = {} for spec_dir in get_all_inference_spec_dirs("engines"): - # Give preference to .yaml, then .json spec files file_extensions = ["*.yaml", "*.yml", "*.json"] for file_extension in file_extensions: @@ -117,7 +116,6 @@ def get_inference_schema_files() -> dict[str, Path]: files: dict[str, Path] = {} for schema_dir in get_all_inference_spec_dirs("schema"): - for spec_file in sorted(Path(schema_dir).glob("schema.*.json")): file = Path(spec_file) version = file.name.replace("schema.", "").replace(".json", "") @@ -279,8 +277,7 @@ def _finalize_engine(self: "Config"): If Podman is detected on macOS without a configured machine, it falls back on docker availability. """ - is_podman = self.engine is not None and os.path.basename(self.engine) == "podman" - if is_podman and sys.platform == "darwin": + if self.engine is not None and os.path.basename(self.engine) == "podman" and sys.platform == "darwin": run_with_podman_engine = apple_vm(self.engine, self) if not run_with_podman_engine and not self.is_set("engine"): self.engine = "docker" if available("docker") else None diff --git a/ramalama/engine.py b/ramalama/engine.py index 42a487672..ab402a402 100644 --- a/ramalama/engine.py +++ b/ramalama/engine.py @@ -5,8 +5,9 @@ import sys import time from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence from http.client import HTTPConnection, HTTPException -from typing import Any, Callable +from typing import Any, cast # Live reference for checking global vars import ramalama.common @@ -20,10 +21,10 @@ class BaseEngine(ABC): def __init__(self, args): base = os.path.basename(args.engine) - self.use_docker = base == "docker" - self.use_podman = base == "podman" + self.use_docker: bool = base == "docker" + self.use_podman: bool = base == "podman" self.args = args - self.exec_args = [self.args.engine] + self.exec_args: list[str] = [self.args.engine] self.base_args() self.add_labels() self.add_network() @@ -35,10 +36,10 @@ def __init__(self, args): @abstractmethod def base_args(self): ... - def add_label(self, label): + def add_label(self, label: str): self.add(["--label", label]) - def add_name(self, name): + def add_name(self, name: str): self.add(["--name", name]) def add_labels(self): @@ -112,8 +113,8 @@ def handle_podman_specifics(self): if getattr(self.args, "podman_keep_groups", None): self.exec_args += ["--group-add", "keep-groups"] - def add(self, newargs): - self.exec_args += newargs + def add(self, newargs: Sequence[str]): + self.exec_args.extend(newargs) def add_args(self, *args: str) -> None: self.add(args) @@ -319,7 +320,7 @@ def containers(args): raise (e) -def info(args): +def info(args) -> list[Any] | str | dict[str, Any]: conman = str(args.engine) if args.engine is not None else None if conman == "" or conman is None: raise ValueError("no container manager (Podman, Docker) found") @@ -334,7 +335,7 @@ def info(args): return str(e) -def inspect(args, name, format=None, ignore_stderr=False): +def inspect(args, name: str, format: str | None = None, ignore_stderr: bool = False): if not name: raise ValueError("must specify a container name") conman = str(args.engine) if args.engine is not None else None @@ -349,7 +350,7 @@ def inspect(args, name, format=None, ignore_stderr=False): return run_cmd(conman_args, ignore_stderr=ignore_stderr).stdout.decode("utf-8").strip() -def logs(args, name, ignore_stderr=False): +def logs(args, name: str, ignore_stderr: bool = False): if not name: raise ValueError("must specify a container name") conman = str(args.engine) if args.engine is not None else None @@ -360,7 +361,7 @@ def logs(args, name, ignore_stderr=False): return run_cmd(conman_args, ignore_stderr=ignore_stderr).stdout.decode("utf-8").strip() -def stop_container(args, name, remove=False): +def stop_container(args, name: str, remove: bool = False): if not name: raise ValueError("must specify a container name") conman = str(args.engine) if args.engine is not None else None @@ -416,7 +417,7 @@ def stop_container(args, name, remove=False): raise -def add_labels(args, add_label): +def add_labels(args, add_label: Callable[[str], None]): label_map = { "MODEL": "ai.ramalama.model", "engine": "ai.ramalama.engine", @@ -452,7 +453,7 @@ def is_healthy(args, timeout: int = 3, model_name: str | None = None): model_names = [m["name"] for m in body["models"]] if not model_name: # The transport and tag is not included in the model name returned by the endpoint - model_name = args.MODEL.split("://")[-1] + model_name = cast(str, args.MODEL.split("://")[-1]) model_name = model_name.split(":")[0] if not any(model_name in name for name in model_names): logger.debug(f'Container {args.name} does not include "{model_name}" in the model list: {model_names}') diff --git a/ramalama/rag.py b/ramalama/rag.py index 608e8c977..fec138bb2 100644 --- a/ramalama/rag.py +++ b/ramalama/rag.py @@ -3,7 +3,9 @@ import tempfile from functools import partial from textwrap import dedent +from typing import Literal +from ramalama.arg_types import RagArgsType, ServeRunArgsType from ramalama.chat import ChatOperationalArgs from ramalama.common import accel_image, perror, set_accel_env_vars from ramalama.compat import StrEnum @@ -149,7 +151,7 @@ class RagTransport(OCI): type: str = "Model+RAG" - def __init__(self, imodel: Transport, cmd: list[str], args): + def __init__(self, imodel: Transport, cmd: list[str], args: RagArgsType): super().__init__(args.rag, args.store, args.engine) self.imodel = imodel self.model_cmd = cmd @@ -163,21 +165,21 @@ def exists(self) -> bool: return os.path.exists(self.model) return super().exists() - def new_engine(self, args): + def new_engine(self, args: RagArgsType) -> "RagEngine": return RagEngine(args, sourcetype=self.kind) - def setup_mounts(self, args): + def setup_mounts(self, args: RagArgsType) -> None: pass - def chat_operational_args(self, args): + def chat_operational_args(self, args: RagArgsType) -> ChatOperationalArgs: return ChatOperationalArgs(name=args.model_args.name) - def _handle_container_chat(self, args, pid): + def _handle_container_chat(self, args: RagArgsType, server_pid: int) -> Literal[0]: # Clear args.rag so RamaLamaShell doesn't treat it as local data for RAG context args.rag = None - super()._handle_container_chat(args, pid) + return super()._handle_container_chat(args, server_pid) - def _start_model(self, args, cmd: list[str]): + def _start_model(self, args: ServeRunArgsType, cmd: list[str]) -> int | None: pid = self.imodel._fork_and_serve(args, self.model_cmd) if pid: _, status = os.waitpid(pid, 0) @@ -188,15 +190,15 @@ def _start_model(self, args, cmd: list[str]): ) return pid - def serve(self, args, cmd: list[str]): + def serve(self, args: RagArgsType, cmd: list[str]): pid = self._start_model(args.model_args, cmd) if pid: super().serve(args, cmd) - def run(self, args, cmd: list[str]): + def run(self, args: RagArgsType, server_cmd: list[str]): args.model_args.name = self.imodel.get_container_name(args.model_args) - super().run(args, cmd) + super().run(args, server_cmd) - def wait_for_healthy(self, args): + def wait_for_healthy(self, args: RagArgsType) -> None: self.imodel.wait_for_healthy(args.model_args) wait_for_healthy(args, partial(is_healthy, model_name=f"{self.imodel.model_name}+rag")) diff --git a/ramalama/shortnames.py b/ramalama/shortnames.py index 077609325..d398814fe 100644 --- a/ramalama/shortnames.py +++ b/ramalama/shortnames.py @@ -53,5 +53,5 @@ def __init__(self): def _strip_quotes(self, s) -> str: return s.strip("'\"") - def resolve(self, model) -> str | None: + def resolve(self, model: str) -> str: return self.shortnames.get(model, model) diff --git a/ramalama/transports/base.py b/ramalama/transports/base.py index 45fffa291..828280ff7 100644 --- a/ramalama/transports/base.py +++ b/ramalama/transports/base.py @@ -6,7 +6,10 @@ import sys import time from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + from ramalama.chat import ChatOperationalArgs import ramalama.chat as chat from ramalama.common import ( @@ -428,7 +431,7 @@ def _connect_and_chat(self, args, server_pid): chat.chat(args) return 0 - def chat_operational_args(self, args): + def chat_operational_args(self, args) -> "ChatOperationalArgs | None": return None def wait_for_healthy(self, args): @@ -708,8 +711,12 @@ def print_pull_message(self, model_name): def compute_ports(exclude: list[str] | None = None) -> list[int]: - excluded = exclude and set(map(int, exclude)) or set() - ports = list(sorted(set(range(DEFAULT_PORT_RANGE[0], DEFAULT_PORT_RANGE[1] + 1)) - excluded)) + excluded = set() if exclude is None else set(map(int, exclude)) + ports = [p for p in range(DEFAULT_PORT_RANGE[0], DEFAULT_PORT_RANGE[1] + 1) if p not in excluded] + + if not ports: + raise ValueError("All ports in the DEFAULT_PORT_RANGE were exhausted by the exclusion list.") + first_port = ports.pop(0) random.shuffle(ports) # try always the first port before the randomized others diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 4ff88b6d7..e7c0a7b3e 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -15,11 +15,22 @@ from ramalama.log_levels import LogLevel +@pytest.fixture(autouse=True) +def isolate_config(): + """ + Isolate tests from user configuration files by mocking config loading. + This fixture is automatically used for all tests in this module. + Individual tests can override by explicitly patching if needed. + """ + with patch("ramalama.config.load_file_config", return_value={}): + with patch("ramalama.config.apple_vm", return_value=False): + yield + + def test_correct_config_defaults(monkeypatch): monkeypatch.delenv("RAMALAMA_IMAGE", raising=False) - with patch("ramalama.config.load_file_config", return_value={}): - with patch("ramalama.config.load_env_config", return_value={}): - cfg = default_config() + with patch("ramalama.config.load_env_config", return_value={}): + cfg = default_config() assert cfg.carimage == "registry.access.redhat.com/ubi10-micro:latest" assert cfg.container in [True, False] # depends on env/system