accelerate support for RoBERTa family (#19906)
This commit is contained in:
@@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module):
|
|||||||
|
|
||||||
def _tie_weights(self):
|
def _tie_weights(self):
|
||||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
self.bias = self.decoder.bias
|
# For accelerate compatibility and to not break backward compatibility
|
||||||
|
if self.decoder.bias.device.type == "meta":
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
else:
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
|
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
|
||||||
def __init__(self, config, add_pooling_layer=True):
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
|
|||||||
@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = Data2VecTextConfig
|
config_class = Data2VecTextConfig
|
||||||
base_model_prefix = "data2vec_text"
|
base_model_prefix = "data2vec_text"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module):
|
|||||||
|
|
||||||
def _tie_weights(self):
|
def _tie_weights(self):
|
||||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
self.bias = self.decoder.bias
|
# For accelerate compatibility and to not break backward compatibility
|
||||||
|
if self.decoder.bias.device.type == "meta":
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
else:
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = LiltConfig
|
config_class = LiltConfig
|
||||||
base_model_prefix = "lilt"
|
base_model_prefix = "lilt"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|||||||
@@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module):
|
|||||||
|
|
||||||
def _tie_weights(self):
|
def _tie_weights(self):
|
||||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
self.bias = self.decoder.bias
|
# For accelerate compatibility and to not break backward compatibility
|
||||||
|
if self.decoder.bias.device.type == "meta":
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
else:
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
class LongformerPreTrainedModel(PreTrainedModel):
|
class LongformerPreTrainedModel(PreTrainedModel):
|
||||||
@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "longformer"
|
base_model_prefix = "longformer"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
|
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
|
||||||
|
_no_split_modules = ["LongformerSelfAttention"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel):
|
|||||||
config_class = LukeConfig
|
config_class = LukeConfig
|
||||||
base_model_prefix = "luke"
|
base_model_prefix = "luke"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
|
||||||
|
|
||||||
def _init_weights(self, module: nn.Module):
|
def _init_weights(self, module: nn.Module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module):
|
|||||||
|
|
||||||
def _tie_weights(self):
|
def _tie_weights(self):
|
||||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
self.bias = self.decoder.bias
|
# For accelerate compatibility and to not break backward compatibility
|
||||||
|
if self.decoder.bias.device.type == "meta":
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
else:
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
|
|||||||
hidden_size = outputs.last_hidden_state.size(-1)
|
hidden_size = outputs.last_hidden_state.size(-1)
|
||||||
|
|
||||||
entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
|
entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
|
||||||
|
if entity_start_positions.device != outputs.last_hidden_state.device:
|
||||||
|
entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device)
|
||||||
start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
|
start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
|
||||||
|
|
||||||
entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
|
entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
|
||||||
|
if entity_end_positions.device != outputs.last_hidden_state.device:
|
||||||
|
entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device)
|
||||||
end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
|
end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
|
||||||
|
|
||||||
feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
|
feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
|
||||||
|
|
||||||
feature_vector = self.dropout(feature_vector)
|
feature_vector = self.dropout(feature_vector)
|
||||||
|
|||||||
@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
@@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module):
|
|||||||
|
|
||||||
def _tie_weights(self):
|
def _tie_weights(self):
|
||||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
self.bias = self.decoder.bias
|
# For accelerate compatibility and to not break backward compatibility
|
||||||
|
if self.decoder.bias.device.type == "meta":
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
else:
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = XLMRobertaConfig
|
config_class = XLMRobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
@@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module):
|
|||||||
|
|
||||||
def _tie_weights(self):
|
def _tie_weights(self):
|
||||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
self.bias = self.decoder.bias
|
# For accelerate compatibility and to not break backward compatibility
|
||||||
|
if self.decoder.bias.device.type == "meta":
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
else:
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
|
|||||||
if model_class._no_split_modules is None:
|
if model_class._no_split_modules is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config).eval()
|
model = model_class(config).eval()
|
||||||
model = model.to(torch_device)
|
model = model.to(torch_device)
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
base_output = model(**inputs_dict)
|
base_output = model(**inputs_dict_class)
|
||||||
|
|
||||||
model_size = compute_module_sizes(model)[""]
|
model_size = compute_module_sizes(model)[""]
|
||||||
max_size = int(self.model_split_percents[0] * model_size)
|
max_size = int(self.model_split_percents[0] * model_size)
|
||||||
@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
new_output = new_model(**inputs_dict)
|
new_output = new_model(**inputs_dict_class)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||||
|
|
||||||
@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
|
|||||||
if model_class._no_split_modules is None:
|
if model_class._no_split_modules is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config).eval()
|
model = model_class(config).eval()
|
||||||
model = model.to(torch_device)
|
model = model.to(torch_device)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
base_output = model(**inputs_dict)
|
base_output = model(**inputs_dict_class)
|
||||||
|
|
||||||
model_size = compute_module_sizes(model)[""]
|
model_size = compute_module_sizes(model)[""]
|
||||||
# We test several splits of sizes to make sure it works.
|
# We test several splits of sizes to make sure it works.
|
||||||
@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
|
|||||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
new_output = new_model(**inputs_dict)
|
new_output = new_model(**inputs_dict_class)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||||
|
|
||||||
@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
|
|||||||
if model_class._no_split_modules is None:
|
if model_class._no_split_modules is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config).eval()
|
model = model_class(config).eval()
|
||||||
model = model.to(torch_device)
|
model = model.to(torch_device)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
base_output = model(**inputs_dict)
|
base_output = model(**inputs_dict_class)
|
||||||
|
|
||||||
model_size = compute_module_sizes(model)[""]
|
model_size = compute_module_sizes(model)[""]
|
||||||
# We test several splits of sizes to make sure it works.
|
# We test several splits of sizes to make sure it works.
|
||||||
@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
|
|||||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
new_output = new_model(**inputs_dict)
|
new_output = new_model(**inputs_dict_class)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user