Make default_data_collator more flexible and deprecate old behavior (#5060)
* Make default_data_collator more flexible * Accept tensors for all features * Document code * Refactor * Formatting
This commit is contained in:
@@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
||||
# have the same attributes.
|
||||
# So we will look at the first element as a proxy for what attributes exist
|
||||
# on the whole batch.
|
||||
if not isinstance(features[0], dict):
|
||||
features = [vars(f) for f in features]
|
||||
|
||||
first = features[0]
|
||||
batch = {}
|
||||
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if hasattr(first, "label") and first.label is not None:
|
||||
if type(first.label) is int:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
if "label" in first:
|
||||
dtype = torch.long if type(first["label"]) is int else torch.float
|
||||
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
||||
elif "label_ids" in first:
|
||||
if isinstance(first["label_ids"], torch.Tensor):
|
||||
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
||||
else:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
||||
if type(first.label_ids[0]) is int:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
else:
|
||||
batch = {}
|
||||
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
||||
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
||||
|
||||
# Handling of all other possible attributes.
|
||||
# Handling of all other possible keys.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in vars(first).items():
|
||||
for k, v in first.items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = torch.stack([f[k] for f in features])
|
||||
else:
|
||||
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user