Tf longformer for sequence classification (#8231)
* working on LongformerForSequenceClassification * add TFLongformerForMultipleChoice * add TFLongformerForTokenClassification * use add_start_docstrings_to_model_forward * test TFLongformerForSequenceClassification * test TFLongformerForMultipleChoice * test TFLongformerForTokenClassification * remove test from repo * add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice * add requested classes to modeling_tf_auto.py update dummy_tf_objects fix tests fix bugs in requested classes * pass all tests except test_inputs_embeds * sync with master * pass all tests except test_inputs_embeds * pass all tests * pass all tests * work on test_inputs_embeds * fix style and quality * make multi choice work * fix TFLongformerForTokenClassification signature * fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature * fix mult choice * fix mc hint * fix input embeds * fix input embeds * refactor input embeds * fix copy issue * apply sylvains changes and clean more Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, features):
|
||||
x = self.dense(features)
|
||||
x = self.act(x)
|
||||
x = self.layer_norm(x)
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = self.decoder(x, mode="linear") + self.bias
|
||||
hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
|
||||
|
||||
return x
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
|
||||
Reference in New Issue
Block a user