Trainer (#3800)
* doc
* [tests] Add sample files for a regression task
* [HUGE] Trainer
* Feedback from @sshleifer
* Feedback from @thomwolf + logging tweak
* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes
* [glue] Use default max_seq_length of 128 like before
* [glue] move DataTrainingArguments around
* [ner] Change interface of InputExample, and align run_{tf,pl}
* Re-align the pl scripts a little bit
* ner
* [ner] Add integration test
* Fix language_modeling with API tweak
* [ci] Tweak loss target
* Don't break console output
* amp.initialize: model must be on right device before
* [multiple-choice] update for Trainer
* Re-align to 827d6d6ef0
This commit is contained in:
@@ -35,8 +35,8 @@ class GLUETransformer(BaseTransformer):
|
||||
def training_step(self, batch, batch_idx):
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
|
||||
if self.hparams.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = batch[2] if self.hparams.model_type in ["bert", "xlnet", "albert"] else None
|
||||
if self.config.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None
|
||||
|
||||
outputs = self(**inputs)
|
||||
loss = outputs[0]
|
||||
@@ -95,8 +95,8 @@ class GLUETransformer(BaseTransformer):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
|
||||
if self.hparams.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = batch[2] if self.hparams.model_type in ["bert", "xlnet", "albert"] else None
|
||||
if self.config.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None
|
||||
|
||||
outputs = self(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
@@ -179,7 +179,7 @@ if __name__ == "__main__":
|
||||
|
||||
# If output_dir not provided, a folder will be generated in pwd
|
||||
if args.output_dir is None:
|
||||
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||
args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
model = GLUETransformer(args)
|
||||
|
||||
Reference in New Issue
Block a user