Update demo_camembert.py with new classes
This commit is contained in:
committed by
Julien Chaumond
parent
f12e4d8da7
commit
3e20c2e871
@@ -5,7 +5,7 @@ import urllib.request
|
||||
import torch
|
||||
|
||||
from transformers.tokenization_camembert import CamembertTokenizer
|
||||
from transformers.modeling_roberta import RobertaForMaskedLM
|
||||
from transformers.modeling_camembert import CamembertForMaskedLM
|
||||
|
||||
|
||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||
@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||
return topk_filled_outputs
|
||||
|
||||
|
||||
model_path = Path('camembert.v0.pytorch')
|
||||
if not model_path.exists():
|
||||
compressed_path = model_path.with_suffix('.tar.gz')
|
||||
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
|
||||
print('Downloading model...')
|
||||
urllib.request.urlretrieve(url, compressed_path)
|
||||
print('Extracting model...')
|
||||
with tarfile.open(compressed_path) as f:
|
||||
f.extractall(model_path.parent)
|
||||
assert model_path.exists()
|
||||
tokenizer_path = model_path / 'sentencepiece.bpe.model'
|
||||
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
|
||||
model = RobertaForMaskedLM.from_pretrained(model_path)
|
||||
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
|
||||
model = CamembertForMaskedLM.from_pretrained('camembert-base')
|
||||
model.eval()
|
||||
|
||||
masked_input = "Le camembert est <mask> :)"
|
||||
|
||||
Reference in New Issue
Block a user