ProphetNet (#7157)
* add new model prophetnet prophetnet modified modify codes as suggested v1 add prophetnet test files * still bugs, because of changed output formats of encoder and decoder * move prophetnet into the latest version * clean integration tests * clean tokenizers * add xlm config to init * correct typo in init * further refactoring * continue refactor * save parallel * add decoder_attention_mask * fix use_cache vs. past_key_values * fix common tests * change decoder output logits * fix xlm tests * make common tests pass * change model architecture * add tokenizer tests * finalize model structure * no weight mapping * correct n-gram stream attention mask as discussed with qweizhen * remove unused import * fix index.rst * fix tests * delete unnecessary code * add fast integration test * rename weights * final weight remapping * save intermediate * Descriptions for Prophetnet Config File * finish all models * finish new model outputs * delete unnecessary files * refactor encoder layer * add dummy docs * code quality * fix tests * add model pages to doctree * further refactor * more refactor, more tests * finish code refactor and tests * remove unnecessary files * further clean up * add docstring template * finish tokenizer doc * finish prophetnet * fix copies * fix typos * fix tf tests * fix fp16 * fix tf test 2nd try * fix code quality * add test for each model * merge new tests to branch * Update model_cards/microsoft/prophetnet-large-uncased-cnndm/README.md Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update model_cards/microsoft/prophetnet-large-uncased-cnndm/README.md Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update src/transformers/modeling_prophetnet.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update utils/check_repo.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * apply sams and sylvains comments * make style * remove unnecessary code * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/configuration_prophetnet.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * implement lysandres comments * correct docs * fix isort * fix tokenizers * fix copies Co-authored-by: weizhen <weizhen@mail.ustc.edu.cn> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -490,12 +490,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
||||
if hasattr(self, self.base_model_prefix):
|
||||
self = getattr(self, self.base_model_prefix)
|
||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||
|
||||
@staticmethod
|
||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
|
||||
if decoder.__class__ != encoder.__class__:
|
||||
logger.info(
|
||||
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
|
||||
)
|
||||
|
||||
def tie_encoder_to_decoder_recursively(
|
||||
decoder_pointer: nn.Module,
|
||||
@@ -528,7 +533,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if name.isdigit():
|
||||
encoder_name = str(int(name) + encoder_layer_pos)
|
||||
decoder_name = name
|
||||
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
|
||||
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
|
||||
encoder_modules
|
||||
) != len(decoder_modules):
|
||||
# this can happen if the name corresponds to the position in a list module list of layers
|
||||
# in this case the decoder has added a cross-attention that the encoder does not have
|
||||
# thus skip this step and substract one layer pos from encoder
|
||||
|
||||
Reference in New Issue
Block a user