Black 20 release
This commit is contained in:
@@ -391,7 +391,11 @@ class EncoderDecoderMixin:
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||
with torch.no_grad():
|
||||
outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,)
|
||||
outputs = model_2(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
@@ -401,7 +405,9 @@ class EncoderDecoderMixin:
|
||||
model_1.to(torch_device)
|
||||
|
||||
after_outputs = model_1(
|
||||
input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
|
||||
Reference in New Issue
Block a user