Skip to content
Prev Previous commit
Next Next commit
Remove train/test split
  • Loading branch information
merrymercy authored May 29, 2023
commit f6c66966dd3be36955cf8c3509892f8c21797f65
25 changes: 14 additions & 11 deletions fastchat/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,20 @@ def make_supervised_data_module(
raw_data = json.load(open(data_args.data_path, "r"))

# Split train/test
perm = np.random.permutation(len(raw_data))
split = int(len(perm) * 0.98)
train_indices = perm[:split]
eval_indices = perm[split:]
train_raw_data = [raw_data[i] for i in train_indices]
eval_raw_data = [raw_data[i] for i in eval_indices]
rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")

train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer)
eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
# perm = np.random.permutation(len(raw_data))
# split = int(len(perm) * 0.98)
# train_indices = perm[:split]
# eval_indices = perm[split:]
# train_raw_data = [raw_data[i] for i in train_indices]
# eval_raw_data = [raw_data[i] for i in eval_indices]
# rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")

# train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer)
# eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer)

rank0_print(f"#train {len(raw_data)}")
train_dataset = dataset_cls(raw_data, tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None)


def train():
Expand Down