Add tests to Trainer (#6605)
* Add tests to Trainer * Test if removing long breaks everything * Remove ugly hack * Fix distributed test * Use float for number of epochs
This commit is contained in:
@@ -62,7 +62,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = torch.stack([f[k] for f in features])
|
||||
else:
|
||||
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
|
||||
batch[k] = torch.tensor([f[k] for f in features])
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user