[tokenizers] Updates data processors, docstring, examples and model cards to the new API (#5308)

* remove references to old API in docstring - update data processors

* style

* fix tests - better type checking error messages

* better type checking

* include awesome fix by @LysandreJik for #5310

* updated doc and examples
This commit is contained in:
Thomas Wolf
2020-06-26 19:48:14 +02:00
committed by GitHub
parent fd405e9a93
commit 601d4d699c
73 changed files with 180 additions and 138 deletions

View File

@@ -193,12 +193,12 @@ def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_toks = tokenizer(q_ls, max_length=max_len, padding="max_length", truncation=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)
a_toks = tokenizer(a_ls, max_length=max_len, padding="max_length", truncation=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
@@ -375,12 +375,12 @@ def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_toks = tokenizer(q_ls, max_length=max_len, padding="max_length", truncation=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
a_toks = tokenizer(a_ls, max_length=min(max_len, max_a_len), padding="max_length", truncation=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
@@ -531,7 +531,7 @@ def qa_s2s_generate(
# ELI5-trained retrieval model usage
###############
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)
a_toks = tokenizer(passages, max_length=max_length, padding="max_length", truncation=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
@@ -542,7 +542,7 @@ def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=12
def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)
q_toks = tokenizer(q_ls, max_length=128, padding="max_length", truncation=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),