Add tests to Trainer (#6605)

* Add tests to Trainer

* Test if removing long breaks everything

* Remove ugly hack

* Fix distributed test

* Use float for number of epochs
This commit is contained in:
Sylvain Gugger
2020-08-20 11:13:50 -04:00
committed by GitHub
parent 039d8d65fc
commit 573bdb0a5d
6 changed files with 313 additions and 136 deletions

View File

@@ -62,7 +62,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
else:
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
batch[k] = torch.tensor([f[k] for f in features])
return batch