Update demo_camembert.py with new classes

This commit is contained in:
Louis MARTIN
2019-11-12 17:16:24 -08:00
committed by Julien Chaumond
parent f12e4d8da7
commit 3e20c2e871

View File

@@ -5,7 +5,7 @@ import urllib.request
import torch
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_roberta import RobertaForMaskedLM
from transformers.modeling_camembert import CamembertForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5):
@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
return topk_filled_outputs
model_path = Path('camembert.v0.pytorch')
if not model_path.exists():
compressed_path = model_path.with_suffix('.tar.gz')
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
print('Downloading model...')
urllib.request.urlretrieve(url, compressed_path)
print('Extracting model...')
with tarfile.open(compressed_path) as f:
f.extractall(model_path.parent)
assert model_path.exists()
tokenizer_path = model_path / 'sentencepiece.bpe.model'
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
model = RobertaForMaskedLM.from_pretrained(model_path)
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForMaskedLM.from_pretrained('camembert-base')
model.eval()
masked_input = "Le camembert est <mask> :)"