fixed tests

This commit is contained in:
thomwolf
2019-07-15 12:32:19 +02:00
parent e28d8bde0d
commit f7cd7392fd
7 changed files with 63 additions and 38 deletions

View File

@@ -594,7 +594,7 @@ class SQuADHead(nn.Module):
"""
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask)
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting