[tokenizers] Updates data processors, docstring, examples and model cards to the new API (#5308)
* remove references to old API in docstring - update data processors * style * fix tests - better type checking error messages * better type checking * include awesome fix by @LysandreJik for #5310 * updated doc and examples
This commit is contained in:
@@ -626,9 +626,9 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
|
||||
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",).to(
|
||||
torch_device
|
||||
)
|
||||
dct = tok.batch_encode_plus(
|
||||
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
hypotheses_batch = model.generate(
|
||||
input_ids=dct["input_ids"],
|
||||
@@ -672,7 +672,8 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
dct = tok.batch_encode_plus(
|
||||
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
max_length=1024,
|
||||
pad_to_max_length=True,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
@@ -375,10 +375,11 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
summarization_config = task_specific_config.get("summarization", {})
|
||||
model.config.update(summarization_config)
|
||||
|
||||
dct = tok.batch_encode_plus(
|
||||
dct = tok(
|
||||
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
|
||||
max_length=512,
|
||||
pad_to_max_length=True,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertEqual(512, dct["input_ids"].shape[1])
|
||||
|
||||
@@ -276,10 +276,11 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
||||
summarization_config = task_specific_config.get("summarization", {})
|
||||
model.config.update(summarization_config)
|
||||
|
||||
dct = tok.batch_encode_plus(
|
||||
dct = tok(
|
||||
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
|
||||
max_length=512,
|
||||
pad_to_max_length=True,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="tf",
|
||||
)
|
||||
self.assertEqual(512, dct["input_ids"].shape[1])
|
||||
|
||||
Reference in New Issue
Block a user