add FlashAttentionKwargs and seq_idx to flat collator (#36456)
* add flash attn kwargs to flattening collator
* add return_seq_idx option
* doc string edits
* cleaner max len updates
* various fixes
* temp testing code
* return int32 seq_idx and FlashAttnKwargs
* DataCollatorIntegrationTest impl
* fix batch dims and dtypes
* fill out remaining collator tests
* test name change and fmt
* rm unused var
* fmt
* minor change
* fmt
* add missing pos_ids check
* consistent {np,pt,tf} tests
* split pt tests into 3, like np/tf tests
* mv comment, rename fa test
* remove batch dim comment
* simply wrapping
* compute cu_seq_len/max_length once
* fmt
* remove tf code
* rm warning
* move separator_id back to 2nd pos
* use cleaner lists in tests
* ret -> batch
* fmt
* attr ordering
* use py ints for max_length_{k,q}
This commit is contained in:
@@ -1974,9 +1974,11 @@ class DataCollatorWithFlattening(DefaultDataCollator):
|
|||||||
"""
|
"""
|
||||||
Data collator used for padding free approach. Does the following:
|
Data collator used for padding free approach. Does the following:
|
||||||
|
|
||||||
- concatate the entire mini batch into single long sequence [1, total_tokens]
|
- concatenates the entire mini batch into single long sequence of shape [1, total_tokens]
|
||||||
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
|
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
|
||||||
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
|
- no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
|
||||||
|
- optionally returns the kwargs contained in FlashAttentionKwargs
|
||||||
|
- optionally returns seq_idx indicating which sequence each token belongs to
|
||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
@@ -1986,10 +1988,23 @@ class DataCollatorWithFlattening(DefaultDataCollator):
|
|||||||
</Tip>
|
</Tip>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
return_position_ids=True,
|
||||||
|
separator_id=-100,
|
||||||
|
return_flash_attn_kwargs=False,
|
||||||
|
return_seq_idx=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.return_position_ids = return_position_ids
|
self.return_position_ids = return_position_ids
|
||||||
self.separator_id = separator_id
|
self.separator_id = separator_id
|
||||||
|
self.return_flash_attn_kwargs = return_flash_attn_kwargs
|
||||||
|
self.return_seq_idx = return_seq_idx
|
||||||
|
self._int_64_keys = {"labels", "position_ids", "input_ids"}
|
||||||
|
self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"}
|
||||||
|
self._py_int_keys = {"max_length_q", "max_length_k"}
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None, separator_id=None):
|
def __call__(self, features, return_tensors=None, separator_id=None):
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
@@ -1997,15 +2012,52 @@ class DataCollatorWithFlattening(DefaultDataCollator):
|
|||||||
if separator_id is None:
|
if separator_id is None:
|
||||||
separator_id = self.separator_id
|
separator_id = self.separator_id
|
||||||
is_labels_provided = "labels" in features[0]
|
is_labels_provided = "labels" in features[0]
|
||||||
ret = {"input_ids": [], "labels": []}
|
batch = {"input_ids": [], "labels": []}
|
||||||
if self.return_position_ids:
|
if self.return_position_ids:
|
||||||
ret.update({"position_ids": []})
|
batch.update({"position_ids": []})
|
||||||
for idx in range(0, len(features)):
|
if self.return_seq_idx:
|
||||||
ret["input_ids"] += features[idx]["input_ids"]
|
batch.update({"seq_idx": []})
|
||||||
|
if self.return_flash_attn_kwargs:
|
||||||
|
cu_seq_lens = [0]
|
||||||
|
max_length = 0
|
||||||
|
for seq_idx, sample in enumerate(features):
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
batch["input_ids"] += input_ids
|
||||||
if is_labels_provided:
|
if is_labels_provided:
|
||||||
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
|
batch["labels"] += [separator_id] + sample["labels"][1:]
|
||||||
else:
|
else:
|
||||||
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
|
batch["labels"] += [separator_id] + input_ids[1:]
|
||||||
if self.return_position_ids:
|
if self.return_position_ids:
|
||||||
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
|
batch["position_ids"] += list(range(len(input_ids)))
|
||||||
return default_data_collator([ret], return_tensors)
|
if self.return_seq_idx:
|
||||||
|
batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))]
|
||||||
|
if self.return_flash_attn_kwargs:
|
||||||
|
cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
|
||||||
|
max_length = max(max_length, len(input_ids))
|
||||||
|
|
||||||
|
if self.return_flash_attn_kwargs:
|
||||||
|
batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
|
||||||
|
batch["max_length_q"] = batch["max_length_k"] = max_length
|
||||||
|
|
||||||
|
# FlashAttentionKwargs and seq_idx are expected to be int32s.
|
||||||
|
if return_tensors == "pt":
|
||||||
|
import torch
|
||||||
|
|
||||||
|
data_cls = torch.tensor
|
||||||
|
dtype_64 = torch.int64
|
||||||
|
dtype_32 = torch.int32
|
||||||
|
elif return_tensors == "np":
|
||||||
|
data_cls = np.array
|
||||||
|
dtype_64 = np.int64
|
||||||
|
dtype_32 = np.int32
|
||||||
|
else:
|
||||||
|
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
|
||||||
|
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k in self._batch_dim_keys:
|
||||||
|
v = [v]
|
||||||
|
# Flash attention max_len_{q,k} are python ints
|
||||||
|
if k not in self._py_int_keys:
|
||||||
|
batch[k] = data_cls(v, dtype=dtype_64 if k in self._int_64_keys else dtype_32)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from transformers import (
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
DataCollatorWithFlattening,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -4170,6 +4171,78 @@ class ModelTesterMixin:
|
|||||||
tol = torch.finfo(torch.float16).eps
|
tol = torch.finfo(torch.float16).eps
|
||||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||||
|
if not self.has_attentions:
|
||||||
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
|
max_new_tokens = 30
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||||
|
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||||
|
dummy_input = dummy_input.to(torch.float16)
|
||||||
|
|
||||||
|
# make sure that all models have enough positions for generation
|
||||||
|
if hasattr(config, "max_position_embeddings"):
|
||||||
|
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||||
|
self.skipTest("Model does not support position_ids")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# ensure left padding, to adapt for some models
|
||||||
|
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||||
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
|
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||||
|
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||||
|
|
||||||
|
model = (
|
||||||
|
model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
features = [
|
||||||
|
{"input_ids": i[a.bool()].tolist()}
|
||||||
|
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||||
|
]
|
||||||
|
|
||||||
|
# add position_ids + fa_kwargs
|
||||||
|
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||||
|
batch = data_collator(features)
|
||||||
|
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||||
|
|
||||||
|
res_padded = model(**inputs_dict)
|
||||||
|
res_padfree = model(**batch_cuda)
|
||||||
|
|
||||||
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
|
logits_padfree = res_padfree.logits[0]
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||||
|
# acceptable numerical instability
|
||||||
|
tol = torch.finfo(torch.float16).eps
|
||||||
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
|||||||
@@ -126,6 +126,104 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
batch = data_collator(features)
|
batch = data_collator(features)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
||||||
|
|
||||||
|
def test_data_collator_with_flattening(self):
|
||||||
|
features = [
|
||||||
|
{"input_ids": [10, 11, 12]},
|
||||||
|
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||||
|
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||||
|
]
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithFlattening(return_tensors="pt")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
for unexpected_key in [
|
||||||
|
"attention_mask",
|
||||||
|
"cu_seq_lens_k",
|
||||||
|
"cu_seq_lens_q",
|
||||||
|
"max_length_k",
|
||||||
|
"max_length_q",
|
||||||
|
"seq_idx",
|
||||||
|
]:
|
||||||
|
self.assertNotIn(unexpected_key, batch)
|
||||||
|
self.assertIn("position_ids", batch)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||||
|
self.assertEqual(
|
||||||
|
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||||
|
)
|
||||||
|
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||||
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
def test_data_collator_with_flattening_flash_attn_kwargs(self):
|
||||||
|
features = [
|
||||||
|
{"input_ids": [10, 11, 12]},
|
||||||
|
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||||
|
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||||
|
]
|
||||||
|
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
for unexpected_key in [
|
||||||
|
"attention_mask",
|
||||||
|
"seq_idx",
|
||||||
|
]:
|
||||||
|
self.assertNotIn(unexpected_key, batch)
|
||||||
|
for expected_key in [
|
||||||
|
"position_ids",
|
||||||
|
"cu_seq_lens_k",
|
||||||
|
"cu_seq_lens_q",
|
||||||
|
"max_length_k",
|
||||||
|
"max_length_q",
|
||||||
|
]:
|
||||||
|
self.assertIn(expected_key, batch)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||||
|
self.assertEqual(
|
||||||
|
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||||
|
)
|
||||||
|
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||||
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4]))
|
||||||
|
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
||||||
|
self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4]))
|
||||||
|
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
||||||
|
# The flash attn max_length_{k,q} are simple python ints
|
||||||
|
self.assertEqual(batch["max_length_k"], 7)
|
||||||
|
self.assertEqual(batch["max_length_q"], 7)
|
||||||
|
|
||||||
|
def test_data_collator_with_flattening_seq_idx(self):
|
||||||
|
features = [
|
||||||
|
{"input_ids": [10, 11, 12]},
|
||||||
|
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||||
|
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||||
|
]
|
||||||
|
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
for unexpected_key in [
|
||||||
|
"attention_mask",
|
||||||
|
"cu_seq_lens_k",
|
||||||
|
"cu_seq_lens_q",
|
||||||
|
"max_length_k",
|
||||||
|
"max_length_q",
|
||||||
|
]:
|
||||||
|
self.assertNotIn(unexpected_key, batch)
|
||||||
|
for expected_key in [
|
||||||
|
"position_ids",
|
||||||
|
"seq_idx",
|
||||||
|
]:
|
||||||
|
self.assertIn(expected_key, batch)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||||
|
self.assertEqual(
|
||||||
|
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||||
|
)
|
||||||
|
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||||
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||||
|
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
|
||||||
|
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
||||||
|
|
||||||
def test_data_collator_for_token_classification(self):
|
def test_data_collator_for_token_classification(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
features = [
|
features = [
|
||||||
@@ -1803,15 +1901,97 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
data_collator = DataCollatorWithFlattening(return_tensors="np")
|
data_collator = DataCollatorWithFlattening(return_tensors="np")
|
||||||
batch = data_collator(features)
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
for unexpected_key in [
|
||||||
|
"attention_mask",
|
||||||
|
"cu_seq_lens_k",
|
||||||
|
"cu_seq_lens_q",
|
||||||
|
"max_length_k",
|
||||||
|
"max_length_q",
|
||||||
|
"seq_idx",
|
||||||
|
]:
|
||||||
|
self.assertNotIn(unexpected_key, batch)
|
||||||
|
self.assertIn("position_ids", batch)
|
||||||
|
|
||||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||||
)
|
)
|
||||||
self.assertNotIn("attention_mask", batch)
|
|
||||||
self.assertIn("position_ids", batch)
|
|
||||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
def test_data_collator_with_flattening_flash_attn_kwargs(self):
|
||||||
|
features = [
|
||||||
|
{"input_ids": [10, 11, 12]},
|
||||||
|
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||||
|
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||||
|
]
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
for unexpected_key in [
|
||||||
|
"attention_mask",
|
||||||
|
"seq_idx",
|
||||||
|
]:
|
||||||
|
self.assertNotIn(unexpected_key, batch)
|
||||||
|
for expected_key in [
|
||||||
|
"position_ids",
|
||||||
|
"cu_seq_lens_k",
|
||||||
|
"cu_seq_lens_q",
|
||||||
|
"max_length_k",
|
||||||
|
"max_length_q",
|
||||||
|
]:
|
||||||
|
self.assertIn(expected_key, batch)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||||
|
self.assertEqual(
|
||||||
|
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||||
|
)
|
||||||
|
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||||
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
self.assertEqual(batch["cu_seq_lens_k"].shape, (4,))
|
||||||
|
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
||||||
|
self.assertEqual(batch["cu_seq_lens_q"].shape, (4,))
|
||||||
|
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
||||||
|
# The flash attn max_length_{k,q} are simple python ints
|
||||||
|
self.assertEqual(batch["max_length_k"], 7)
|
||||||
|
self.assertEqual(batch["max_length_q"], 7)
|
||||||
|
|
||||||
|
def test_data_collator_with_flattening_seq_idx(self):
|
||||||
|
features = [
|
||||||
|
{"input_ids": [10, 11, 12]},
|
||||||
|
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||||
|
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||||
|
]
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True)
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
for unexpected_key in [
|
||||||
|
"attention_mask",
|
||||||
|
"cu_seq_lens_k",
|
||||||
|
"cu_seq_lens_q",
|
||||||
|
"max_length_k",
|
||||||
|
"max_length_q",
|
||||||
|
]:
|
||||||
|
self.assertNotIn(unexpected_key, batch)
|
||||||
|
for expected_key in [
|
||||||
|
"position_ids",
|
||||||
|
"seq_idx",
|
||||||
|
]:
|
||||||
|
self.assertIn(expected_key, batch)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||||
|
self.assertEqual(
|
||||||
|
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||||
|
)
|
||||||
|
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||||
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||||
|
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
|
||||||
|
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
||||||
|
|
||||||
def test_data_collator_for_token_classification(self):
|
def test_data_collator_for_token_classification(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
features = [
|
features = [
|
||||||
|
|||||||
Reference in New Issue
Block a user