Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -35,7 +36,6 @@ from transformers import (
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,8 +20,8 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import (
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
@@ -134,7 +134,6 @@ if is_torch_available():
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
|
||||
@@ -25,14 +25,14 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
@@ -173,7 +173,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
@@ -263,7 +262,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||
|
||||
if args.model_type == "albert":
|
||||
model.albert.set_regression_threshold(args.regression_threshold)
|
||||
model.albert.set_patience(patience)
|
||||
@@ -736,7 +734,6 @@ def main():
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_with_pabee
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
|
||||
@@ -24,10 +24,10 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from configuration_bertabs import BertAbsConfig
|
||||
from torch import nn
|
||||
from torch.nn.init import xavier_uniform_
|
||||
|
||||
from configuration_bertabs import BertAbsConfig
|
||||
from transformers import BertConfig, BertModel, PreTrainedModel
|
||||
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ import sys
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .utils_summarization import (
|
||||
@@ -45,7 +45,6 @@ def evaluate(args):
|
||||
generated_summaries = []
|
||||
|
||||
import nltk
|
||||
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
|
||||
@@ -3,8 +3,8 @@ from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, DatasetDict, load_dataset
|
||||
|
||||
from evaluate import load
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from arguments import TokenizerTrainingArguments
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from arguments import TokenizerTrainingArguments
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
|
||||
@@ -6,16 +6,16 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from arguments import TrainingArguments
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from arguments import TrainingArguments
|
||||
from huggingface_hub import Repository
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||
|
||||
|
||||
|
||||
@@ -5,15 +5,15 @@ import re
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from arguments import HumanEvalArguments
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from arguments import HumanEvalArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from arguments import InitializationArguments
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
|
||||
@@ -6,10 +6,9 @@ from functools import partial
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
from dpu_utils.utils.iterators import ThreadedIterator
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
||||
|
||||
@@ -9,10 +9,10 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PreprocessingArguments
|
||||
from datasets import load_dataset
|
||||
from minhash_deduplication import deduplicate_dataset
|
||||
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
from arguments import PretokenizationArguments
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PretokenizationArguments
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from arguments import EvaluationArguments
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
from arguments import EvaluationArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import gym
|
||||
from mujoco_py import GlfwContext
|
||||
|
||||
from transformers import DecisionTransformerModel
|
||||
|
||||
|
||||
|
||||
@@ -229,7 +229,10 @@ class DeeBertModel(BertPreTrainedModel):
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
outputs = (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
@@ -19,7 +19,6 @@ from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayExc
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaModel(DeeBertModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
@@ -36,7 +35,6 @@ class DeeRobertaModel(DeeBertModel):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_deebert
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, require_torch_non_multi_gpu, slow
|
||||
|
||||
|
||||
@@ -45,7 +46,6 @@ class DeeBertTests(TestCasePlus):
|
||||
@slow
|
||||
@require_torch_non_multi_gpu
|
||||
def test_glue_deebert_train(self):
|
||||
|
||||
train_args = """
|
||||
--model_type roberta
|
||||
--model_name_or_path roberta-base
|
||||
|
||||
@@ -21,14 +21,14 @@ import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from torch import nn
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from utils import logger
|
||||
|
||||
|
||||
@@ -189,7 +189,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
|
||||
@@ -24,9 +24,9 @@ import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from distiller import Distiller
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
|
||||
@@ -5,13 +5,13 @@ import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import joblib
|
||||
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
@@ -119,7 +119,6 @@ def recopy_gpt2(orig_model, device, max_steps):
|
||||
|
||||
|
||||
def intermittent_save(contexts, real_perps, past_perps, filename):
|
||||
|
||||
"""
|
||||
save the perplexity differences to filename
|
||||
|
||||
@@ -152,7 +151,6 @@ def collect_objective_set(
|
||||
filename="dev.jbl",
|
||||
recopy_model=recopy_gpt2,
|
||||
):
|
||||
|
||||
"""
|
||||
Collect individual IGF values from pre-trained transformer model
|
||||
max_steps samples of training data to train secondary model
|
||||
@@ -271,7 +269,6 @@ def generate_datasets(
|
||||
def train_secondary_learner(
|
||||
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
||||
):
|
||||
|
||||
"""
|
||||
Train the secondary learner (igf_model)
|
||||
|
||||
|
||||
@@ -28,11 +28,9 @@ Last, a plot is generated to compare the performance of IGF to standard fine-tun
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
import joblib
|
||||
from igf.igf import (
|
||||
SecondaryLearner,
|
||||
collect_objective_set,
|
||||
@@ -43,6 +41,8 @@ from igf.igf import (
|
||||
set_seed,
|
||||
train_secondary_learner,
|
||||
)
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
@@ -55,7 +55,6 @@ def generate_n_pairs(
|
||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||
igf_data_file="igf_context_pairs.jbl",
|
||||
):
|
||||
|
||||
"""
|
||||
Collecting *n* pairs for training the secondary learner
|
||||
Args:
|
||||
|
||||
@@ -4,8 +4,6 @@ from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -16,6 +14,8 @@ from flax import jax_utils, struct, traverse_util
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
||||
|
||||
@@ -98,7 +98,6 @@ class Args:
|
||||
|
||||
@dataclass
|
||||
class DataCollator:
|
||||
|
||||
pad_id: int
|
||||
max_length: int = 4096 # no dynamic padding on TPUs
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from datasets import load_from_disk
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from bigbird_flax import FlaxBigBirdForNaturalQuestions
|
||||
from datasets import load_from_disk
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import jsonlines
|
||||
|
||||
|
||||
DOC_STRIDE = 2048
|
||||
MAX_LENGTH = 4096
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
from dataclasses import replace
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
import jax
|
||||
import wandb
|
||||
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
|
||||
@@ -32,17 +32,17 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
|
||||
@@ -20,6 +20,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from configuration_hybrid_clip import HybridCLIPConfig
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
||||
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
||||
@@ -132,7 +133,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
|
||||
@@ -32,22 +32,22 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import torch
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.io import ImageReadMode, read_image
|
||||
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
||||
|
||||
|
||||
|
||||
@@ -28,19 +28,19 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from flax.core.frozen_dict import freeze, unfreeze
|
||||
from flax.training.common_utils import onehot, stack_forest
|
||||
from jax.experimental.maps import mesh
|
||||
from jax.experimental.pjit import pjit
|
||||
from partitions import set_partitions
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
|
||||
@@ -6,18 +6,18 @@ from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import librosa
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
HfArgumentParser,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import datasets
|
||||
import faiss
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import faiss
|
||||
import transformers
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
make_qa_s2s_model,
|
||||
@@ -13,6 +11,8 @@ from eli5_utils import (
|
||||
query_es_index,
|
||||
query_qa_dense_index,
|
||||
)
|
||||
|
||||
import transformers
|
||||
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from random import choice, randint
|
||||
from time import time
|
||||
|
||||
import datasets # noqa: F401
|
||||
import faiss # noqa: F401
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -15,7 +16,6 @@ from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
import faiss # noqa: F401
|
||||
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
|
||||
@@ -27,14 +27,14 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from datasets import ClassLabel, load_dataset, load_metric
|
||||
from huggingface_hub import Repository
|
||||
from luke_utils import DataCollatorForLukeTokenClassification, is_punctuation, padding_tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from huggingface_hub import Repository
|
||||
from luke_utils import DataCollatorForLukeTokenClassification, is_punctuation, padding_tensor
|
||||
from transformers import (
|
||||
AdamW,
|
||||
LukeConfig,
|
||||
|
||||
@@ -9,9 +9,9 @@ from collections import OrderedDict
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modeling_frcnn import GeneralizedRCNN
|
||||
from processing_image import Preprocess
|
||||
|
||||
from utils import Config
|
||||
|
||||
|
||||
|
||||
@@ -169,7 +169,6 @@ def get_norm(norm, out_channels):
|
||||
|
||||
|
||||
def _create_grid_offsets(size: List[int], stride: int, offset: float, device):
|
||||
|
||||
grid_height, grid_width = size
|
||||
shifts_x = torch.arange(
|
||||
offset * stride,
|
||||
@@ -390,7 +389,6 @@ def assign_boxes_to_levels(
|
||||
canonical_box_size: int,
|
||||
canonical_level: int,
|
||||
):
|
||||
|
||||
box_sizes = torch.sqrt(torch.cat([boxes.area() for boxes in box_lists]))
|
||||
# Eqn.(1) in FPN paper
|
||||
level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
|
||||
@@ -1708,9 +1706,10 @@ class GeneralizedRCNN(nn.Module):
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
assert from_tf, (
|
||||
"We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
|
||||
.format(pretrained_model_name_or_path + ".index")
|
||||
assert (
|
||||
from_tf
|
||||
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
||||
pretrained_model_name_or_path + ".index"
|
||||
)
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
|
||||
@@ -34,14 +34,13 @@ from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import wget
|
||||
from filelock import FileLock
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from yaml import Loader, dump, load
|
||||
|
||||
|
||||
@@ -181,7 +180,6 @@ class Config:
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -225,14 +223,13 @@ class Config:
|
||||
|
||||
# quick compare tensors
|
||||
def compare(in_tensor):
|
||||
|
||||
out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
|
||||
n1 = in_tensor.numpy()
|
||||
n2 = out_tensor.numpy()[0]
|
||||
print(n1.shape, n1[0, 0, :5])
|
||||
print(n2.shape, n2[0, 0, :5])
|
||||
assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
|
||||
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
|
||||
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x is False])/len(n1.flatten())*100:.4f} %"
|
||||
" element-wise mismatch"
|
||||
)
|
||||
raise Exception("tensors are all good")
|
||||
@@ -300,7 +297,6 @@ def get_from_cache(
|
||||
user_agent=None,
|
||||
local_files_only=False,
|
||||
):
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
@@ -355,7 +351,6 @@ def get_from_cache(
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
# If the download just completed while the lock was activated.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
# Even if returning early like here, the lock will be released.
|
||||
@@ -406,7 +401,6 @@ def get_from_cache(
|
||||
|
||||
|
||||
def url_to_filename(url, etag=None):
|
||||
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
import colorsys
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import matplotlib as mpl
|
||||
import matplotlib.colors as mplc
|
||||
import matplotlib.figure as mplfigure
|
||||
@@ -25,7 +26,6 @@ import numpy as np
|
||||
import torch
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
import cv2
|
||||
from utils import img_tensorize
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from typing import List
|
||||
|
||||
from ltp import LTP
|
||||
|
||||
from transformers.models.bert.tokenization_bert import BertTokenizer
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -43,7 +44,6 @@ from transformers import (
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -22,7 +22,6 @@ import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class MaskedBertConfig(PretrainedConfig):
|
||||
pruning_method="topK",
|
||||
mask_init="constant",
|
||||
mask_scale=0.0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
|
||||
@@ -649,7 +649,10 @@ class MaskedBertModel(MaskedBertPreTrainedModel):
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
outputs = (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
@@ -24,12 +24,12 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
@@ -228,7 +228,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
|
||||
@@ -25,12 +25,12 @@ import timeit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
@@ -236,7 +236,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
|
||||
@@ -264,7 +264,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
||||
|
||||
past: List[torch.Tensor] = []
|
||||
while cur_len < max_length:
|
||||
|
||||
logits, past = self._decoder_forward(input_ids, encoder_output, attention_mask, past)
|
||||
next_token_logits = logits[:, -1, :]
|
||||
|
||||
@@ -303,7 +302,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
||||
decoder_start_token_id,
|
||||
bos_token_id: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
|
||||
decoder_input_ids = (
|
||||
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
|
||||
* decoder_start_token_id
|
||||
@@ -633,7 +631,6 @@ class BARTBeamSearchGenerator(BARTGenerator):
|
||||
def beam_search(
|
||||
self, input_ids, encoder_output, attention_mask, num_beams, max_length, pad_token_id: int, eos_token_id: int
|
||||
):
|
||||
|
||||
batch_size = self.beam_scorer.batch_size
|
||||
|
||||
num_beams = self.beam_scorer.num_beams
|
||||
|
||||
@@ -5,7 +5,6 @@ Code to remove duplicate initializers to reduce ONNX model size.
|
||||
import os
|
||||
|
||||
import numpy
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
|
||||
@@ -22,12 +22,12 @@ import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import onnxruntime
|
||||
import transformers
|
||||
import torch
|
||||
from bart_onnx.generation_onnx import BARTBeamSearchGenerator
|
||||
from bart_onnx.reduce_onnx_size import remove_dup_initializers
|
||||
|
||||
import transformers
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
|
||||
|
||||
@@ -15,13 +15,13 @@
|
||||
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
from modeling_flax_performer_utils import make_fast_softmax_attention
|
||||
|
||||
from transformers.file_utils import add_start_docstrings
|
||||
from transformers.modeling_flax_utils import ACT2FN
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
@@ -366,7 +366,6 @@ class FlaxPerformerModel(FlaxBertPreTrainedModel):
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
@@ -443,7 +442,6 @@ class FlaxPerformerModel(FlaxBertPreTrainedModel):
|
||||
def __call__(
|
||||
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None
|
||||
):
|
||||
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
@@ -30,11 +30,10 @@ import abc
|
||||
import functools
|
||||
from collections.abc import Iterable # pylint: disable=g-importing-member
|
||||
|
||||
import numpy as onp
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
from absl import logging
|
||||
from jax import lax, random
|
||||
|
||||
|
||||
@@ -524,7 +523,6 @@ class FastAttentionviaLowRankDecomposition(FastAttention):
|
||||
deterministic=False,
|
||||
precision=None,
|
||||
):
|
||||
|
||||
assert key.shape[:-1] == value.shape[:-1]
|
||||
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
|
||||
if axis is None:
|
||||
|
||||
@@ -28,18 +28,18 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils
|
||||
from flax.optim import Adam
|
||||
from flax.training import common_utils
|
||||
from flax.training.common_utils import get_metrics
|
||||
from jax.nn import log_softmax
|
||||
from modeling_flax_performer import FlaxPerformerForMaskedLM
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoTokenizer,
|
||||
@@ -632,7 +632,6 @@ if __name__ == "__main__":
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
# ======================== Training ================================
|
||||
# Create sampling rng
|
||||
rng, training_rng, eval_rng = jax.random.split(rng, 3)
|
||||
|
||||
@@ -30,10 +30,10 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from torch import nn
|
||||
from tqdm import trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
|
||||
@@ -345,7 +345,7 @@ def full_text_generation(
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||
|
||||
@@ -463,7 +463,6 @@ def generate_text_pplm(
|
||||
unpert_discrim_loss = 0
|
||||
loss_in_time = []
|
||||
for i in trange(length, ascii=True):
|
||||
|
||||
# Get past/probs for current output, except for last word
|
||||
# Note that GPT takes 2 inputs: past + current_token
|
||||
|
||||
@@ -547,7 +546,6 @@ def generate_text_pplm(
|
||||
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = (pert_probs**gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
|
||||
@@ -26,12 +26,12 @@ import torch
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as data
|
||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from torch import nn
|
||||
from torchtext import data as torchtext_data
|
||||
from torchtext import datasets
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
|
||||
|
||||
@@ -21,19 +21,19 @@ import timeit
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from absl import logging as absl_logging
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pycuda.autoinit # noqa: F401
|
||||
import pycuda.driver as cuda
|
||||
import tensorrt as trt
|
||||
import transformers
|
||||
import torch
|
||||
from absl import logging as absl_logging
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, EvalPrediction, default_data_collator, set_seed
|
||||
from transformers.trainer_pt_utils import nested_concat, nested_truncate
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
@@ -395,7 +395,6 @@ logger.info("Loading ONNX model %s for evaluation", args.onnx_model_path)
|
||||
with open(engine_name, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(
|
||||
f.read()
|
||||
) as engine, engine.create_execution_context() as context:
|
||||
|
||||
# setup for TRT inferrence
|
||||
for i in range(len(input_names)):
|
||||
context.set_binding_shape(i, INPUT_SHAPE)
|
||||
@@ -427,7 +426,6 @@ with open(engine_name, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.d
|
||||
|
||||
all_preds = None
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
|
||||
outputs, infer_time = model_infer(batch, context, d_inputs, h_output0, h_output1, d_output0, d_output1, stream)
|
||||
total_time += infer_time
|
||||
niter += 1
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
|
||||
@@ -16,10 +16,9 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_quantization
|
||||
import pytorch_quantization.nn as quant_nn
|
||||
import torch
|
||||
from pytorch_quantization import calib
|
||||
from pytorch_quantization.tensor_quant import QuantDescriptor
|
||||
|
||||
|
||||
@@ -26,11 +26,12 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import quant_trainer
|
||||
import transformers
|
||||
from datasets import load_dataset, load_metric
|
||||
from trainer_quant_qa import QuestionAnsweringTrainer
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
@@ -46,7 +47,6 @@ from transformers import (
|
||||
from transformers.trainer_utils import SchedulerType, get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
||||
@@ -20,10 +20,10 @@ A subclass of `Trainer` specific to Question-Answering tasks
|
||||
import logging
|
||||
import os
|
||||
|
||||
import quant_trainer
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import quant_trainer
|
||||
from transformers import Trainer, is_torch_tpu_available
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import random
|
||||
|
||||
import ray
|
||||
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
@@ -166,7 +167,6 @@ class RagRayDistributedRetriever(RagRetriever):
|
||||
)
|
||||
|
||||
def re_load(self):
|
||||
|
||||
logger.info("re-loading the new dataset with embeddings")
|
||||
# access from the training loop
|
||||
|
||||
|
||||
@@ -252,14 +252,12 @@ class GenerativeQAModule(BaseTransformer):
|
||||
raise NotImplementedError("pad not implemented")
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
|
||||
global isEmUpdateBusy # use to check whether the entire embedding update process is finished or not
|
||||
global isAddIndexBusy # use to check whether the entire indexing process is finished or not
|
||||
global processes # use to keep threads embedding update processes
|
||||
global threadHandle_index # use to keep thread in embedding indexing processes
|
||||
|
||||
if (self.trainer.global_rank == 0) and (self.custom_config.end2end):
|
||||
|
||||
if (not batch_idx == 0) and (batch_idx % self.custom_config.indexing_freq == 0):
|
||||
free_gpu_list = []
|
||||
nvmlInit()
|
||||
@@ -282,7 +280,6 @@ class GenerativeQAModule(BaseTransformer):
|
||||
has_free_gpus = False
|
||||
|
||||
if (not isEmUpdateBusy) and has_free_gpus:
|
||||
|
||||
model_copy = type(self.model.rag.ctx_encoder)(
|
||||
self.config_dpr
|
||||
) # get a new instance #this will be load in the CPU
|
||||
@@ -336,10 +333,8 @@ class GenerativeQAModule(BaseTransformer):
|
||||
|
||||
# check when index building has started
|
||||
if isAddIndexBusy:
|
||||
|
||||
# check still the index_building process is happening
|
||||
if not threadHandle_index.is_alive():
|
||||
|
||||
logger.info("Merging the dataset shards")
|
||||
saved_dataset_shards = []
|
||||
|
||||
@@ -494,7 +489,6 @@ class GenerativeQAModule(BaseTransformer):
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
if self.custom_config.end2end:
|
||||
|
||||
modified_state_dict = self.model.state_dict()
|
||||
for key in self.model.state_dict().keys():
|
||||
if key.split(".")[1] == "ctx_encoder":
|
||||
@@ -803,7 +797,6 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
multiprocessing.set_start_method("spawn")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
from functools import partial
|
||||
from glob import glob
|
||||
|
||||
import faiss
|
||||
from datasets import Features, Sequence, Value, concatenate_datasets, load_dataset, load_from_disk
|
||||
|
||||
import faiss
|
||||
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ def split_documents(documents):
|
||||
|
||||
|
||||
def embed_update(ctx_encoder, total_processes, device, process_num, shard_dir, csv_path):
|
||||
|
||||
kb_dataset = load_dataset(
|
||||
"csv", data_files=[csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
|
||||
@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
**config_kwargs,
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
@@ -365,7 +365,7 @@ def generic_train(
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
**extra_train_kwargs,
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast, HfArgumentParser
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ def main(
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import finetune_rag
|
||||
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ def consolidate(
|
||||
generator_tokenizer_name_or_path: str = None,
|
||||
question_encoder_tokenizer_name_or_path: str = None,
|
||||
):
|
||||
|
||||
if config_name_or_path is None:
|
||||
config_name_or_path = "facebook/rag-token-base" if model_type == "rag_token" else "facebook/rag-sequence-base"
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import random
|
||||
|
||||
import ray
|
||||
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
**config_kwargs,
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
@@ -356,7 +356,7 @@ def generic_train(
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
**extra_train_kwargs,
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.integrations import is_ray_available
|
||||
|
||||
@@ -6,10 +6,10 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
@@ -56,7 +56,6 @@ def main(
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
@@ -36,7 +36,6 @@ def log_results(result: Dataset, args: Dict[str, str]):
|
||||
target_file = f"log_{dataset_id}_targets.txt"
|
||||
|
||||
with open(pred_file, "w") as p, open(target_file, "w") as t:
|
||||
|
||||
# mapping function to write output
|
||||
def write_to_file(batch, i):
|
||||
p.write(f"{i}" + "\n")
|
||||
|
||||
@@ -25,12 +25,12 @@ import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset, load_metric
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -717,7 +717,6 @@ def main():
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
# use last checkpoint if exist
|
||||
if last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
|
||||
@@ -622,7 +622,6 @@ def main():
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
# use last checkpoint if exist
|
||||
if last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
|
||||
@@ -23,12 +23,12 @@ import shutil
|
||||
from typing import List, Optional
|
||||
|
||||
import datasets
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from finetuning import finetune
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from finetuning import finetune
|
||||
from transformers import AutoConfig, set_seed
|
||||
from transformers.trainer_utils import IntervalStrategy
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ from unittest.mock import patch
|
||||
import pytorch_lightning as pl
|
||||
import timeout_decorator
|
||||
import torch
|
||||
|
||||
from distillation import SummarizationDistiller, distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
|
||||
from transformers import MarianMTModel
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
|
||||
|
||||
@@ -2,6 +2,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from make_student import create_student_by_copying_alternating_layers
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
@@ -5,18 +5,18 @@ import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import lightning_base
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from huggingface_hub import list_models
|
||||
from parameterized import parameterized
|
||||
from run_eval import generate_summaries_or_translations
|
||||
from torch import nn
|
||||
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
||||
from utils import label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
@@ -98,7 +98,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu(self):
|
||||
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import List # noqa: F401
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
|
||||
from torch import nn
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
|
||||
|
||||
@@ -13,10 +13,10 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
|
||||
@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
**config_kwargs,
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
@@ -346,7 +346,7 @@ def generic_train(
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
**extra_train_kwargs,
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def create_student_by_copying_alternating_layers(
|
||||
copy_first_teacher_layers=False,
|
||||
e_layers_to_copy=None,
|
||||
d_layers_to_copy=None,
|
||||
**extra_config_kwargs
|
||||
**extra_config_kwargs,
|
||||
) -> Tuple[PreTrainedModel, List[int], List[int]]:
|
||||
"""Make a student by copying alternating layers from a teacher, save it to save_path.
|
||||
Args:
|
||||
@@ -107,7 +107,6 @@ def create_student_by_copying_alternating_layers(
|
||||
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
|
||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
|
||||
else:
|
||||
|
||||
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
|
||||
init_kwargs = teacher.config.to_diff_dict()
|
||||
|
||||
|
||||
@@ -15,10 +15,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
@@ -115,7 +115,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
prefix="",
|
||||
**dataset_kwargs
|
||||
**dataset_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
|
||||
@@ -32,9 +32,10 @@ import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from filelock import FileLock
|
||||
from wikisql_utils import _TYPE_CONVERTER, retrieve_wikisql_query_answer_tapas
|
||||
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BartForConditionalGeneration,
|
||||
@@ -48,7 +49,6 @@ from transformers import (
|
||||
from transformers.file_utils import is_offline_mode
|
||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
from transformers.utils import check_min_version
|
||||
from wikisql_utils import _TYPE_CONVERTER, retrieve_wikisql_query_answer_tapas
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
||||
@@ -31,9 +31,9 @@ import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from filelock import FileLock
|
||||
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BartForConditionalGeneration,
|
||||
|
||||
@@ -9,9 +9,9 @@ from collections import OrderedDict
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modeling_frcnn import GeneralizedRCNN
|
||||
from processing_image import Preprocess
|
||||
|
||||
from utils import Config
|
||||
|
||||
|
||||
|
||||
@@ -169,7 +169,6 @@ def get_norm(norm, out_channels):
|
||||
|
||||
|
||||
def _create_grid_offsets(size: List[int], stride: int, offset: float, device):
|
||||
|
||||
grid_height, grid_width = size
|
||||
shifts_x = torch.arange(
|
||||
offset * stride,
|
||||
@@ -390,7 +389,6 @@ def assign_boxes_to_levels(
|
||||
canonical_box_size: int,
|
||||
canonical_level: int,
|
||||
):
|
||||
|
||||
box_sizes = torch.sqrt(torch.cat([boxes.area() for boxes in box_lists]))
|
||||
# Eqn.(1) in FPN paper
|
||||
level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
|
||||
@@ -1708,9 +1706,10 @@ class GeneralizedRCNN(nn.Module):
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
assert from_tf, (
|
||||
"We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
|
||||
.format(pretrained_model_name_or_path + ".index")
|
||||
assert (
|
||||
from_tf
|
||||
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
||||
pretrained_model_name_or_path + ".index"
|
||||
)
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
|
||||
@@ -34,14 +34,13 @@ from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import wget
|
||||
from filelock import FileLock
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from yaml import Loader, dump, load
|
||||
|
||||
|
||||
@@ -181,7 +180,6 @@ class Config:
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -225,14 +223,13 @@ class Config:
|
||||
|
||||
# quick compare tensors
|
||||
def compare(in_tensor):
|
||||
|
||||
out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
|
||||
n1 = in_tensor.numpy()
|
||||
n2 = out_tensor.numpy()[0]
|
||||
print(n1.shape, n1[0, 0, :5])
|
||||
print(n2.shape, n2[0, 0, :5])
|
||||
assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
|
||||
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
|
||||
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x is False])/len(n1.flatten())*100:.4f} %"
|
||||
" element-wise mismatch"
|
||||
)
|
||||
raise Exception("tensors are all good")
|
||||
@@ -300,7 +297,6 @@ def get_from_cache(
|
||||
user_agent=None,
|
||||
local_files_only=False,
|
||||
):
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
@@ -355,7 +351,6 @@ def get_from_cache(
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
# If the download just completed while the lock was activated.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
# Even if returning early like here, the lock will be released.
|
||||
@@ -406,7 +401,6 @@ def get_from_cache(
|
||||
|
||||
|
||||
def url_to_filename(url, etag=None):
|
||||
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
import colorsys
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import matplotlib as mpl
|
||||
import matplotlib.colors as mplc
|
||||
import matplotlib.figure as mplfigure
|
||||
@@ -25,7 +26,6 @@ import numpy as np
|
||||
import torch
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
import cv2
|
||||
from utils import img_tensorize
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
|
||||
import imageio
|
||||
import wandb
|
||||
from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan
|
||||
from loaders import load_vqgan
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
|
||||
from transformers import CLIPModel, CLIPTokenizerFast
|
||||
from utils import get_device, get_timestamp, show_pil
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from taming.models.vqgan import VQModel
|
||||
|
||||
@@ -176,7 +176,6 @@ class Wav2Vec2Aligner:
|
||||
out_align.write(str(seg) + "\n")
|
||||
|
||||
def align_data(self, wav_dir, text_file, output_dir):
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
@@ -7,13 +7,13 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import datasets
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from lang_trans import arabic
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
import librosa
|
||||
from lang_trans import arabic
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
|
||||
@@ -4,12 +4,12 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
import librosa
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user