From d735b074d70370da57eea55c0644910a6672f5b5 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 4 Jan 2021 16:06:28 +0100 Subject: [PATCH] Fix Flaubert (#9292) --- .../models/flaubert/modeling_tf_flaubert.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index c1711b7f73..7dae97f645 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -17,6 +17,7 @@ """ import itertools +import random from dataclasses import dataclass from typing import Optional, Tuple @@ -596,15 +597,15 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): tensor = tensor * mask[..., tf.newaxis] # hidden_states and attentions cannot be None in graph mode. - hidden_states = () - attentions = () + hidden_states = () if inputs["output_hidden_states"] else None + attentions = () if inputs["output_attentions"] else None # transformer layers for i in range(self.n_layers): # LayerDrop - dropout_probability = tf.random.uniform([1], 0, 1) + dropout_probability = random.uniform(0, 1) - if inputs["training"] and tf.less(dropout_probability, self.layerdrop): + if inputs["training"] and (dropout_probability < self.layerdrop): continue if inputs["output_hidden_states"]: @@ -642,7 +643,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ) attn = attn_outputs[0] - if output_attentions: + if inputs["output_attentions"]: attentions = attentions + (attn_outputs[1],) attn = self.dropout(attn, training=inputs["training"]) @@ -676,10 +677,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): # move back sequence length to dimension 0 # tensor = tensor.transpose(0, 1) - # Set to None here if the output booleans are at False - hidden_states = hidden_states if inputs["output_hidden_states"] else None - attentions = attentions if inputs["output_attentions"] else None - if not inputs["return_dict"]: return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)