[s2s] add support for overriding config params (#6149)
This commit is contained in:
@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
|
||||
)
|
||||
else:
|
||||
self.config: PretrainedConfig = config
|
||||
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
if getattr(self.hparams, p, None):
|
||||
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
||||
setattr(self.config, p, getattr(self.hparams, p))
|
||||
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_layerdrop",
|
||||
type=float,
|
||||
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_layerdrop",
|
||||
type=float,
|
||||
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout", type=float, help="Dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
|
||||
@@ -66,6 +66,19 @@ Summarization Tips:
|
||||
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
|
||||
A new dataset is needed to support multilingual tasks.
|
||||
|
||||
### Finetuning Training Params
|
||||
|
||||
To override the pretrained model's training params, you can pass them to `./finetune.sh`:
|
||||
|
||||
```bash
|
||||
./finetune.sh \
|
||||
[...]
|
||||
--encoder_layerdrop 0.1 \
|
||||
--decoder_layerdrop 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
```
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
|
||||
@@ -10,4 +10,8 @@ python finetune.py \
|
||||
--do_predict \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.1 \
|
||||
--encoder_layerdrop 0.1 \
|
||||
--decoder_layerdrop 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
$@
|
||||
|
||||
@@ -277,6 +277,55 @@ def test_finetune(model):
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
|
||||
def test_finetune_extra_model_args():
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir()
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# test models whose config includes the extra_model_args
|
||||
model = BART_TINY
|
||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||
args_d1 = args_d.copy()
|
||||
args_d1.update(
|
||||
model_name_or_path=model, output_dir=output_dir,
|
||||
)
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
args_d1[p] = 0.5
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
for p in extra_model_params:
|
||||
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||
|
||||
# test models whose config doesn't include the extra_model_args
|
||||
model = T5_TINY
|
||||
output_dir = tempfile.mkdtemp(prefix="output_2_")
|
||||
args_d2 = args_d.copy()
|
||||
args_d2.update(
|
||||
model_name_or_path=model, output_dir=output_dir,
|
||||
)
|
||||
unsupported_param = "encoder_layerdrop"
|
||||
args_d2[unsupported_param] = 0.5
|
||||
args = argparse.Namespace(**args_d2)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
model = main(args)
|
||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||
|
||||
|
||||
def test_pack_dataset():
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user