Update demo_camembert.py with new classes

This commit is contained in:
Louis MARTIN
2019-11-12 17:16:24 -08:00
committed by Julien Chaumond
parent f12e4d8da7
commit 3e20c2e871

View File

@@ -5,7 +5,7 @@ import urllib.request
import torch import torch
from transformers.tokenization_camembert import CamembertTokenizer from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_roberta import RobertaForMaskedLM from transformers.modeling_camembert import CamembertForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5): def fill_mask(masked_input, model, tokenizer, topk=5):
@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
return topk_filled_outputs return topk_filled_outputs
model_path = Path('camembert.v0.pytorch') tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
if not model_path.exists(): model = CamembertForMaskedLM.from_pretrained('camembert-base')
compressed_path = model_path.with_suffix('.tar.gz')
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
print('Downloading model...')
urllib.request.urlretrieve(url, compressed_path)
print('Extracting model...')
with tarfile.open(compressed_path) as f:
f.extractall(model_path.parent)
assert model_path.exists()
tokenizer_path = model_path / 'sentencepiece.bpe.model'
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
model = RobertaForMaskedLM.from_pretrained(model_path)
model.eval() model.eval()
masked_input = "Le camembert est <mask> :)" masked_input = "Le camembert est <mask> :)"