Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def create_app_wrapper(self):
instance_types = self.config.machine_types
env = self.config.envs
local_interactive = self.config.local_interactive
image_pull_secrets = self.config.image_pull_secrets
return generate_appwrapper(
name=name,
namespace=namespace,
Expand All @@ -100,6 +101,7 @@ def create_app_wrapper(self):
instance_types=instance_types,
env=env,
local_interactive=local_interactive,
image_pull_secrets=image_pull_secrets,
)

# creates a new cluster with the provided or default spec
Expand Down
1 change: 1 addition & 0 deletions src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ class ClusterConfiguration:
envs: dict = field(default_factory=dict)
image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
local_interactive: bool = False
image_pull_secrets: list = field(default_factory=list)
20 changes: 20 additions & 0 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def update_image(spec, image):
container["image"] = image


def update_image_pull_secrets(spec, image_pull_secrets):
if image_pull_secrets:
if "imagePullSecrets" not in spec:
spec["imagePullSecrets"] = []
for image_pull_secret in image_pull_secrets:
spec["imagePullSecrets"].append({"name": image_pull_secret})


def update_env(spec, env):
containers = spec.get("containers")
for container in containers:
Expand Down Expand Up @@ -178,6 +186,7 @@ def update_nodes(
image,
instascale,
env,
image_pull_secrets,
):
if "generictemplate" in item.keys():
head = item.get("generictemplate").get("spec").get("headGroupSpec")
Expand All @@ -193,6 +202,7 @@ def update_nodes(
for comp in [head, worker]:
spec = comp.get("template").get("spec")
update_affinity(spec, appwrapper_name, instascale)
update_image_pull_secrets(spec, image_pull_secrets)
update_image(spec, image)
update_env(spec, env)
if comp == head:
Expand Down Expand Up @@ -295,6 +305,7 @@ def generate_appwrapper(
instance_types: list,
env,
local_interactive: bool,
image_pull_secrets: list,
):
user_yaml = read_template(template)
appwrapper_name, cluster_name = gen_names(name)
Expand All @@ -318,6 +329,7 @@ def generate_appwrapper(
image,
instascale,
env,
image_pull_secrets,
)
update_dashboard_route(route_item, cluster_name, namespace)
if local_interactive:
Expand Down Expand Up @@ -409,6 +421,12 @@ def main(): # pragma: no cover
default=False,
help="Enable local interactive mode",
)
parser.add_argument(
"--image-pull-secrets",
required=False,
default=[],
help="Set image pull secrets for private registries",
)

args = parser.parse_args()
name = args.name
Expand All @@ -425,6 +443,7 @@ def main(): # pragma: no cover
namespace = args.namespace
local_interactive = args.local_interactive
env = {}
image_pull_secrets = args.image_pull_secrets

outfile = generate_appwrapper(
name,
Expand All @@ -441,6 +460,7 @@ def main(): # pragma: no cover
instance_types,
local_interactive,
env,
image_pull_secrets,
)
return outfile

Expand Down
4 changes: 4 additions & 0 deletions tests/test-case.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ spec:
cpu: 2
memory: 8G
nvidia.com/gpu: 0
imagePullSecrets:
- name: unit-test-pull-secret
rayVersion: 2.1.0
workerGroupSpecs:
- groupName: small-group-unit-test-cluster
Expand Down Expand Up @@ -164,6 +166,8 @@ spec:
cpu: 3
memory: 5G
nvidia.com/gpu: 7
imagePullSecrets:
- name: unit-test-pull-secret
initContainers:
- command:
- sh
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def test_config_creation():
gpu=7,
instascale=True,
machine_types=["cpu.small", "gpu.large"],
image_pull_secrets=["unit-test-pull-secret"],
)

assert config.name == "unit-test-cluster" and config.namespace == "ns"
Expand All @@ -234,6 +235,7 @@ def test_config_creation():
assert config.template == f"{parent}/src/codeflare_sdk/templates/base-template.yaml"
assert config.instascale
assert config.machine_types == ["cpu.small", "gpu.large"]
assert config.image_pull_secrets == ["unit-test-pull-secret"]
return config


Expand Down