adding TF 2.0 model

This commit is contained in:
thomwolf
2019-10-09 11:07:43 +02:00
parent 45dc04f33d
commit c56d921dda
6 changed files with 430 additions and 213 deletions

View File

@@ -351,7 +351,7 @@ class CTRLModel(CTRLPreTrainedModel):
x = self.w(input_ids)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[1]
seq_len = input_ids.shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
x *= np.sqrt(self.d_model_size)