From 76296569266fbc238defbe68d06ad720b3a544e3 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 26 Oct 2022 22:41:53 +0200 Subject: [PATCH] `accelerate` support for `RoBERTa` family (#19906) --- .../models/camembert/modeling_camembert.py | 7 ++++++- .../models/data2vec/modeling_data2vec_text.py | 7 ++++++- src/transformers/models/lilt/modeling_lilt.py | 1 + .../models/longformer/modeling_longformer.py | 7 ++++++- src/transformers/models/luke/modeling_luke.py | 13 ++++++++++++- .../models/roberta/modeling_roberta.py | 7 ++++++- .../models/xlm_roberta/modeling_xlm_roberta.py | 7 ++++++- tests/test_modeling_common.py | 18 +++++++++--------- 8 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 58ea054985..514566596c 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module): def _tie_weights(self): # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias @add_start_docstrings( @@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel): """ _keys_to_ignore_on_load_missing = [r"position_ids"] + _no_split_modules = [] # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert def __init__(self, config, add_pooling_layer=True): diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 56b78d90be..543e5ee367 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): config_class = Data2VecTextConfig base_model_prefix = "data2vec_text" supports_gradient_checkpointing = True + _no_split_modules = [] def _init_weights(self, module): """Initialize the weights""" @@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module): def _tie_weights(self): # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias @add_start_docstrings( diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index c78490f4b4..6859aff7e6 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel): config_class = LiltConfig base_model_prefix = "lilt" supports_gradient_checkpointing = True + _no_split_modules = [] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 00cd227a68..7cbeb36f54 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module): def _tie_weights(self): # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias class LongformerPreTrainedModel(PreTrainedModel): @@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel): base_model_prefix = "longformer" supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"position_ids"] + _no_split_modules = ["LongformerSelfAttention"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index befeaccd55..a1f9f3cbd9 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel): config_class = LukeConfig base_model_prefix = "luke" supports_gradient_checkpointing = True + _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"] def _init_weights(self, module: nn.Module): """Initialize the weights""" @@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module): def _tie_weights(self): # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias @add_start_docstrings( @@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): hidden_size = outputs.last_hidden_state.size(-1) entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size) + if entity_start_positions.device != outputs.last_hidden_state.device: + entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device) start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions) + entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size) + if entity_end_positions.device != outputs.last_hidden_state.device: + entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device) end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions) + feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2) feature_vector = self.dropout(feature_vector) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 4ace886338..0e0f822d41 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _no_split_modules = [] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module): def _tie_weights(self): # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias @add_start_docstrings( diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 9e5bdfe4fc..5393961508 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): config_class = XLMRobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _no_split_modules = [] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module): def _tie_weights(self): # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias @add_start_docstrings( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6ff31a4de8..5d299f5c83 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2312,11 +2312,11 @@ class ModelTesterMixin: if model_class._no_split_modules is None: continue - inputs_dict = self._prepare_for_class(inputs_dict, model_class) + inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) model = model_class(config).eval() model = model.to(torch_device) torch.manual_seed(0) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict_class) model_size = compute_module_sizes(model)[""] max_size = int(self.model_split_percents[0] * model_size) @@ -2334,7 +2334,7 @@ class ModelTesterMixin: self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) - new_output = new_model(**inputs_dict) + new_output = new_model(**inputs_dict_class) self.assertTrue(torch.allclose(base_output[0], new_output[0])) @@ -2347,12 +2347,12 @@ class ModelTesterMixin: if model_class._no_split_modules is None: continue - inputs_dict = self._prepare_for_class(inputs_dict, model_class) + inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) model = model_class(config).eval() model = model.to(torch_device) torch.manual_seed(0) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict_class) model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works. @@ -2369,7 +2369,7 @@ class ModelTesterMixin: self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) - new_output = new_model(**inputs_dict) + new_output = new_model(**inputs_dict_class) self.assertTrue(torch.allclose(base_output[0], new_output[0])) @@ -2382,12 +2382,12 @@ class ModelTesterMixin: if model_class._no_split_modules is None: continue - inputs_dict = self._prepare_for_class(inputs_dict, model_class) + inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) model = model_class(config).eval() model = model.to(torch_device) torch.manual_seed(0) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict_class) model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works. @@ -2404,7 +2404,7 @@ class ModelTesterMixin: self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) - new_output = new_model(**inputs_dict) + new_output = new_model(**inputs_dict_class) self.assertTrue(torch.allclose(base_output[0], new_output[0]))