Remove dependency on examples/seq2seq from rag (#7395)
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -1,15 +1,20 @@
|
||||
import itertools
|
||||
import json
|
||||
import linecache
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import socket
|
||||
import string
|
||||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
import git
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from examples.seq2seq.utils import SortishSampler, trim_batch
|
||||
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
@@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru
|
||||
)
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class Seq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset):
|
||||
}
|
||||
return batch
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
"""Save git information to output_dir/git_log.json"""
|
||||
repo_infos = get_git_info()
|
||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||
|
||||
|
||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
"hostname": str(socket.gethostname()),
|
||||
}
|
||||
return repo_infos
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
"""pickle.dump(obj, path)"""
|
||||
with open(path, "wb") as f:
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user