Fix CI after killing archive maps (#4724)
Some checks failed
GitHub-hosted runner / check_code_quality (push) Has been cancelled

* 🐛 Fix model ids for BART and Flaubert
This commit is contained in:
Julien Chaumond
2020-06-02 10:21:09 -04:00
committed by GitHub
parent b43c78e5d3
commit b42586ea56
9 changed files with 39 additions and 34 deletions

View File

@@ -21,7 +21,7 @@ def generate_summaries(
):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
max_length = 140
min_length = 55
@@ -54,7 +54,7 @@ def run_generate():
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn",
"model_name", type=str, default="facebook/bart-large-cnn", help="like bart-large-cnn",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",

View File

@@ -129,7 +129,7 @@ class TestBartExamples(unittest.TestCase):
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
_dump_articles((tmp_dir / "train.source"), articles)
_dump_articles((tmp_dir / "train.target"), summaries)
tokenizer = BartTokenizer.from_pretrained("bart-large")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
trunc_target = 4