Merge branch 'master' into fix_top_k_top_p_filtering
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
|
||||
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
@@ -26,12 +26,14 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig
|
||||
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
|
||||
|
||||
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
||||
from transformers import CTRLLMHeadModel, CTRLTokenizer
|
||||
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
@@ -41,13 +43,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||||
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
|
||||
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
||||
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
|
||||
}
|
||||
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
@@ -103,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
|
||||
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
|
||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||||
generated = context
|
||||
@@ -121,10 +126,27 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
next_token_logits = outputs[0][:, -1, :] / temperature
|
||||
if is_xlm_mlm and xlm_mask_token:
|
||||
# XLM MLM models are direct models (predict same token, not next token)
|
||||
# => need one additional dummy token in the input (will be masked and guessed)
|
||||
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
|
||||
inputs = {'input_ids': input_ids}
|
||||
|
||||
if xlm_lang is not None:
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
||||
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
|
||||
|
||||
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
|
||||
for _ in set(generated.view(-1).tolist()):
|
||||
next_token_logits[_] /= repetition_penalty
|
||||
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
if temperature == 0: #greedy sampling:
|
||||
next_token = torch.argmax(filtered_logits).unsqueeze(0)
|
||||
else:
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
return generated
|
||||
|
||||
@@ -137,15 +159,21 @@ def main():
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--padding_text", type=str, default="")
|
||||
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--num_samples", type=int, default=1)
|
||||
parser.add_argument("--temperature", type=float, default=1.0,
|
||||
help="temperature of 0 implies greedy sampling")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0,
|
||||
help="primarily useful for CTRL model; in that case, use 1.2")
|
||||
parser.add_argument("--top_k", type=int, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--stop_token', type=str, default=None,
|
||||
help="Token at which text generation is stopped")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
@@ -167,13 +195,39 @@ def main():
|
||||
elif args.length < 0:
|
||||
args.length = MAX_LENGTH # avoid infinite loop
|
||||
|
||||
print(args)
|
||||
logger.info(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7:
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
while True:
|
||||
xlm_lang = None
|
||||
# XLM Language usage detailed in the issues #1414
|
||||
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
|
||||
and model.config.use_lang_emb:
|
||||
if args.xlm_lang:
|
||||
language = args.xlm_lang
|
||||
else:
|
||||
language = None
|
||||
while language not in tokenizer.lang2id.keys():
|
||||
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
||||
xlm_lang = tokenizer.lang2id[language]
|
||||
|
||||
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
|
||||
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
xlm_mask_token = tokenizer.mask_token_id
|
||||
else:
|
||||
xlm_mask_token = None
|
||||
|
||||
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||
if args.model_type in ["transfo-xl", "xlnet"]:
|
||||
# Models with memory likes to have a long prompt for short inputs.
|
||||
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
|
||||
if args.model_type == "ctrl":
|
||||
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
|
||||
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||
out = sample_sequence(
|
||||
model=model,
|
||||
context=context_tokens,
|
||||
@@ -182,13 +236,20 @@ def main():
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
device=args.device,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
is_xlnet=bool(args.model_type == "xlnet"),
|
||||
is_xlm_mlm=is_xlm_mlm,
|
||||
xlm_mask_token=xlm_mask_token,
|
||||
xlm_lang=xlm_lang,
|
||||
device=args.device,
|
||||
)
|
||||
out = out[:, len(context_tokens):].tolist()
|
||||
for o in out:
|
||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
|
||||
print(text)
|
||||
|
||||
if args.prompt:
|
||||
break
|
||||
return text
|
||||
|
||||
Reference in New Issue
Block a user