-
Notifications
You must be signed in to change notification settings - Fork 287
chore typing and some bug fixes #2241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |||||
| import urllib.error | ||||||
| from datetime import datetime, timezone | ||||||
| from textwrap import dedent | ||||||
| from typing import get_args | ||||||
| from typing import Any, get_args | ||||||
| from urllib.parse import urlparse | ||||||
|
|
||||||
| # if autocomplete doesn't exist, just do nothing, don't break | ||||||
|
|
@@ -151,7 +151,7 @@ def add_argument(self, *args, help=None, default=None, completer=None, **kwargs) | |||||
| kwargs['help'] += f' (default: {default})' | ||||||
| action = super().add_argument(*args, **kwargs) | ||||||
| if completer is not None: | ||||||
| action.completer = completer | ||||||
| action.completer = completer # type: ignore[attr-defined] | ||||||
| return action | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -314,7 +314,7 @@ def parse_arguments(parser): | |||||
| def post_parse_setup(args): | ||||||
| """Perform additional setup after parsing arguments.""" | ||||||
|
|
||||||
| def map_https_to_transport(input: str) -> str | None: | ||||||
| def map_https_to_transport(input: str) -> str: | ||||||
| if input.startswith("https://") or input.startswith("http://"): | ||||||
| url = urlparse(input) | ||||||
| # detect if the whole repo is defined or a specific file | ||||||
|
|
@@ -468,7 +468,7 @@ def bench_cli(args): | |||||
| model.bench(args, assemble_command(args)) | ||||||
|
|
||||||
|
|
||||||
| def add_network_argument(parser, dflt="none"): | ||||||
| def add_network_argument(parser, dflt: str | None = "none"): | ||||||
| # Disable network access by default, and give the option to pass any supported network mode into | ||||||
| # podman if needed: | ||||||
| # https://docs.podman.io/en/latest/markdown/podman-run.1.html#network-mode-net | ||||||
|
|
@@ -573,13 +573,11 @@ def _list_models_from_store(args): | |||||
| size_sum += file.size | ||||||
| last_modified = max(file.modified, last_modified) | ||||||
|
|
||||||
| ret.append( | ||||||
| { | ||||||
| "name": f"{model} (partial)" if is_partially_downloaded else model, | ||||||
| "modified": datetime.fromtimestamp(last_modified, tz=local_timezone).isoformat(), | ||||||
| "size": size_sum, | ||||||
| } | ||||||
| ) | ||||||
| ret.append({ | ||||||
| "name": f"{model} (partial)" if is_partially_downloaded else model, | ||||||
| "modified": datetime.fromtimestamp(last_modified, tz=local_timezone).isoformat(), | ||||||
| "size": size_sum, | ||||||
| }) | ||||||
|
|
||||||
| # sort the listed models according to the desired order | ||||||
| ret.sort(key=lambda entry: entry[args.sort], reverse=args.order == "desc") | ||||||
|
|
@@ -592,7 +590,7 @@ def _list_models(args): | |||||
|
|
||||||
|
|
||||||
| def info_cli(args): | ||||||
| info = { | ||||||
| info: dict[str, Any] = { | ||||||
| "Accelerator": get_accel(), | ||||||
| "Config": load_file_config(), | ||||||
| "Engine": { | ||||||
|
|
@@ -1335,6 +1333,9 @@ def version_parser(subparsers): | |||||
|
|
||||||
| class AddPathOrUrl(argparse.Action): | ||||||
| def __call__(self, parser, namespace, values, option_string=None): | ||||||
| if not isinstance(values, list): | ||||||
| raise ValueError("AddPathOrUrl can only be used with the settings `nargs='+'`") | ||||||
|
|
||||||
| setattr(namespace, self.dest, []) | ||||||
| namespace.urls = [] | ||||||
| for value in values: | ||||||
|
|
@@ -1516,7 +1517,7 @@ def inspect_cli(args): | |||||
|
|
||||||
|
|
||||||
| def main() -> None: | ||||||
| def eprint(e, exit_code): | ||||||
| def eprint(e: Exception | str, exit_code: int): | ||||||
| try: | ||||||
| if args.debug: | ||||||
| logger.exception(e) | ||||||
|
|
@@ -1545,7 +1546,7 @@ def eprint(e, exit_code): | |||||
| except urllib.error.HTTPError as e: | ||||||
| eprint(f"pulling {e.geturl()} failed: {e}", errno.EINVAL) | ||||||
| except HelpException: | ||||||
| parser.print_help() | ||||||
| parser.print_help() # type: ignore[possibly-unbound] | ||||||
| except (ConnectionError, IndexError, KeyError, ValueError, NoRefFileFound) as e: | ||||||
| eprint(e, errno.EINVAL) | ||||||
| except NotImplementedError as e: | ||||||
|
|
@@ -1563,12 +1564,11 @@ def eprint(e, exit_code): | |||||
| except ParseError as e: | ||||||
| eprint(f"Failed to parse model: {e}", errno.EINVAL) | ||||||
| except SafetensorModelNotSupported: | ||||||
| eprint( | ||||||
| f"""Safetensor models are not supported. Please convert it to GGUF via: | ||||||
| $ ramalama convert --gguf=<quantization> {args.model} <oci-name> | ||||||
| $ ramalama run <oci-name> | ||||||
| """, | ||||||
| errno.ENOTSUP, | ||||||
| message = ( | ||||||
| "Safetensor models are not supported. Please convert it to GGUF via:" | ||||||
| f"$ ramalama convert --gguf=<quantization> {args.model} <oci-name>" # type: ignore[possibly-unbound] | ||||||
| "$ ramalama run <oci-name>" | ||||||
| ) | ||||||
| eprint(message, errno.ENOTSUP) | ||||||
sourcery-ai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| except NoGGUFModelFileFound: | ||||||
| eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) | ||||||
| eprint(f"No GGUF model file found for downloaded model '{args.model}'", errno.ENOENT) # type: ignore | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency with the
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The line exceeds 120 characters with the error specified |
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,8 +5,9 @@ | |
| import sys | ||
| import time | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable, Sequence | ||
| from http.client import HTTPConnection, HTTPException | ||
| from typing import Any, Callable | ||
| from typing import Any, cast | ||
|
|
||
| # Live reference for checking global vars | ||
| import ramalama.common | ||
|
|
@@ -20,10 +21,10 @@ class BaseEngine(ABC): | |
|
|
||
| def __init__(self, args): | ||
| base = os.path.basename(args.engine) | ||
| self.use_docker = base == "docker" | ||
| self.use_podman = base == "podman" | ||
| self.use_docker: bool = base == "docker" | ||
| self.use_podman: bool = base == "podman" | ||
| self.args = args | ||
| self.exec_args = [self.args.engine] | ||
| self.exec_args: list[str] = [self.args.engine] | ||
| self.base_args() | ||
| self.add_labels() | ||
| self.add_network() | ||
|
|
@@ -35,10 +36,10 @@ def __init__(self, args): | |
| @abstractmethod | ||
| def base_args(self): ... | ||
|
|
||
| def add_label(self, label): | ||
| def add_label(self, label: str): | ||
| self.add(["--label", label]) | ||
|
|
||
| def add_name(self, name): | ||
| def add_name(self, name: str): | ||
| self.add(["--name", name]) | ||
|
|
||
| def add_labels(self): | ||
|
|
@@ -112,8 +113,8 @@ def handle_podman_specifics(self): | |
| if getattr(self.args, "podman_keep_groups", None): | ||
| self.exec_args += ["--group-add", "keep-groups"] | ||
|
|
||
| def add(self, newargs): | ||
| self.exec_args += newargs | ||
| def add(self, newargs: Sequence[str]): | ||
| self.exec_args.extend(newargs) | ||
|
|
||
| def add_args(self, *args: str) -> None: | ||
| self.add(args) | ||
|
|
@@ -319,7 +320,7 @@ def containers(args): | |
| raise (e) | ||
|
|
||
|
|
||
| def info(args): | ||
| def info(args) -> list[Any] | str | dict[str, Any]: | ||
| conman = str(args.engine) if args.engine is not None else None | ||
| if conman == "" or conman is None: | ||
| raise ValueError("no container manager (Podman, Docker) found") | ||
|
|
@@ -329,12 +330,12 @@ def info(args): | |
| output = run_cmd(conman_args).stdout.decode("utf-8").strip() | ||
| if output == "": | ||
| return [] | ||
| return json.loads(output) | ||
| return cast(dict[str, Any], json.loads(output)) | ||
|
||
| except FileNotFoundError as e: | ||
| return str(e) | ||
|
|
||
|
|
||
| def inspect(args, name, format=None, ignore_stderr=False): | ||
| def inspect(args, name: str, format: str | None = None, ignore_stderr: bool = False): | ||
| if not name: | ||
| raise ValueError("must specify a container name") | ||
| conman = str(args.engine) if args.engine is not None else None | ||
|
|
@@ -349,7 +350,7 @@ def inspect(args, name, format=None, ignore_stderr=False): | |
| return run_cmd(conman_args, ignore_stderr=ignore_stderr).stdout.decode("utf-8").strip() | ||
|
|
||
|
|
||
| def logs(args, name, ignore_stderr=False): | ||
| def logs(args, name: str, ignore_stderr: bool = False): | ||
| if not name: | ||
| raise ValueError("must specify a container name") | ||
| conman = str(args.engine) if args.engine is not None else None | ||
|
|
@@ -360,7 +361,7 @@ def logs(args, name, ignore_stderr=False): | |
| return run_cmd(conman_args, ignore_stderr=ignore_stderr).stdout.decode("utf-8").strip() | ||
|
|
||
|
|
||
| def stop_container(args, name, remove=False): | ||
| def stop_container(args, name: str, remove: bool = False): | ||
| if not name: | ||
| raise ValueError("must specify a container name") | ||
| conman = str(args.engine) if args.engine is not None else None | ||
|
|
@@ -416,7 +417,7 @@ def stop_container(args, name, remove=False): | |
| raise | ||
|
|
||
|
|
||
| def add_labels(args, add_label): | ||
| def add_labels(args, add_label: Callable[[str], None]): | ||
| label_map = { | ||
| "MODEL": "ai.ramalama.model", | ||
| "engine": "ai.ramalama.engine", | ||
|
|
@@ -452,7 +453,7 @@ def is_healthy(args, timeout: int = 3, model_name: str | None = None): | |
| model_names = [m["name"] for m in body["models"]] | ||
| if not model_name: | ||
| # The transport and tag is not included in the model name returned by the endpoint | ||
| model_name = args.MODEL.split("://")[-1] | ||
| model_name = cast(str, args.MODEL.split("://")[-1]) | ||
| model_name = model_name.split(":")[0] | ||
| if not any(model_name in name for name in model_names): | ||
| logger.debug(f'Container {args.name} does not include "{model_name}" in the model list: {model_names}') | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.