[tokenizers] Updates data processors, docstring, examples and model cards to the new API (#5308)

* remove references to old API in docstring - update data processors

* style

* fix tests - better type checking error messages

* better type checking

* include awesome fix by @LysandreJik for #5310

* updated doc and examples
This commit is contained in:
Thomas Wolf
2020-06-26 19:48:14 +02:00
committed by GitHub
parent fd405e9a93
commit 601d4d699c
73 changed files with 180 additions and 138 deletions

View File

@@ -626,9 +626,9 @@ class BartModelIntegrationTests(unittest.TestCase):
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",).to(
torch_device
)
dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
).to(torch_device)
hypotheses_batch = model.generate(
input_ids=dct["input_ids"],
@@ -672,7 +672,8 @@ class BartModelIntegrationTests(unittest.TestCase):
dct = tok.batch_encode_plus(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
max_length=1024,
pad_to_max_length=True,
padding="max_length",
truncation=True,
return_tensors="pt",
)

View File

@@ -375,10 +375,11 @@ class T5ModelIntegrationTests(unittest.TestCase):
summarization_config = task_specific_config.get("summarization", {})
model.config.update(summarization_config)
dct = tok.batch_encode_plus(
dct = tok(
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
max_length=512,
pad_to_max_length=True,
padding="max_length",
truncation=True,
return_tensors="pt",
)
self.assertEqual(512, dct["input_ids"].shape[1])

View File

@@ -276,10 +276,11 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
summarization_config = task_specific_config.get("summarization", {})
model.config.update(summarization_config)
dct = tok.batch_encode_plus(
dct = tok(
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
max_length=512,
pad_to_max_length=True,
padding="max_length",
truncation=True,
return_tensors="tf",
)
self.assertEqual(512, dct["input_ids"].shape[1])