Update demo_camembert.py with new classes
This commit is contained in:
committed by
Julien Chaumond
parent
f12e4d8da7
commit
3e20c2e871
@@ -5,7 +5,7 @@ import urllib.request
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.tokenization_camembert import CamembertTokenizer
|
from transformers.tokenization_camembert import CamembertTokenizer
|
||||||
from transformers.modeling_roberta import RobertaForMaskedLM
|
from transformers.modeling_camembert import CamembertForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||||
@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
|
|||||||
return topk_filled_outputs
|
return topk_filled_outputs
|
||||||
|
|
||||||
|
|
||||||
model_path = Path('camembert.v0.pytorch')
|
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
|
||||||
if not model_path.exists():
|
model = CamembertForMaskedLM.from_pretrained('camembert-base')
|
||||||
compressed_path = model_path.with_suffix('.tar.gz')
|
|
||||||
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
|
|
||||||
print('Downloading model...')
|
|
||||||
urllib.request.urlretrieve(url, compressed_path)
|
|
||||||
print('Extracting model...')
|
|
||||||
with tarfile.open(compressed_path) as f:
|
|
||||||
f.extractall(model_path.parent)
|
|
||||||
assert model_path.exists()
|
|
||||||
tokenizer_path = model_path / 'sentencepiece.bpe.model'
|
|
||||||
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
|
|
||||||
model = RobertaForMaskedLM.from_pretrained(model_path)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
masked_input = "Le camembert est <mask> :)"
|
masked_input = "Le camembert est <mask> :)"
|
||||||
|
|||||||
Reference in New Issue
Block a user