Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -21,19 +21,19 @@ import timeit
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from absl import logging as absl_logging
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pycuda.autoinit # noqa: F401
|
||||
import pycuda.driver as cuda
|
||||
import tensorrt as trt
|
||||
import transformers
|
||||
import torch
|
||||
from absl import logging as absl_logging
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, EvalPrediction, default_data_collator, set_seed
|
||||
from transformers.trainer_pt_utils import nested_concat, nested_truncate
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
@@ -395,7 +395,6 @@ logger.info("Loading ONNX model %s for evaluation", args.onnx_model_path)
|
||||
with open(engine_name, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(
|
||||
f.read()
|
||||
) as engine, engine.create_execution_context() as context:
|
||||
|
||||
# setup for TRT inferrence
|
||||
for i in range(len(input_names)):
|
||||
context.set_binding_shape(i, INPUT_SHAPE)
|
||||
@@ -427,7 +426,6 @@ with open(engine_name, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.d
|
||||
|
||||
all_preds = None
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
|
||||
outputs, infer_time = model_infer(batch, context, d_inputs, h_output0, h_output1, d_output0, d_output1, stream)
|
||||
total_time += infer_time
|
||||
niter += 1
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
|
||||
@@ -16,10 +16,9 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_quantization
|
||||
import pytorch_quantization.nn as quant_nn
|
||||
import torch
|
||||
from pytorch_quantization import calib
|
||||
from pytorch_quantization.tensor_quant import QuantDescriptor
|
||||
|
||||
|
||||
@@ -26,11 +26,12 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import quant_trainer
|
||||
import transformers
|
||||
from datasets import load_dataset, load_metric
|
||||
from trainer_quant_qa import QuestionAnsweringTrainer
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
@@ -46,7 +47,6 @@ from transformers import (
|
||||
from transformers.trainer_utils import SchedulerType, get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
||||
@@ -20,10 +20,10 @@ A subclass of `Trainer` specific to Question-Answering tasks
|
||||
import logging
|
||||
import os
|
||||
|
||||
import quant_trainer
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import quant_trainer
|
||||
from transformers import Trainer, is_torch_tpu_available
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
|
||||
|
||||
Reference in New Issue
Block a user