fixing run_generation
This commit is contained in:
@@ -156,7 +156,7 @@ def main():
|
|||||||
parser.add_argument("--length", type=int, default=20)
|
parser.add_argument("--length", type=int, default=20)
|
||||||
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling")
|
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling")
|
||||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
||||||
parser.add_argument("--k", type=int, default=0)
|
parser.add_argument("--k", type=int, default=0)
|
||||||
parser.add_argument("--p", type=float, default=0.9)
|
parser.add_argument("--p", type=float, default=0.9)
|
||||||
@@ -187,7 +187,6 @@ def main():
|
|||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||||
model = model_class.from_pretrained(args.model_name_or_path)
|
model = model_class.from_pretrained(args.model_name_or_path)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
args.length = adjust_length_to_model(
|
args.length = adjust_length_to_model(
|
||||||
args.length, max_sequence_length=model.config.max_position_embeddings
|
args.length, max_sequence_length=model.config.max_position_embeddings
|
||||||
@@ -202,11 +201,11 @@ def main():
|
|||||||
if requires_preprocessing:
|
if requires_preprocessing:
|
||||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
||||||
|
|
||||||
output_sequences = model.generate(
|
output_sequences = model.generate(
|
||||||
intput_ids=encoded_prompt,
|
input_ids=encoded_prompt,
|
||||||
length=args.length,
|
max_length=args.length,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_k=args.k,
|
top_k=args.k,
|
||||||
top_p=args.p,
|
top_p=args.p,
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ class PretrainedConfig(object):
|
|||||||
self.bos_token_id = kwargs.pop('bos_token_id', 0)
|
self.bos_token_id = kwargs.pop('bos_token_id', 0)
|
||||||
self.pad_token_id = kwargs.pop('pad_token_id', 0)
|
self.pad_token_id = kwargs.pop('pad_token_id', 0)
|
||||||
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
|
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
|
||||||
self.batch_size = kwargs.pop('batch_size', 1)
|
|
||||||
self.length_penalty = kwargs.pop('length_penalty', 1.)
|
self.length_penalty = kwargs.pop('length_penalty', 1.)
|
||||||
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
|
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
|
||||||
|
|
||||||
|
|||||||
@@ -485,9 +485,10 @@ class PreTrainedModel(nn.Module):
|
|||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
||||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||||
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
|
bos_token_id=None, pad_token_id=None, eos_token_ids=None,
|
||||||
length_penalty=None, num_return_sequences=None, **model_kwargs):
|
length_penalty=None, num_return_sequences=None, **model_kwargs):
|
||||||
""" Sequence generator for models with a LM head.
|
""" Sequence generator for models with a LM head.
|
||||||
|
|
||||||
@@ -530,19 +531,20 @@ class PreTrainedModel(nn.Module):
|
|||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||||
batch_size = batch_size if batch_size is not None else self.config.batch_size
|
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||||
|
else:
|
||||||
|
batch_size = 1
|
||||||
if isinstance(eos_token_ids, int):
|
if isinstance(eos_token_ids, int):
|
||||||
eos_token_ids = [eos_token_ids]
|
eos_token_ids = [eos_token_ids]
|
||||||
|
|
||||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||||
assert temperature > 0, "`temperature` should be strictely positive."
|
# assert temperature > 0, "`temperature` should be strictely positive."
|
||||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||||
@@ -550,7 +552,6 @@ class PreTrainedModel(nn.Module):
|
|||||||
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
|
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
|
||||||
assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \
|
assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \
|
||||||
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||||
assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer."
|
|
||||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||||
assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer."
|
assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer."
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user