Merge remote-tracking branch 'huggingface/master' into run_multiple_choice_merge
# Conflicts: # examples/contrib/run_swag.py
This commit is contained in:
@@ -527,7 +527,7 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertModel(BertPreTrainedModel):
|
||||
r"""
|
||||
|
||||
@@ -394,7 +394,7 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
|
||||
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
|
||||
class DistilBertModel(DistilBertPreTrainedModel):
|
||||
r"""
|
||||
|
||||
@@ -314,7 +314,7 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
||||
class GPT2Model(GPT2PreTrainedModel):
|
||||
r"""
|
||||
|
||||
@@ -324,7 +324,7 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
|
||||
OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
r"""
|
||||
|
||||
@@ -124,7 +124,7 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
class RobertaModel(BertModel):
|
||||
r"""
|
||||
|
||||
@@ -820,7 +820,7 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
|
||||
class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
r"""
|
||||
|
||||
@@ -313,7 +313,7 @@ XLM_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMModel(XLMPreTrainedModel):
|
||||
r"""
|
||||
|
||||
@@ -546,7 +546,7 @@ XLNET_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetModel(XLNetPreTrainedModel):
|
||||
r"""
|
||||
@@ -743,8 +743,9 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
if data_mask is not None:
|
||||
# all mems can be attended to
|
||||
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
|
||||
data_mask = torch.cat([mems_mask, data_mask], dim=1)
|
||||
if mlen > 0:
|
||||
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
|
||||
data_mask = torch.cat([mems_mask, data_mask], dim=1)
|
||||
if attn_mask is None:
|
||||
attn_mask = data_mask[:, :, :, None]
|
||||
else:
|
||||
@@ -755,7 +756,8 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
if attn_mask is not None:
|
||||
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
|
||||
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
|
||||
if mlen > 0:
|
||||
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
|
||||
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
|
||||
else:
|
||||
non_tgt_mask = None
|
||||
@@ -775,8 +777,11 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
##### Segment embedding
|
||||
if token_type_ids is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
|
||||
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
|
||||
if mlen > 0:
|
||||
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
|
||||
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
|
||||
else:
|
||||
cat_ids = token_type_ids
|
||||
|
||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
|
||||
|
||||
Reference in New Issue
Block a user