adding tests

This commit is contained in:
thomwolf
2019-08-05 18:14:07 +02:00
parent b90e29d52c
commit ed4e542260
4 changed files with 83 additions and 3 deletions

View File

@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module):
super(SequenceSummary, self).__init__()
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
if config.summary_type == 'attn':
if self.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0