Fix TF DPR (#9283)
* Fix DPR * Keep usual models * Apply style * Address Sylvain's comments
This commit is contained in:
@@ -144,18 +144,18 @@ class TFDPRReaderOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class TFDPREncoder(TFPreTrainedModel):
|
class TFDPREncoderLayer(tf.keras.layers.Layer):
|
||||||
|
|
||||||
base_model_prefix = "bert_model"
|
base_model_prefix = "bert_model"
|
||||||
|
|
||||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
def __init__(self, config: DPRConfig, **kwargs):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# resolve name conflict with TFBertMainLayer instead of TFBertModel
|
# resolve name conflict with TFBertMainLayer instead of TFBertModel
|
||||||
self.bert_model = TFBertMainLayer(config, name="bert_model")
|
self.bert_model = TFBertMainLayer(config, name="bert_model")
|
||||||
self.bert_model.config = config
|
self.config = config
|
||||||
|
|
||||||
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
|
assert self.config.hidden_size > 0, "Encoder hidden_size can't be zero"
|
||||||
self.projection_dim = config.projection_dim
|
self.projection_dim = config.projection_dim
|
||||||
if self.projection_dim > 0:
|
if self.projection_dim > 0:
|
||||||
self.encode_proj = tf.keras.layers.Dense(
|
self.encode_proj = tf.keras.layers.Dense(
|
||||||
@@ -220,13 +220,14 @@ class TFDPREncoder(TFPreTrainedModel):
|
|||||||
return self.bert_model.config.hidden_size
|
return self.bert_model.config.hidden_size
|
||||||
|
|
||||||
|
|
||||||
class TFDPRSpanPredictor(TFPreTrainedModel):
|
class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
|
||||||
|
|
||||||
base_model_prefix = "encoder"
|
base_model_prefix = "encoder"
|
||||||
|
|
||||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
def __init__(self, config: DPRConfig, **kwargs):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.encoder = TFDPREncoder(config, name="encoder")
|
self.config = config
|
||||||
|
self.encoder = TFDPREncoderLayer(config, name="encoder")
|
||||||
|
|
||||||
self.qa_outputs = tf.keras.layers.Dense(
|
self.qa_outputs = tf.keras.layers.Dense(
|
||||||
2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||||
@@ -299,6 +300,97 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||||
|
|
||||||
|
base_model_prefix = "encoder"
|
||||||
|
|
||||||
|
def __init__(self, config: DPRConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
self.encoder = TFDPRSpanPredictorLayer(config)
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids: tf.Tensor = None,
|
||||||
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
token_type_ids: Optional[tf.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[tf.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = False,
|
||||||
|
training: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
|
||||||
|
inputs = input_processing(
|
||||||
|
func=self.call,
|
||||||
|
config=self.config,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
kwargs_call=kwargs,
|
||||||
|
)
|
||||||
|
outputs = self.encoder(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
output_attentions=inputs["output_attentions"],
|
||||||
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
|
training=inputs["training"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TFDPREncoder(TFPreTrainedModel):
|
||||||
|
base_model_prefix = "encoder"
|
||||||
|
|
||||||
|
def __init__(self, config: DPRConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
|
||||||
|
self.encoder = TFDPREncoderLayer(config)
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids: tf.Tensor = None,
|
||||||
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
token_type_ids: Optional[tf.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[tf.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = False,
|
||||||
|
training: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
|
||||||
|
inputs = input_processing(
|
||||||
|
func=self.call,
|
||||||
|
config=self.config,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
kwargs_call=kwargs,
|
||||||
|
)
|
||||||
|
outputs = self.encoder(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
output_attentions=inputs["output_attentions"],
|
||||||
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
|
training=inputs["training"],
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# PreTrainedModel
|
# PreTrainedModel
|
||||||
##################
|
##################
|
||||||
@@ -465,8 +557,7 @@ TF_DPR_READER_INPUTS_DOCSTRING = r"""
|
|||||||
class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
self.config = config
|
self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
|
||||||
self.ctx_encoder = TFDPREncoder(config, name="ctx_encoder")
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.ctx_encoder.bert_model.get_input_embeddings()
|
return self.ctx_encoder.bert_model.get_input_embeddings()
|
||||||
@@ -541,6 +632,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
|||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return outputs[1:]
|
return outputs[1:]
|
||||||
|
|
||||||
return TFDPRContextEncoderOutput(
|
return TFDPRContextEncoderOutput(
|
||||||
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||||
)
|
)
|
||||||
@@ -553,8 +645,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
|||||||
class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
self.config = config
|
self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
|
||||||
self.question_encoder = TFDPREncoder(config, name="question_encoder")
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.question_encoder.bert_model.get_input_embeddings()
|
return self.question_encoder.bert_model.get_input_embeddings()
|
||||||
@@ -641,8 +732,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
|||||||
class TFDPRReader(TFDPRPretrainedReader):
|
class TFDPRReader(TFDPRPretrainedReader):
|
||||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
self.config = config
|
self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
|
||||||
self.span_predictor = TFDPRSpanPredictor(config, name="span_predictor")
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
||||||
|
|||||||
Reference in New Issue
Block a user