[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
|
||||
)
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
"Compute validation"
|
||||
|
||||
"""Compute validation""" ""
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if self.config.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
|
||||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user