Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions ramalama/arg_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,45 @@ class ChatSubArgsType(Protocol):

class ChatArgsType(DefaultArgsType, ChatSubArgsType):
pass


class ServeRunArgsType(DefaultArgsType, Protocol):
"""Args for serve and run commands"""

MODEL: str
port: int | None
name: str | None
rag: str | None
subcommand: str
detach: bool | None
api: str | None
image: str
host: str | None
generate: str | None
context: int
cache_reuse: int
authfile: str | None
device: list[str] | None
env: list[str]
ARGS: list[str] | None # For run command
mcp: list[str] | None
summarize_after: int
# Chat/run specific options
color: COLOR_OPTIONS
prefix: str
rag_image: str | None


ServeRunArgs = protocol_to_dataclass(ServeRunArgsType)


class RagArgsType(ServeRunArgsType, Protocol):
"""Args when using RAG functionality - wraps model args"""

model_args: ServeRunArgsType
model_host: str
model_port: int
rag: str # Override to make rag required (not optional)


RagArgs = protocol_to_dataclass(RagArgsType)
42 changes: 21 additions & 21 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.error
from datetime import datetime, timezone
from textwrap import dedent
from typing import get_args
from typing import Any, get_args
from urllib.parse import urlparse

# if autocomplete doesn't exist, just do nothing, don't break
Expand Down Expand Up @@ -151,7 +151,7 @@ def add_argument(self, *args, help=None, default=None, completer=None, **kwargs)
kwargs['help'] += f' (default: {default})'
action = super().add_argument(*args, **kwargs)
if completer is not None:
action.completer = completer
action.completer = completer # type: ignore[attr-defined]
return action


Expand Down Expand Up @@ -314,7 +314,7 @@ def parse_arguments(parser):
def post_parse_setup(args):
"""Perform additional setup after parsing arguments."""

def map_https_to_transport(input: str) -> str | None:
def map_https_to_transport(input: str) -> str:
if input.startswith("https://") or input.startswith("http://"):
url = urlparse(input)
# detect if the whole repo is defined or a specific file
Expand Down Expand Up @@ -468,7 +468,7 @@ def bench_cli(args):
model.bench(args, assemble_command(args))


def add_network_argument(parser, dflt="none"):
def add_network_argument(parser, dflt: str | None = "none"):
# Disable network access by default, and give the option to pass any supported network mode into
# podman if needed:
# https://docs.podman.io/en/latest/markdown/podman-run.1.html#network-mode-net
Expand Down Expand Up @@ -573,13 +573,11 @@ def _list_models_from_store(args):
size_sum += file.size
last_modified = max(file.modified, last_modified)

ret.append(
{
"name": f"{model} (partial)" if is_partially_downloaded else model,
"modified": datetime.fromtimestamp(last_modified, tz=local_timezone).isoformat(),
"size": size_sum,
}
)
ret.append({
"name": f"{model} (partial)" if is_partially_downloaded else model,
"modified": datetime.fromtimestamp(last_modified, tz=local_timezone).isoformat(),
"size": size_sum,
})

# sort the listed models according to the desired order
ret.sort(key=lambda entry: entry[args.sort], reverse=args.order == "desc")
Expand All @@ -592,7 +590,7 @@ def _list_models(args):


def info_cli(args):
info = {
info: dict[str, Any] = {
"Accelerator": get_accel(),
"Config": load_file_config(),
"Engine": {
Expand Down Expand Up @@ -1335,6 +1333,9 @@ def version_parser(subparsers):

class AddPathOrUrl(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if not isinstance(values, list):
raise ValueError("AddPathOrUrl can only be used with the settings `nargs='+'`")

setattr(namespace, self.dest, [])
namespace.urls = []
for value in values:
Expand Down Expand Up @@ -1516,7 +1517,7 @@ def inspect_cli(args):


def main() -> None:
def eprint(e, exit_code):
def eprint(e: Exception | str, exit_code: int):
try:
if args.debug:
logger.exception(e)
Expand Down Expand Up @@ -1545,7 +1546,7 @@ def eprint(e, exit_code):
except urllib.error.HTTPError as e:
eprint(f"pulling {e.geturl()} failed: {e}", errno.EINVAL)
except HelpException:
parser.print_help()
parser.print_help() # type: ignore[possibly-unbound]
except (ConnectionError, IndexError, KeyError, ValueError, NoRefFileFound) as e:
eprint(e, errno.EINVAL)
except NotImplementedError as e:
Expand All @@ -1563,12 +1564,11 @@ def eprint(e, exit_code):
except ParseError as e:
eprint(f"Failed to parse model: {e}", errno.EINVAL)
except SafetensorModelNotSupported:
eprint(
f"""Safetensor models are not supported. Please convert it to GGUF via:
$ ramalama convert --gguf=<quantization> {args.model} <oci-name>
$ ramalama run <oci-name>
""",
errno.ENOTSUP,
message = (
"Safetensor models are not supported. Please convert it to GGUF via:\n"
f"$ ramalama convert --gguf=<quantization> {args.model} <oci-name>\n" # type: ignore[possibly-unbound]
"$ ramalama run <oci-name>\n"
)
eprint(message, errno.ENOTSUP)
except NoGGUFModelFileFound:
eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT)
eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the type: ignore on line 1569 and for better clarity, it's good practice to specify the error code being ignored. In this case, args could be unbound if init_cli() fails.

Suggested change
eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) # type: ignore
eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) # type: ignore[possibly-unbound]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line exceeds 120 characters with the error specified

14 changes: 11 additions & 3 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import string
import subprocess
import sys
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import lru_cache
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast, get_args
from typing import IO, TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, TypedDict, cast, get_args

import yaml

Expand Down Expand Up @@ -137,7 +137,15 @@ def exec_cmd(args, stdout2null: bool = False, stderr2null: bool = False):
raise


def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_all=False, encoding=None, env=None):
def run_cmd(
args: Sequence[str],
cwd: str | None = None,
stdout: int | IO[Any] | None = subprocess.PIPE,
ignore_stderr: bool = False,
ignore_all: bool = False,
encoding: str | None = None,
env: dict[str, str] | None = None,
) -> subprocess.CompletedProcess[Any]:
"""
Run the given command arguments.

Expand Down
5 changes: 1 addition & 4 deletions ramalama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def get_inference_spec_files() -> dict[str, Path]:
files: dict[str, Path] = {}

for spec_dir in get_all_inference_spec_dirs("engines"):

# Give preference to .yaml, then .json spec files
file_extensions = ["*.yaml", "*.yml", "*.json"]
for file_extension in file_extensions:
Expand All @@ -117,7 +116,6 @@ def get_inference_schema_files() -> dict[str, Path]:
files: dict[str, Path] = {}

for schema_dir in get_all_inference_spec_dirs("schema"):

for spec_file in sorted(Path(schema_dir).glob("schema.*.json")):
file = Path(spec_file)
version = file.name.replace("schema.", "").replace(".json", "")
Expand Down Expand Up @@ -279,8 +277,7 @@ def _finalize_engine(self: "Config"):

If Podman is detected on macOS without a configured machine, it falls back on docker availability.
"""
is_podman = self.engine is not None and os.path.basename(self.engine) == "podman"
if is_podman and sys.platform == "darwin":
if self.engine is not None and os.path.basename(self.engine) == "podman" and sys.platform == "darwin":
run_with_podman_engine = apple_vm(self.engine, self)
if not run_with_podman_engine and not self.is_set("engine"):
self.engine = "docker" if available("docker") else None
Expand Down
29 changes: 15 additions & 14 deletions ramalama/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import sys
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from http.client import HTTPConnection, HTTPException
from typing import Any, Callable
from typing import Any, cast

# Live reference for checking global vars
import ramalama.common
Expand All @@ -20,10 +21,10 @@ class BaseEngine(ABC):

def __init__(self, args):
base = os.path.basename(args.engine)
self.use_docker = base == "docker"
self.use_podman = base == "podman"
self.use_docker: bool = base == "docker"
self.use_podman: bool = base == "podman"
self.args = args
self.exec_args = [self.args.engine]
self.exec_args: list[str] = [self.args.engine]
self.base_args()
self.add_labels()
self.add_network()
Expand All @@ -35,10 +36,10 @@ def __init__(self, args):
@abstractmethod
def base_args(self): ...

def add_label(self, label):
def add_label(self, label: str):
self.add(["--label", label])

def add_name(self, name):
def add_name(self, name: str):
self.add(["--name", name])

def add_labels(self):
Expand Down Expand Up @@ -112,8 +113,8 @@ def handle_podman_specifics(self):
if getattr(self.args, "podman_keep_groups", None):
self.exec_args += ["--group-add", "keep-groups"]

def add(self, newargs):
self.exec_args += newargs
def add(self, newargs: Sequence[str]):
self.exec_args.extend(newargs)

def add_args(self, *args: str) -> None:
self.add(args)
Expand Down Expand Up @@ -319,7 +320,7 @@ def containers(args):
raise (e)


def info(args):
def info(args) -> list[Any] | str | dict[str, Any]:
conman = str(args.engine) if args.engine is not None else None
if conman == "" or conman is None:
raise ValueError("no container manager (Podman, Docker) found")
Expand All @@ -334,7 +335,7 @@ def info(args):
return str(e)


def inspect(args, name, format=None, ignore_stderr=False):
def inspect(args, name: str, format: str | None = None, ignore_stderr: bool = False):
if not name:
raise ValueError("must specify a container name")
conman = str(args.engine) if args.engine is not None else None
Expand All @@ -349,7 +350,7 @@ def inspect(args, name, format=None, ignore_stderr=False):
return run_cmd(conman_args, ignore_stderr=ignore_stderr).stdout.decode("utf-8").strip()


def logs(args, name, ignore_stderr=False):
def logs(args, name: str, ignore_stderr: bool = False):
if not name:
raise ValueError("must specify a container name")
conman = str(args.engine) if args.engine is not None else None
Expand All @@ -360,7 +361,7 @@ def logs(args, name, ignore_stderr=False):
return run_cmd(conman_args, ignore_stderr=ignore_stderr).stdout.decode("utf-8").strip()


def stop_container(args, name, remove=False):
def stop_container(args, name: str, remove: bool = False):
if not name:
raise ValueError("must specify a container name")
conman = str(args.engine) if args.engine is not None else None
Expand Down Expand Up @@ -416,7 +417,7 @@ def stop_container(args, name, remove=False):
raise


def add_labels(args, add_label):
def add_labels(args, add_label: Callable[[str], None]):
label_map = {
"MODEL": "ai.ramalama.model",
"engine": "ai.ramalama.engine",
Expand Down Expand Up @@ -452,7 +453,7 @@ def is_healthy(args, timeout: int = 3, model_name: str | None = None):
model_names = [m["name"] for m in body["models"]]
if not model_name:
# The transport and tag is not included in the model name returned by the endpoint
model_name = args.MODEL.split("://")[-1]
model_name = cast(str, args.MODEL.split("://")[-1])
model_name = model_name.split(":")[0]
if not any(model_name in name for name in model_names):
logger.debug(f'Container {args.name} does not include "{model_name}" in the model list: {model_names}')
Expand Down
24 changes: 13 additions & 11 deletions ramalama/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import tempfile
from functools import partial
from textwrap import dedent
from typing import Literal

from ramalama.arg_types import RagArgsType, ServeRunArgsType
from ramalama.chat import ChatOperationalArgs
from ramalama.common import accel_image, perror, set_accel_env_vars
from ramalama.compat import StrEnum
Expand Down Expand Up @@ -149,7 +151,7 @@ class RagTransport(OCI):

type: str = "Model+RAG"

def __init__(self, imodel: Transport, cmd: list[str], args):
def __init__(self, imodel: Transport, cmd: list[str], args: RagArgsType):
super().__init__(args.rag, args.store, args.engine)
self.imodel = imodel
self.model_cmd = cmd
Expand All @@ -163,21 +165,21 @@ def exists(self) -> bool:
return os.path.exists(self.model)
return super().exists()

def new_engine(self, args):
def new_engine(self, args: RagArgsType) -> "RagEngine":
return RagEngine(args, sourcetype=self.kind)

def setup_mounts(self, args):
def setup_mounts(self, args: RagArgsType) -> None:
pass

def chat_operational_args(self, args):
def chat_operational_args(self, args: RagArgsType) -> ChatOperationalArgs:
return ChatOperationalArgs(name=args.model_args.name)

def _handle_container_chat(self, args, pid):
def _handle_container_chat(self, args: RagArgsType, server_pid: int) -> Literal[0]:
# Clear args.rag so RamaLamaShell doesn't treat it as local data for RAG context
args.rag = None
super()._handle_container_chat(args, pid)
return super()._handle_container_chat(args, server_pid)

def _start_model(self, args, cmd: list[str]):
def _start_model(self, args: ServeRunArgsType, cmd: list[str]) -> int | None:
pid = self.imodel._fork_and_serve(args, self.model_cmd)
if pid:
_, status = os.waitpid(pid, 0)
Expand All @@ -188,15 +190,15 @@ def _start_model(self, args, cmd: list[str]):
)
return pid

def serve(self, args, cmd: list[str]):
def serve(self, args: RagArgsType, cmd: list[str]):
pid = self._start_model(args.model_args, cmd)
if pid:
super().serve(args, cmd)

def run(self, args, cmd: list[str]):
def run(self, args: RagArgsType, server_cmd: list[str]):
args.model_args.name = self.imodel.get_container_name(args.model_args)
super().run(args, cmd)
super().run(args, server_cmd)

def wait_for_healthy(self, args):
def wait_for_healthy(self, args: RagArgsType) -> None:
self.imodel.wait_for_healthy(args.model_args)
wait_for_healthy(args, partial(is_healthy, model_name=f"{self.imodel.model_name}+rag"))
2 changes: 1 addition & 1 deletion ramalama/shortnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def __init__(self):
def _strip_quotes(self, s) -> str:
return s.strip("'\"")

def resolve(self, model) -> str | None:
def resolve(self, model: str) -> str:
return self.shortnames.get(model, model)
Loading
Loading