From 3e20c2e871db82f81c3b2b814265a481be15273c Mon Sep 17 00:00:00 2001 From: Louis MARTIN Date: Tue, 12 Nov 2019 17:16:24 -0800 Subject: [PATCH] Update demo_camembert.py with new classes --- examples/contrib/demo_camembert.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/examples/contrib/demo_camembert.py b/examples/contrib/demo_camembert.py index df28f4f267..28144d5167 100644 --- a/examples/contrib/demo_camembert.py +++ b/examples/contrib/demo_camembert.py @@ -5,7 +5,7 @@ import urllib.request import torch from transformers.tokenization_camembert import CamembertTokenizer -from transformers.modeling_roberta import RobertaForMaskedLM +from transformers.modeling_camembert import CamembertForMaskedLM def fill_mask(masked_input, model, tokenizer, topk=5): @@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5): return topk_filled_outputs -model_path = Path('camembert.v0.pytorch') -if not model_path.exists(): - compressed_path = model_path.with_suffix('.tar.gz') - url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz' - print('Downloading model...') - urllib.request.urlretrieve(url, compressed_path) - print('Extracting model...') - with tarfile.open(compressed_path) as f: - f.extractall(model_path.parent) - assert model_path.exists() -tokenizer_path = model_path / 'sentencepiece.bpe.model' -tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path) -model = RobertaForMaskedLM.from_pretrained(model_path) +tokenizer = CamembertTokenizer.from_pretrained('camembert-base') +model = CamembertForMaskedLM.from_pretrained('camembert-base') model.eval() masked_input = "Le camembert est :)"