update formating - make flake8 happy
This commit is contained in:
@@ -19,7 +19,6 @@ from __future__ import absolute_import, division, print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -56,18 +55,14 @@ from transformers import (
|
|||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from transformers import glue_compute_metrics as compute_metrics
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
except:
|
except ImportError:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
@@ -374,7 +369,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@@ -411,7 +406,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
class InputExample(object):
|
class InputExample(object):
|
||||||
@@ -118,7 +117,5 @@ class DataProcessor(object):
|
|||||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
|
||||||
line = list(unicode(cell, "utf-8") for cell in line)
|
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|||||||
Reference in New Issue
Block a user