Fix tensor label type inference in default collator (#5250)

* allow tensor label inputs to default collator

* replace try/except with type check
This commit is contained in:
Joe Davison
2020-07-01 10:40:14 -06:00
committed by GitHub
parent fe81f7d12c
commit 35befd9ce3
2 changed files with 10 additions and 1 deletions

View File

@@ -43,7 +43,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if "label" in first and first["label"] is not None:
dtype = torch.long if type(first["label"]) is int else torch.float
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
dtype = torch.long if isinstance(label, int) else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], torch.Tensor):