Skip to content
This repository was archived by the owner on Apr 11, 2025. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ def quantize(

examples = self._prepare_examples_for_quantization(examples, batch_size)

def nested_move_to_device(v, device):
if isinstance(v, torch.Tensor):
return move_to_device(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to_device(e, device) for e in v])
else:
return v

class LayerHijacker(nn.Module):
"""hijack layer's forward pass to cache data"""

Expand Down Expand Up @@ -259,10 +267,11 @@ def forward(self, inp=None, **kwargs):
one_kwargs = dict()
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
if isinstance(v, torch.Tensor):
one_kwargs[k] = move_to_device(v, self.data_device)
else:
one_kwargs[k] = v
# if isinstance(v, torch.Tensor):
# one_kwargs[k] = move_to_device(v, self.data_device)
# else:
# one_kwargs[k] = v
one_kwargs[k] = nested_move_to_device(v, self.data_device)
layer_input_kwargs.append(one_kwargs)
raise ValueError

Expand Down Expand Up @@ -355,10 +364,11 @@ def tmp(_, inp, out):
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
else:
additional_layer_inputs[k] = v
# if isinstance(v, torch.Tensor):
# additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
# else:
# additional_layer_inputs[k] = v
additional_layer_inputs[k] = nested_move_to_device(v, cur_layer_device)
layer(layer_input, **additional_layer_inputs)
for h in handles:
h.remove()
Expand Down Expand Up @@ -389,10 +399,11 @@ def tmp(_, inp, out):
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
else:
additional_layer_inputs[k] = v
# if isinstance(v, torch.Tensor):
# additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
# else:
# additional_layer_inputs[k] = v
additional_layer_inputs[k] = nested_move_to_device(v, cur_layer_device)
layer_output = move_to_device(
layer(layer_input, **additional_layer_inputs)[0],
cur_layer_device if cache_examples_on_gpu else CPU
Expand Down