Remove dependency on examples/seq2seq from rag (#7395)
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -34,22 +34,23 @@ from transformers import logging as transformers_logging
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
||||
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
|
||||
from examples.rag.callbacks import ( # noqa: E402 # isort:skip
|
||||
get_checkpoint_callback,
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
||||
Seq2SeqDataset,
|
||||
calculate_exact_match,
|
||||
is_rag_model,
|
||||
set_extra_model_params,
|
||||
)
|
||||
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
|
||||
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
is_rag_model,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
set_extra_model_params,
|
||||
Seq2SeqDataset,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer):
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
@@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer):
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
@@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer):
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
|
||||
Reference in New Issue
Block a user