Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
Aymeric Augustin
2019-12-21 15:46:46 +01:00
parent 63e3827c6b
commit fa84ae26d6
200 changed files with 17452 additions and 12594 deletions

View File

@@ -25,17 +25,7 @@ import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
POOLING_BREAKDOWN = {
1: (1, 1),
2: (2, 1),
3: (3, 1),
4: (2, 2),
5: (5, 1),
6: (3, 2),
7: (7, 1),
8: (4, 2),
9: (3, 3)
}
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
class ImageEncoder(nn.Module):
@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module):
return out # BxNx2048
class JsonlDataset(Dataset):
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
self.data = [json.loads(l) for l in open(data_path)]
@@ -72,7 +61,7 @@ class JsonlDataset(Dataset):
def __getitem__(self, index):
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
sentence = sentence[:self.max_seq_length]
sentence = sentence[: self.max_seq_length]
label = torch.zeros(self.n_classes)
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
@@ -80,8 +69,13 @@ class JsonlDataset(Dataset):
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
image = self.transforms(image)
return {"image_start_token": start_token, "image_end_token": end_token,
"sentence": sentence, "image": image, "label": label}
return {
"image_start_token": start_token,
"image_end_token": end_token,
"sentence": sentence,
"image": image,
"label": label,
}
def get_label_frequencies(self):
label_freqs = Counter()
@@ -110,10 +104,31 @@ def collate_fn(batch):
def get_mmimdb_labels():
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance',
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure',
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music',
'Musical', 'Animation', 'Biography', 'Film-Noir']
return [
"Crime",
"Drama",
"Thriller",
"Action",
"Comedy",
"Romance",
"Documentary",
"Short",
"Mystery",
"History",
"Family",
"Adventure",
"Fantasy",
"Sci-Fi",
"Western",
"Horror",
"Sport",
"War",
"Music",
"Musical",
"Animation",
"Biography",
"Film-Noir",
]
def get_image_transforms():
@@ -122,9 +137,6 @@ def get_image_transforms():
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.46777044, 0.44531429, 0.40661017],
std=[0.12221994, 0.12145835, 0.14380469],
),
transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
]
)