From 6e72fd094c98901cc90d146d3fe3cd5a0e879911 Mon Sep 17 00:00:00 2001 From: Louis MARTIN Date: Fri, 8 Nov 2019 17:09:48 -0800 Subject: [PATCH] Add demo_camembert.py --- examples/demo_camembert.py | 59 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 examples/demo_camembert.py diff --git a/examples/demo_camembert.py b/examples/demo_camembert.py new file mode 100644 index 0000000000..df28f4f267 --- /dev/null +++ b/examples/demo_camembert.py @@ -0,0 +1,59 @@ +from pathlib import Path +import tarfile +import urllib.request + +import torch + +from transformers.tokenization_camembert import CamembertTokenizer +from transformers.modeling_roberta import RobertaForMaskedLM + + +def fill_mask(masked_input, model, tokenizer, topk=5): + # Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py + assert masked_input.count('') == 1 + input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1 + logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple + masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item() + logits = logits[0, masked_index, :] + prob = logits.softmax(dim=0) + values, indices = prob.topk(k=topk, dim=0) + topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item()) + for i in range(len(indices))]) + masked_token = tokenizer.mask_token + topk_filled_outputs = [] + for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): + predicted_token = predicted_token_bpe.replace('\u2581', ' ') + if " {0}".format(masked_token) in masked_input: + topk_filled_outputs.append(( + masked_input.replace( + ' {0}'.format(masked_token), predicted_token + ), + values[index].item(), + predicted_token, + )) + else: + topk_filled_outputs.append(( + masked_input.replace(masked_token, predicted_token), + values[index].item(), + predicted_token, + )) + return topk_filled_outputs + + +model_path = Path('camembert.v0.pytorch') +if not model_path.exists(): + compressed_path = model_path.with_suffix('.tar.gz') + url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz' + print('Downloading model...') + urllib.request.urlretrieve(url, compressed_path) + print('Extracting model...') + with tarfile.open(compressed_path) as f: + f.extractall(model_path.parent) + assert model_path.exists() +tokenizer_path = model_path / 'sentencepiece.bpe.model' +tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path) +model = RobertaForMaskedLM.from_pretrained(model_path) +model.eval() + +masked_input = "Le camembert est :)" +print(fill_mask(masked_input, model, tokenizer, topk=3))