fix QA example for PT (#6890)

This commit is contained in:
Patrick von Platen
2020-09-02 09:53:09 +02:00
committed by GitHub
parent d822ab636b
commit 1889e96c8c

View File

@@ -303,14 +303,15 @@ PT_QUESTION_ANSWERING_SAMPLE = r"""
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True) >>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='pt')
>>> start_positions = torch.tensor([1]) >>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3]) >>> end_positions = torch.tensor([3])
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions) >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
>>> loss = outputs.loss >>> loss = outputs.loss
>>> start_scores = outputs.start_scores >>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_scores >>> end_scores = outputs.end_logits
""" """
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""