Remove dependency on examples/seq2seq from rag (#7395)

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Ola Piktus
2020-09-25 17:20:49 +01:00
committed by GitHub
parent ad39271ae8
commit fe326bd5cf
3 changed files with 157 additions and 20 deletions

View File

@@ -1,15 +1,20 @@
import itertools
import json
import linecache
import os
import pickle
import re
import socket
import string
from collections import Counter
from logging import getLogger
from pathlib import Path
from typing import Dict, List
from typing import Callable, Dict, Iterable, List
import git
import torch
from torch.utils.data import Dataset
from examples.seq2seq.utils import SortishSampler, trim_batch
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
@@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru
)
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class Seq2SeqDataset(Dataset):
def __init__(
self,
@@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset):
}
return batch
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
"hostname": str(socket.gethostname()),
}
return repo_infos
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""