Remove dependency on examples/seq2seq from rag (#7395)

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Ola Piktus
2020-09-25 17:20:49 +01:00
committed by GitHub
parent ad39271ae8
commit fe326bd5cf
3 changed files with 157 additions and 20 deletions

View File

@@ -34,22 +34,23 @@ from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
from examples.rag.callbacks import ( # noqa: E402 # isort:skip
get_checkpoint_callback,
get_early_stopping_callback,
Seq2SeqLoggingCallback,
)
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from examples.rag.utils import ( # noqa: E402 # isort:skip
Seq2SeqDataset,
calculate_exact_match,
is_rag_model,
set_extra_model_params,
)
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
flatten_list,
get_git_info,
is_rag_model,
lmap,
pickle_save,
save_git_info,
save_json,
set_extra_model_params,
Seq2SeqDataset,
)
logging.basicConfig(level=logging.INFO)
@@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer):
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
@@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer):
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
)
return dataloader
@@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")