Fix CI after killing archive maps (#4724)
Some checks failed
GitHub-hosted runner / check_code_quality (push) Has been cancelled
Some checks failed
GitHub-hosted runner / check_code_quality (push) Has been cancelled
* 🐛 Fix model ids for BART and Flaubert
This commit is contained in:
@@ -22,7 +22,7 @@ Implementation Notes
|
||||
- The forward pass of ``BartModel`` will create decoder inputs (using the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``) if they are not passed. This is different than some other modeling APIs.
|
||||
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to ``fairseq.encode`` starts with a space.
|
||||
- ``BartForConditionalGeneration.generate`` should be used for conditional generation tasks like summarization, see the example in that docstrings
|
||||
- Models that load the ``"bart-large-cnn"`` weights will not have a ``mask_token_id``, or be able to perform mask filling tasks.
|
||||
- Models that load the ``"facebook/bart-large-cnn"`` weights will not have a ``mask_token_id``, or be able to perform mask filling tasks.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ def generate_summaries(
|
||||
):
|
||||
fout = Path(out_file).open("w")
|
||||
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
|
||||
max_length = 140
|
||||
min_length = 55
|
||||
@@ -54,7 +54,7 @@ def run_generate():
|
||||
"output_path", type=str, help="where to save summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn",
|
||||
"model_name", type=str, default="facebook/bart-large-cnn", help="like bart-large-cnn",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
||||
|
||||
@@ -129,7 +129,7 @@ class TestBartExamples(unittest.TestCase):
|
||||
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
_dump_articles((tmp_dir / "train.source"), articles)
|
||||
_dump_articles((tmp_dir / "train.target"), summaries)
|
||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
|
||||
trunc_target = 4
|
||||
|
||||
@@ -22,7 +22,7 @@ from .configuration_utils import PretrainedConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"}
|
||||
|
||||
|
||||
class CTRLConfig(PretrainedConfig):
|
||||
|
||||
@@ -124,7 +124,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
fairseq_output = bart.extract_features(tokens)
|
||||
if hf_checkpoint_name == "bart-large":
|
||||
if hf_checkpoint_name == "facebook/bart-large":
|
||||
model = BartModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
new_model_outputs = model(tokens).model[0]
|
||||
|
||||
@@ -1633,7 +1633,7 @@ SUPPORTED_TASKS = {
|
||||
"impl": SummarizationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "bart-large-cnn", "tf": "t5-small"}, "config": None, "tokenizer": None},
|
||||
"default": {"model": {"pt": "facebook/bart-large-cnn", "tf": "t5-small"}, "config": None, "tokenizer": None},
|
||||
},
|
||||
"translation_en_to_fr": {
|
||||
"impl": TranslationPipeline,
|
||||
|
||||
@@ -25,7 +25,12 @@ logger = logging.getLogger(__name__)
|
||||
# vocab and merges same as roberta
|
||||
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]
|
||||
_all_bart_models = [
|
||||
"facebook/bart-large",
|
||||
"facebook/bart-large-mnli",
|
||||
"facebook/bart-large-cnn",
|
||||
"facebook/bart-large-xsum",
|
||||
]
|
||||
|
||||
|
||||
class BartTokenizer(RobertaTokenizer):
|
||||
@@ -37,7 +42,7 @@ class BartTokenizer(RobertaTokenizer):
|
||||
}
|
||||
|
||||
|
||||
_all_mbart_models = ["mbart-large-en-ro"]
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
|
||||
|
||||
@@ -32,31 +32,31 @@ VOCAB_FILES_NAMES = {
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/vocab.json",
|
||||
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/vocab.json",
|
||||
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/vocab.json",
|
||||
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/vocab.json",
|
||||
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/vocab.json",
|
||||
"flaubert/flaubert_base_uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/vocab.json",
|
||||
"flaubert/flaubert_base_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/vocab.json",
|
||||
"flaubert/flaubert_large_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/vocab.json",
|
||||
},
|
||||
"merges_file": {
|
||||
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/merges.txt",
|
||||
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/merges.txt",
|
||||
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/merges.txt",
|
||||
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/merges.txt",
|
||||
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/merges.txt",
|
||||
"flaubert/flaubert_base_uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/merges.txt",
|
||||
"flaubert/flaubert_base_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/merges.txt",
|
||||
"flaubert/flaubert_large_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"flaubert-small-cased": 512,
|
||||
"flaubert-base-uncased": 512,
|
||||
"flaubert-base-cased": 512,
|
||||
"flaubert-large-cased": 512,
|
||||
"flaubert/flaubert_small_cased": 512,
|
||||
"flaubert/flaubert_base_uncased": 512,
|
||||
"flaubert/flaubert_base_cased": 512,
|
||||
"flaubert/flaubert_large_cased": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"flaubert-small-cased": {"do_lowercase": False},
|
||||
"flaubert-base-uncased": {"do_lowercase": True},
|
||||
"flaubert-base-cased": {"do_lowercase": False},
|
||||
"flaubert-large-cased": {"do_lowercase": False},
|
||||
"flaubert/flaubert_small_cased": {"do_lowercase": False},
|
||||
"flaubert/flaubert_base_uncased": {"do_lowercase": True},
|
||||
"flaubert/flaubert_base_cased": {"do_lowercase": False},
|
||||
"flaubert/flaubert_large_cased": {"do_lowercase": False},
|
||||
}
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user