Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Check GPU without torch
  • Loading branch information
daavoo committed Jan 17, 2025
commit d07ee3966509692e3d7f4402b673c7244d756eeb
12 changes: 10 additions & 2 deletions src/structured_qa/model_loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import torch
import subprocess
from llama_cpp import Llama


def gpu_available():
try:
subprocess.check_output("nvidia-smi")
return True
except Exception:
return False


def load_llama_cpp_model(model_id: str) -> Llama:
"""
Loads the given model_id using Llama.from_pretrained.
Expand All @@ -22,6 +30,6 @@ def load_llama_cpp_model(model_id: str) -> Llama:
filename=filename,
n_ctx=0, # 0 means that the model limit will be used, instead of the default (512) or other hardcoded value
verbose=False,
n_gpu_layers=-1 if torch.cuda.is_available() else 0,
n_gpu_layers=-1 if gpu_available() else 0,
)
return model
Loading