Add BigBirdPegasus (#10991)
* init bigbird pegasus * add debugging nb ; update config * init conversion * update conversion script * complete conversion script * init forward() * complete forward() * add tokenizer * add some slow tests * commit current * fix copies * add docs * add conversion script for bigbird-roberta-summarization * remove TODO * small fixups * correct tokenizer * add bigbird core for now * fix config * fix more * revert pegasus-tokenizer back * make style * everything working for pubmed; yayygit status * complete tests finally * remove bigbird pegasus tok * correct tokenizer * correct tests * add tokenizer files * finish make style * fix test * update * make style * fix tok utils base file * make fix-copies * clean a bit * small update * fix some suggestions * add to readme * fix a bit, clean tests * fix more tests * Update src/transformers/__init__.py * Update src/transformers/__init__.py * make fix-copies * complete attn switching, auto-padding left * make style * fix auto-padding test * make style * fix batched attention tests * put tolerance at 1e-1 for stand-alone decoder test * fix docs * fix tests * correct slow tokenizer conversion * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * complete remaining suggestions * fix test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -310,19 +310,18 @@ class GenerationTesterMixin:
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
output_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
output_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
return output_sample, output_generate
|
||||
|
||||
def _beam_search_generate(
|
||||
|
||||
Reference in New Issue
Block a user