fixing run_generation
This commit is contained in:
@@ -156,7 +156,7 @@ def main():
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
||||
parser.add_argument("--k", type=int, default=0)
|
||||
parser.add_argument("--p", type=float, default=0.9)
|
||||
@@ -187,7 +187,6 @@ def main():
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
args.length = adjust_length_to_model(
|
||||
args.length, max_sequence_length=model.config.max_position_embeddings
|
||||
@@ -202,11 +201,11 @@ def main():
|
||||
if requires_preprocessing:
|
||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
||||
|
||||
output_sequences = model.generate(
|
||||
intput_ids=encoded_prompt,
|
||||
length=args.length,
|
||||
input_ids=encoded_prompt,
|
||||
max_length=args.length,
|
||||
temperature=args.temperature,
|
||||
top_k=args.k,
|
||||
top_p=args.p,
|
||||
|
||||
Reference in New Issue
Block a user