Fix TF DPR (#9283)
* Fix DPR * Keep usual models * Apply style * Address Sylvain's comments
This commit is contained in:
@@ -144,18 +144,18 @@ class TFDPRReaderOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
class TFDPREncoder(TFPreTrainedModel):
|
||||
class TFDPREncoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
base_model_prefix = "bert_model"
|
||||
|
||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
def __init__(self, config: DPRConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# resolve name conflict with TFBertMainLayer instead of TFBertModel
|
||||
self.bert_model = TFBertMainLayer(config, name="bert_model")
|
||||
self.bert_model.config = config
|
||||
self.config = config
|
||||
|
||||
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
|
||||
assert self.config.hidden_size > 0, "Encoder hidden_size can't be zero"
|
||||
self.projection_dim = config.projection_dim
|
||||
if self.projection_dim > 0:
|
||||
self.encode_proj = tf.keras.layers.Dense(
|
||||
@@ -220,13 +220,14 @@ class TFDPREncoder(TFPreTrainedModel):
|
||||
return self.bert_model.config.hidden_size
|
||||
|
||||
|
||||
class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
|
||||
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.encoder = TFDPREncoder(config, name="encoder")
|
||||
def __init__(self, config: DPRConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.encoder = TFDPREncoderLayer(config, name="encoder")
|
||||
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
@@ -299,6 +300,97 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
def __init__(self, config: DPRConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.encoder = TFDPRSpanPredictorLayer(config)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: tf.Tensor = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = False,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFDPREncoder(TFPreTrainedModel):
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
def __init__(self, config: DPRConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.encoder = TFDPREncoderLayer(config)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: tf.Tensor = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = False,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
##################
|
||||
# PreTrainedModel
|
||||
##################
|
||||
@@ -465,8 +557,7 @@ TF_DPR_READER_INPUTS_DOCSTRING = r"""
|
||||
class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.config = config
|
||||
self.ctx_encoder = TFDPREncoder(config, name="ctx_encoder")
|
||||
self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.ctx_encoder.bert_model.get_input_embeddings()
|
||||
@@ -541,6 +632,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return outputs[1:]
|
||||
|
||||
return TFDPRContextEncoderOutput(
|
||||
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
@@ -553,8 +645,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.config = config
|
||||
self.question_encoder = TFDPREncoder(config, name="question_encoder")
|
||||
self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.question_encoder.bert_model.get_input_embeddings()
|
||||
@@ -641,8 +732,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||
class TFDPRReader(TFDPRPretrainedReader):
|
||||
def __init__(self, config: DPRConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.config = config
|
||||
self.span_predictor = TFDPRSpanPredictor(config, name="span_predictor")
|
||||
self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
||||
|
||||
Reference in New Issue
Block a user