fixed tests
This commit is contained in:
@@ -594,7 +594,7 @@ class SQuADHead(nn.Module):
|
||||
"""
|
||||
outputs = ()
|
||||
|
||||
start_logits = self.start_logits(hidden_states, p_mask)
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
||||
|
||||
Reference in New Issue
Block a user