Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add batch size for causual model
  • Loading branch information
t-jingweiyi committed Oct 7, 2023
commit ae62d080e916a96624d7cde459f589061b847950
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
65 changes: 39 additions & 26 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def handler(signum, frame):
else:
time.sleep(3)
return "Z"


class vllmModel(EvalModel):
model_path: str
Expand All @@ -149,12 +149,11 @@ def load(self):
self.model = LLM(
model=self.model_path,
trust_remote_code=self.trust_remote_code,
tensor_parallel_size=self.tensor_parallel_size
tensor_parallel_size=self.tensor_parallel_size,
)
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=self.trust_remote_code
self.model_path, trust_remote_code=self.trust_remote_code
)

def count_text_length(self, text: str) -> int:
Expand All @@ -175,18 +174,16 @@ def run(self, prompts: str, **kwargs) -> str:

do_sample = kwargs.pop("do_sample", True)
max_output_length = kwargs.pop("max_output_length", self.max_output_length)
temperature = 0 if not do_sample else kwargs.pop("temperature", self.temperature)
temperature = (
0 if not do_sample else kwargs.pop("temperature", self.temperature)
)

sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_output_length,
**kwargs
)
outputs = self.model.generate(
prompts, sampling_params
temperature=temperature, max_tokens=max_output_length, **kwargs
)
outputs = self.model.generate(prompts, sampling_params)
return [output.outputs[0].text for output in outputs]

def get_choice(self, text: str, **kwargs) -> Tuple[float, float]:
raise NotImplementedError

Expand Down Expand Up @@ -253,6 +250,8 @@ def get_choice(self, text: str, **kwargs) -> Tuple[float, float]:


class CausalModel(SeqToSeqModel):
batch_size: int = 1

def load(self):
if self.model is None:
args = {}
Expand All @@ -268,24 +267,38 @@ def load(self):
# self.model.to(self.device)
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path, trust_remote_code=True
self.model_path, trust_remote_code=True, padding_side="left"
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

def run(self, prompt: str, **kwargs) -> str:
self.load()
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
if "RWForCausalLM" in str(type(self.model)):
inputs.pop("token_type_ids") # Not used by Falcon model
all_outputs = []

outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_output_length,
pad_token_id=self.tokenizer.eos_token_id, # Avoid pad token warning
do_sample=self.do_sample,
**kwargs,
)
batch_size, length = inputs.input_ids.shape
return self.tokenizer.decode(outputs[0, length:], skip_special_tokens=True)
for i in range(0, len(prompt), self.batch_size):
batch_prompt = prompt[i : i + self.batch_size]
inputs = self.tokenizer(batch_prompt, return_tensors="pt", padding=True).to(
self.device
)
if "RWForCausalLM" in str(type(self.model)):
inputs.pop("token_type_ids") # Not used by Falcon model

outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_output_length,
pad_token_id=self.tokenizer.eos_token_id, # Avoid pad token warning
do_sample=self.do_sample,
**kwargs,
)
batch_size, length = inputs.input_ids.shape
all_outputs.extend(
self.tokenizer.batch_decode(
outputs[:, length:], skip_special_tokens=True
)
)

return all_outputs

def get_choice(self, text: str, **kwargs) -> Tuple[float, float]:
self.load()
Expand Down Expand Up @@ -570,7 +583,7 @@ def select_model(model_name: str, **kwargs) -> EvalModel:
openai=OpenAIModel,
rwkv=RWKVModel,
gptq=GPTQModel,
vllm=vllmModel
vllm=vllmModel,
)
model_class = model_map.get(model_name)
if model_class is None:
Expand Down