From ec54d70e162f2f081f17621c5344df165559bedd Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 4 Jan 2021 17:26:56 +0100 Subject: [PATCH] Fix TF DPR (#9283) * Fix DPR * Keep usual models * Apply style * Address Sylvain's comments --- .../models/dpr/modeling_tf_dpr.py | 120 +++++++++++++++--- 1 file changed, 105 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/dpr/modeling_tf_dpr.py b/src/transformers/models/dpr/modeling_tf_dpr.py index 68c0a02f07..03033d7792 100644 --- a/src/transformers/models/dpr/modeling_tf_dpr.py +++ b/src/transformers/models/dpr/modeling_tf_dpr.py @@ -144,18 +144,18 @@ class TFDPRReaderOutput(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None -class TFDPREncoder(TFPreTrainedModel): +class TFDPREncoderLayer(tf.keras.layers.Layer): base_model_prefix = "bert_model" - def __init__(self, config: DPRConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(**kwargs) # resolve name conflict with TFBertMainLayer instead of TFBertModel self.bert_model = TFBertMainLayer(config, name="bert_model") - self.bert_model.config = config + self.config = config - assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero" + assert self.config.hidden_size > 0, "Encoder hidden_size can't be zero" self.projection_dim = config.projection_dim if self.projection_dim > 0: self.encode_proj = tf.keras.layers.Dense( @@ -220,13 +220,14 @@ class TFDPREncoder(TFPreTrainedModel): return self.bert_model.config.hidden_size -class TFDPRSpanPredictor(TFPreTrainedModel): +class TFDPRSpanPredictorLayer(tf.keras.layers.Layer): base_model_prefix = "encoder" - def __init__(self, config: DPRConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.encoder = TFDPREncoder(config, name="encoder") + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.encoder = TFDPREncoderLayer(config, name="encoder") self.qa_outputs = tf.keras.layers.Dense( 2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" @@ -299,6 +300,97 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ) +class TFDPRSpanPredictor(TFPreTrainedModel): + + base_model_prefix = "encoder" + + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(config, **kwargs) + self.encoder = TFDPRSpanPredictorLayer(config) + + def call( + self, + input_ids: tf.Tensor = None, + attention_mask: Optional[tf.Tensor] = None, + token_type_ids: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + training: bool = False, + **kwargs, + ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + outputs = self.encoder( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + inputs_embeds=inputs["inputs_embeds"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + + return outputs + + +class TFDPREncoder(TFPreTrainedModel): + base_model_prefix = "encoder" + + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(config, **kwargs) + + self.encoder = TFDPREncoderLayer(config) + + def call( + self, + input_ids: tf.Tensor = None, + attention_mask: Optional[tf.Tensor] = None, + token_type_ids: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + training: bool = False, + **kwargs, + ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + outputs = self.encoder( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + inputs_embeds=inputs["inputs_embeds"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + return outputs + + ################## # PreTrainedModel ################## @@ -465,8 +557,7 @@ TF_DPR_READER_INPUTS_DOCSTRING = r""" class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): def __init__(self, config: DPRConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) - self.config = config - self.ctx_encoder = TFDPREncoder(config, name="ctx_encoder") + self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder") def get_input_embeddings(self): return self.ctx_encoder.bert_model.get_input_embeddings() @@ -541,6 +632,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): if not inputs["return_dict"]: return outputs[1:] + return TFDPRContextEncoderOutput( pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) @@ -553,8 +645,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): def __init__(self, config: DPRConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) - self.config = config - self.question_encoder = TFDPREncoder(config, name="question_encoder") + self.question_encoder = TFDPREncoderLayer(config, name="question_encoder") def get_input_embeddings(self): return self.question_encoder.bert_model.get_input_embeddings() @@ -641,8 +732,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): class TFDPRReader(TFDPRPretrainedReader): def __init__(self, config: DPRConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) - self.config = config - self.span_predictor = TFDPRSpanPredictor(config, name="span_predictor") + self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor") def get_input_embeddings(self): return self.span_predictor.encoder.bert_model.get_input_embeddings()