[s2s] add support for overriding config params (#6149)
This commit is contained in:
@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.config: PretrainedConfig = config
|
self.config: PretrainedConfig = config
|
||||||
|
|
||||||
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||||
|
for p in extra_model_params:
|
||||||
|
if getattr(self.hparams, p, None):
|
||||||
|
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
||||||
|
setattr(self.config, p, getattr(self.hparams, p))
|
||||||
|
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||||
@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
type=str,
|
type=str,
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder_layerdrop",
|
||||||
|
type=float,
|
||||||
|
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder_layerdrop",
|
||||||
|
type=float,
|
||||||
|
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dropout", type=float, help="Dropout probability (Optional). Goes into model.config",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
|
||||||
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
|
|||||||
@@ -66,6 +66,19 @@ Summarization Tips:
|
|||||||
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
|
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
|
||||||
A new dataset is needed to support multilingual tasks.
|
A new dataset is needed to support multilingual tasks.
|
||||||
|
|
||||||
|
### Finetuning Training Params
|
||||||
|
|
||||||
|
To override the pretrained model's training params, you can pass them to `./finetune.sh`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./finetune.sh \
|
||||||
|
[...]
|
||||||
|
--encoder_layerdrop 0.1 \
|
||||||
|
--decoder_layerdrop 0.1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--attention_dropout 0.1 \
|
||||||
|
```
|
||||||
|
|
||||||
### Summarization Finetuning
|
### Summarization Finetuning
|
||||||
Run/modify `finetune.sh`
|
Run/modify `finetune.sh`
|
||||||
|
|
||||||
|
|||||||
@@ -10,4 +10,8 @@ python finetune.py \
|
|||||||
--do_predict \
|
--do_predict \
|
||||||
--n_val 1000 \
|
--n_val 1000 \
|
||||||
--val_check_interval 0.1 \
|
--val_check_interval 0.1 \
|
||||||
|
--encoder_layerdrop 0.1 \
|
||||||
|
--decoder_layerdrop 0.1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--attention_dropout 0.1 \
|
||||||
$@
|
$@
|
||||||
|
|||||||
@@ -277,6 +277,55 @@ def test_finetune(model):
|
|||||||
assert bart.decoder.embed_tokens == bart.shared
|
assert bart.decoder.embed_tokens == bart.shared
|
||||||
|
|
||||||
|
|
||||||
|
def test_finetune_extra_model_args():
|
||||||
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
|
task = "summarization"
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
|
||||||
|
args_d.update(
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
tokenizer_name=None,
|
||||||
|
train_batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
do_predict=False,
|
||||||
|
task=task,
|
||||||
|
src_lang="en_XX",
|
||||||
|
tgt_lang="ro_RO",
|
||||||
|
freeze_encoder=True,
|
||||||
|
freeze_embeds=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test models whose config includes the extra_model_args
|
||||||
|
model = BART_TINY
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||||
|
args_d1 = args_d.copy()
|
||||||
|
args_d1.update(
|
||||||
|
model_name_or_path=model, output_dir=output_dir,
|
||||||
|
)
|
||||||
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||||
|
for p in extra_model_params:
|
||||||
|
args_d1[p] = 0.5
|
||||||
|
args = argparse.Namespace(**args_d1)
|
||||||
|
model = main(args)
|
||||||
|
for p in extra_model_params:
|
||||||
|
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||||
|
|
||||||
|
# test models whose config doesn't include the extra_model_args
|
||||||
|
model = T5_TINY
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_2_")
|
||||||
|
args_d2 = args_d.copy()
|
||||||
|
args_d2.update(
|
||||||
|
model_name_or_path=model, output_dir=output_dir,
|
||||||
|
)
|
||||||
|
unsupported_param = "encoder_layerdrop"
|
||||||
|
args_d2[unsupported_param] = 0.5
|
||||||
|
args = argparse.Namespace(**args_d2)
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
model = main(args)
|
||||||
|
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||||
|
|
||||||
|
|
||||||
def test_pack_dataset():
|
def test_pack_dataset():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user