[bart-tiny-random] Put a 5MB model on S3 to allow faster exampl… (#3488)
This commit is contained in:
@@ -16,15 +16,17 @@ def chunks(lst, n):
|
|||||||
yield lst[i : i + n]
|
yield lst[i : i + n]
|
||||||
|
|
||||||
|
|
||||||
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
def generate_summaries(
|
||||||
|
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
||||||
|
):
|
||||||
fout = Path(out_file).open("w")
|
fout = Path(out_file).open("w")
|
||||||
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
|
model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device)
|
||||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||||
|
|
||||||
max_length = 140
|
max_length = 140
|
||||||
min_length = 55
|
min_length = 55
|
||||||
|
|
||||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
||||||
summaries = model.generate(
|
summaries = model.generate(
|
||||||
input_ids=dct["input_ids"].to(device),
|
input_ids=dct["input_ids"].to(device),
|
||||||
@@ -51,6 +53,9 @@ def _run_generate():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"output_path", type=str, help="where to save summaries",
|
"output_path", type=str, help="where to save summaries",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
||||||
)
|
)
|
||||||
@@ -58,8 +63,8 @@ def _run_generate():
|
|||||||
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
lns = [" " + x.rstrip() for x in open(args.source_path).readlines()]
|
examples = [" " + x.rstrip() for x in open(args.source_path).readlines()]
|
||||||
generate_summaries(lns, args.output_path, batch_size=args.bs, device=args.device)
|
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||||
with tmp.open("w") as f:
|
with tmp.open("w") as f:
|
||||||
f.write("\n".join(articles))
|
f.write("\n".join(articles))
|
||||||
testargs = ["evaluate_cnn.py", str(tmp), output_file_name]
|
|
||||||
|
testargs = ["evaluate_cnn.py", str(tmp), output_file_name, "sshleifer/bart-tiny-random"]
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
_run_generate()
|
_run_generate()
|
||||||
self.assertTrue(Path(output_file_name).exists())
|
self.assertTrue(Path(output_file_name).exists())
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
BartModel,
|
BartModel,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
@@ -183,6 +185,15 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_tiny_model(self):
|
||||||
|
model_name = "sshleifer/bart-tiny-random"
|
||||||
|
tiny = AutoModel.from_pretrained(model_name) # same vocab size
|
||||||
|
tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
|
||||||
|
inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
tiny(**inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BartHeadTests(unittest.TestCase):
|
class BartHeadTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user