Updated quick-start example with BertForMaskedLM

As `convert_ids_to_tokens` returns a list, the code in the README currently throws an `AssertionError`, so I propose I quick fix.
This commit is contained in:
Davide Fiocco
2018-11-28 14:53:46 +01:00
committed by GitHub
parent 21f0196412
commit ec2c339b53

View File

@@ -142,7 +142,7 @@ predictions = model(tokens_tensor, segments_tensors)
# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'
```