diff --git a/convokit/forecaster/TransformerDecoderModel.py b/convokit/forecaster/TransformerDecoderModel.py index 39c36f29..2cfc1517 100644 --- a/convokit/forecaster/TransformerDecoderModel.py +++ b/convokit/forecaster/TransformerDecoderModel.py @@ -272,7 +272,7 @@ def fit(self, train_contexts, val_contexts): train_dataset=train_dataset, args=SFTConfig( dataset_text_field="text", - max_seq_length=self.max_seq_length, + max_length=self.max_seq_length, per_device_train_batch_size=self.config.per_device_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, warmup_steps=10, @@ -292,7 +292,7 @@ def fit(self, train_contexts, val_contexts): ), ) trainer.train() - _ = self._tune_threshold(self, val_contexts) + _ = self._tune_threshold(val_contexts) return def _tune_threshold(self, val_contexts):