accelerate support for RoBERTa family (#19906)

This commit is contained in:
Younes Belkada
2022-10-26 22:41:53 +02:00
committed by GitHub
parent 6d023270f6
commit 7629656926
8 changed files with 52 additions and 15 deletions

View File

@@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel):
""" """
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):

View File

@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
config_class = Data2VecTextConfig config_class = Data2VecTextConfig
base_model_prefix = "data2vec_text" base_model_prefix = "data2vec_text"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(

View File

@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
config_class = LiltConfig config_class = LiltConfig
base_model_prefix = "lilt" base_model_prefix = "lilt"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):

View File

@@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
class LongformerPreTrainedModel(PreTrainedModel): class LongformerPreTrainedModel(PreTrainedModel):
@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "longformer" base_model_prefix = "longformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"position_ids"]
_no_split_modules = ["LongformerSelfAttention"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel):
config_class = LukeConfig config_class = LukeConfig
base_model_prefix = "luke" base_model_prefix = "luke"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
"""Initialize the weights""" """Initialize the weights"""
@@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
hidden_size = outputs.last_hidden_state.size(-1) hidden_size = outputs.last_hidden_state.size(-1)
entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size) entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
if entity_start_positions.device != outputs.last_hidden_state.device:
entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device)
start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions) start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size) entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
if entity_end_positions.device != outputs.last_hidden_state.device:
entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device)
end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions) end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2) feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
feature_vector = self.dropout(feature_vector) feature_vector = self.dropout(feature_vector)

View File

@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig config_class = RobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
@@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(

View File

@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
@@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(

View File

@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size) max_size = int(self.model_split_percents[0] * model_size)
@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))