fix conflicts
This commit is contained in:
@@ -81,6 +81,7 @@ class PretrainedConfig(object):
|
|||||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||||
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
|
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
|
||||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import ipdb
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if getattr(output_embeddings, "bias", None) is not None:
|
if getattr(output_embeddings, "bias", None) is not None:
|
||||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||||
output_embeddings.bias.data,
|
output_embeddings.bias.data,
|
||||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
|
||||||
"constant",
|
"constant",
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
@@ -411,7 +411,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"Error no file named {} found in directory {} or `from_tf` set to False".format(
|
"Error no file named {} found in directory {} or `from_tf` set to False".format(
|
||||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index",],
|
||||||
|
pretrained_model_name_or_path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
@@ -425,7 +426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME)
|
pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
||||||
)
|
)
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
@@ -520,7 +521,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
def load(module: nn.Module, prefix=""):
|
def load(module: nn.Module, prefix=""):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
module._load_from_state_dict(
|
module._load_from_state_dict(
|
||||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
|
||||||
)
|
)
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
@@ -620,6 +621,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
pad_token_id=None,
|
pad_token_id=None,
|
||||||
eos_token_ids=None,
|
eos_token_ids=None,
|
||||||
length_penalty=None,
|
length_penalty=None,
|
||||||
|
no_repeat_ngram_size=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
):
|
):
|
||||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
@@ -725,6 +727,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
|
no_repeat_ngram_size = (
|
||||||
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||||
|
)
|
||||||
num_return_sequences = (
|
num_return_sequences = (
|
||||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
)
|
)
|
||||||
@@ -754,6 +759,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||||
|
assert (
|
||||||
|
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
||||||
|
), "`no_repeat_ngram_size` should be a positive integer."
|
||||||
assert (
|
assert (
|
||||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||||
), "`num_return_sequences` should be a strictly positive integer."
|
), "`num_return_sequences` should be a strictly positive integer."
|
||||||
@@ -764,7 +772,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
||||||
)
|
)
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device
|
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||||
@@ -811,23 +819,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# TODO (PVP): check eos_token_id
|
# TODO (PVP): check eos_token_id
|
||||||
# TODO (PVP): probably not the best way to check whether model is encoder decoder
|
# TODO (PVP): probably not the best way to check whether model is encoder decoder
|
||||||
is_encoder_decoder = (
|
is_encoder_decoder = (
|
||||||
hasattr(self, "model")
|
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
|
||||||
and hasattr(self.model, "decoder")
|
|
||||||
and hasattr(self.model, "encoder")
|
|
||||||
)
|
)
|
||||||
if is_encoder_decoder:
|
if is_encoder_decoder:
|
||||||
eos_token_id = eos_token_ids[0]
|
eos_token_id = eos_token_ids[0]
|
||||||
assert (
|
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||||
bos_token_id is not None
|
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
|
||||||
), "Encoder Decoder Models need to have a bos_token_id"
|
|
||||||
assert (
|
|
||||||
eos_token_id is not None
|
|
||||||
), "Encoder Decoder Models need to have a eos_token_id"
|
|
||||||
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
||||||
encoder_inputs = input_ids
|
encoder_inputs = input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
|
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
|
||||||
bos_token_id,
|
bos_token_id,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
@@ -849,6 +851,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
effective_batch_size,
|
effective_batch_size,
|
||||||
@@ -869,6 +872,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
effective_batch_size,
|
effective_batch_size,
|
||||||
@@ -888,6 +892,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -902,9 +907,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
past = None
|
past = None
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs)
|
||||||
input_ids, past=past, encoder_inputs=encoder_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
@@ -917,9 +920,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
||||||
|
|
||||||
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
|
if no_repeat_ngram_size > 0:
|
||||||
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
next_token_logits[
|
||||||
|
batch_idx, banned_tokens[batch_idx]
|
||||||
|
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||||
|
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
for eos_token_id in eos_token_ids:
|
||||||
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
next_token_logits[
|
||||||
|
:, eos_token_id
|
||||||
|
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
@@ -981,6 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -993,9 +1008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
""" Generate sequences for each example with beam search.
|
""" Generate sequences for each example with beam search.
|
||||||
"""
|
"""
|
||||||
is_encoder_decoder = (
|
is_encoder_decoder = (
|
||||||
hasattr(self, "model")
|
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
|
||||||
and hasattr(self.model, "decoder")
|
|
||||||
and hasattr(self.model, "encoder")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# generated hypotheses
|
# generated hypotheses
|
||||||
@@ -1017,9 +1030,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs)
|
||||||
input_ids, past=past, encoder_inputs=encoder_inputs
|
|
||||||
)
|
|
||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -1030,12 +1041,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
self.enforce_repetition_penalty_(
|
self.enforce_repetition_penalty_(
|
||||||
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
|
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
|
if no_repeat_ngram_size > 0:
|
||||||
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
next_token_logits[
|
||||||
|
batch_idx, banned_tokens[batch_idx]
|
||||||
|
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||||
|
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
for eos_token_id in eos_token_ids:
|
||||||
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
next_token_logits[
|
||||||
|
:, eos_token_id
|
||||||
|
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
@@ -1070,14 +1092,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# do greedy beam search
|
# do greedy beam search
|
||||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
if is_encoder_decoder: # TODO(PVP) to be refactored later
|
if is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
||||||
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
||||||
# scores[:, pad_token_id] = -math.inf => seems very hacky here
|
# scores[:, pad_token_id] = -math.inf => seems very hacky here
|
||||||
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
||||||
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
||||||
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
||||||
if cur_len == max_length - 1: # FORCE EOS to be chosen
|
if cur_len == max_length - 1: # FORCE EOS to be chosen
|
||||||
all_but_eos_mask = torch.tensor([x for x in range(vocab_size) if x not in eos_token_ids], dtype=torch.long, device=next(self.parameters()).device)
|
all_but_eos_mask = torch.tensor(
|
||||||
|
[x for x in range(vocab_size) if x not in eos_token_ids],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=next(self.parameters()).device,
|
||||||
|
)
|
||||||
scores[:, all_but_eos_mask] = -10000.0
|
scores[:, all_but_eos_mask] = -10000.0
|
||||||
|
|
||||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||||
@@ -1175,7 +1201,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
assert torch.all(
|
assert torch.all(
|
||||||
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
||||||
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
|
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
|
||||||
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx]
|
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
# need to add best num_beams hypotheses to generated hyps
|
# need to add best num_beams hypotheses to generated hyps
|
||||||
@@ -1218,7 +1244,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
assert (len(hypo) == max_length for hypo in best)
|
assert (len(hypo) == max_length for hypo in best)
|
||||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||||
|
|
||||||
|
if is_encoder_decoder:
|
||||||
|
# do not return first <BOS> token
|
||||||
return decoded[:, 1:]
|
return decoded[:, 1:]
|
||||||
|
return decoded
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
@@ -1235,6 +1264,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
return past
|
return past
|
||||||
|
|
||||||
|
|
||||||
|
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, step):
|
||||||
|
# Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||||
|
if step + 2 < no_repeat_ngram_size:
|
||||||
|
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||||
|
return [[] for _ in range(num_hypos)]
|
||||||
|
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||||
|
for idx in range(num_hypos):
|
||||||
|
gen_tokens = prev_input_ids[idx].tolist()
|
||||||
|
generated_ngram = generated_ngrams[idx]
|
||||||
|
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
||||||
|
prev_ngram_tuple = tuple(ngram[:-1])
|
||||||
|
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||||
|
|
||||||
|
def _get_generated_ngrams(hypo_idx):
|
||||||
|
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||||
|
start_idx = step + 2 - no_repeat_ngram_size
|
||||||
|
end_idx = step + 1
|
||||||
|
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:end_idx].tolist())
|
||||||
|
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||||
|
|
||||||
|
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||||
|
return banned_tokens
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
Args:
|
Args:
|
||||||
@@ -1508,7 +1561,7 @@ class SQuADHead(nn.Module):
|
|||||||
self.answer_class = PoolerAnswerClass(config)
|
self.answer_class = PoolerAnswerClass(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None
|
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
|
||||||
):
|
):
|
||||||
outputs = ()
|
outputs = ()
|
||||||
|
|
||||||
@@ -1567,7 +1620,7 @@ class SQuADHead(nn.Module):
|
|||||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||||
|
|
||||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
|
||||||
|
|
||||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||||
# or (if labels are provided) (total_loss,)
|
# or (if labels are provided) (total_loss,)
|
||||||
@@ -1636,7 +1689,7 @@ class SequenceSummary(nn.Module):
|
|||||||
output = hidden_states.mean(dim=1)
|
output = hidden_states.mean(dim=1)
|
||||||
elif self.summary_type == "cls_index":
|
elif self.summary_type == "cls_index":
|
||||||
if cls_index is None:
|
if cls_index is None:
|
||||||
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long)
|
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
|
||||||
else:
|
else:
|
||||||
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -54,13 +54,13 @@ class ModelTesterMixin:
|
|||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
_A_test_torchscript = True
|
test_torchscript = True
|
||||||
_A_test_pruning = True
|
test_pruning = True
|
||||||
_A_test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
_A_test_head_masking = True
|
test_head_masking = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _A_test_save_load(self):
|
def test_save_load(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -85,7 +85,7 @@ class ModelTesterMixin:
|
|||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def _A_test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config)
|
configs_no_init = _config_zero_init(config)
|
||||||
@@ -99,7 +99,7 @@ class ModelTesterMixin:
|
|||||||
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _A_test_determinism(self):
|
def test_determinism(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -116,7 +116,7 @@ class ModelTesterMixin:
|
|||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def _A_test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||||
@@ -179,25 +179,25 @@ class ModelTesterMixin:
|
|||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _A_test_torchscript(self):
|
def test_torchscript(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
self._create_and_check_torchscript(config, inputs_dict)
|
self._create_and_check_torchscript(config, inputs_dict)
|
||||||
|
|
||||||
def _A_test_torchscript_output_attentions(self):
|
def test_torchscript_output_attentions(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
self._create_and_check_torchscript(config, inputs_dict)
|
self._create_and_check_torchscript(config, inputs_dict)
|
||||||
|
|
||||||
def _A_test_torchscript_output_hidden_state(self):
|
def test_torchscript_output_hidden_state(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
self._create_and_check_torchscript(config, inputs_dict)
|
self._create_and_check_torchscript(config, inputs_dict)
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
if not self._A_test_torchscript:
|
if not self.test_torchscript:
|
||||||
return
|
return
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
@@ -245,8 +245,8 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
def _A_test_headmasking(self):
|
def test_headmasking(self):
|
||||||
if not self._A_test_head_masking:
|
if not self.test_head_masking:
|
||||||
return
|
return
|
||||||
|
|
||||||
global_rng.seed(42)
|
global_rng.seed(42)
|
||||||
@@ -299,8 +299,8 @@ class ModelTesterMixin:
|
|||||||
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
||||||
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||||
|
|
||||||
def _A_test_head_pruning(self):
|
def test_head_pruning(self):
|
||||||
if not self._A_test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -328,8 +328,8 @@ class ModelTesterMixin:
|
|||||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
def _A_test_head_pruning_save_load_from_pretrained(self):
|
def test_head_pruning_save_load_from_pretrained(self):
|
||||||
if not self._A_test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -361,8 +361,8 @@ class ModelTesterMixin:
|
|||||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
def _A_test_head_pruning_save_load_from_config_init(self):
|
def test_head_pruning_save_load_from_config_init(self):
|
||||||
if not self._A_test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -392,8 +392,8 @@ class ModelTesterMixin:
|
|||||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
def _A_test_head_pruning_integration(self):
|
def test_head_pruning_integration(self):
|
||||||
if not self._A_test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -449,7 +449,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
||||||
|
|
||||||
def _A_test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -474,9 +474,9 @@ class ModelTesterMixin:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _A_test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self._A_test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -516,7 +516,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
def _A_test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -594,7 +594,7 @@ class ModelTesterMixin:
|
|||||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
def _A_test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.is_encoder_decoder:
|
if not self.is_encoder_decoder:
|
||||||
@@ -711,7 +711,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|||||||
@require_torch
|
@require_torch
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def _A_test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
config = BertConfig.from_pretrained(model_name)
|
config = BertConfig.from_pretrained(model_name)
|
||||||
@@ -736,7 +736,7 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
class UtilsFunctionsTest(unittest.TestCase):
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
|
|
||||||
# tests whether the top_k_top_p function behaves as expected
|
# tests whether the top_k_top_p function behaves as expected
|
||||||
def _A_test_top_k_top_p_filtering(self):
|
def test_top_k_top_p_filtering(self):
|
||||||
logits = torch.tensor(
|
logits = torch.tensor(
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user