Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add xid_type validation, configurable service names, and improved typ…
…e hints
  • Loading branch information
nv-oviya committed Nov 4, 2025
commit 89a0a6c07b60cf8adfbe2b146eb0052ec7b68c06
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
class CUDAFaultInjector:
"""Manages CUDA fault injection library and deployment patching."""

VALID_XID_TYPES = {79, 48, 94, 95, 43, 74}

def __init__(self, lib_dir: Optional[Path] = None):
"""
Initialize CUDA fault injector.
Expand Down Expand Up @@ -125,6 +127,12 @@ def patch_deployment_for_cuda_fault(
Returns:
True if patch succeeded
"""
if xid_type not in self.VALID_XID_TYPES:
print(
f" ✗ Invalid xid_type: {xid_type}. Valid values: {sorted(self.VALID_XID_TYPES)}"
)
return False

print(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All print() should be from logging

f"\n[→] Patching deployment to enable CUDA fault injection (XID {xid_type})..."
)
Expand All @@ -151,7 +159,11 @@ def patch_deployment_for_cuda_fault(
return False

def cleanup_cuda_fault_injection(
self, deployment_name: str, namespace: str, force_delete_pods: bool = True
self,
deployment_name: str,
namespace: str,
force_delete_pods: bool = True,
service_names: Optional[List[str]] = None,
) -> bool:
"""
Remove CUDA fault injection from deployment.
Expand All @@ -167,10 +179,13 @@ def cleanup_cuda_fault_injection(
deployment_name: Name of the deployment
namespace: Kubernetes namespace
force_delete_pods: If True, force delete pods to apply clean spec
service_names: Service names to check (default: ["VllmDecodeWorker", "VllmPrefillWorker"])

Returns:
True if cleanup succeeded
"""
if service_names is None:
service_names = ["VllmDecodeWorker", "VllmPrefillWorker"]
print("\n[→] Cleaning up CUDA fault injection...")

sys.path.insert(0, str(self.lib_dir))
Expand Down Expand Up @@ -211,7 +226,7 @@ def cleanup_cuda_fault_injection(
has_artifacts = False
artifact_details = []

for service_name in ["VllmDecodeWorker", "VllmPrefillWorker"]:
for service_name in service_names:
service = (
dgd.get("spec", {})
.get("services", {})
Expand Down Expand Up @@ -309,7 +324,7 @@ def cleanup_cuda_fault_injection(
if deleted_count > 0:
print(f" ✓ Deleted {deleted_count} pod(s)")
else:
print(" No pods to delete")
print(" [i] No pods to delete")

except Exception as e:
print(f" ⚠ Pod deletion: {e}")
Expand All @@ -324,7 +339,7 @@ def cleanup_cuda_fault_injection(
traceback.print_exc()
return False

def trigger_pod_restart(self, pods: List, namespace: str):
def trigger_pod_restart(self, pods: List[client.V1Pod], namespace: str):
"""
Delete pods to trigger restart with new env vars.

Expand Down
Loading