Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -25,17 +25,7 @@ import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
POOLING_BREAKDOWN = {
|
||||
1: (1, 1),
|
||||
2: (2, 1),
|
||||
3: (3, 1),
|
||||
4: (2, 2),
|
||||
5: (5, 1),
|
||||
6: (3, 2),
|
||||
7: (7, 1),
|
||||
8: (4, 2),
|
||||
9: (3, 3)
|
||||
}
|
||||
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module):
|
||||
return out # BxNx2048
|
||||
|
||||
|
||||
|
||||
class JsonlDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
||||
self.data = [json.loads(l) for l in open(data_path)]
|
||||
@@ -72,7 +61,7 @@ class JsonlDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
|
||||
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
|
||||
sentence = sentence[:self.max_seq_length]
|
||||
sentence = sentence[: self.max_seq_length]
|
||||
|
||||
label = torch.zeros(self.n_classes)
|
||||
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
|
||||
@@ -80,8 +69,13 @@ class JsonlDataset(Dataset):
|
||||
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
||||
image = self.transforms(image)
|
||||
|
||||
return {"image_start_token": start_token, "image_end_token": end_token,
|
||||
"sentence": sentence, "image": image, "label": label}
|
||||
return {
|
||||
"image_start_token": start_token,
|
||||
"image_end_token": end_token,
|
||||
"sentence": sentence,
|
||||
"image": image,
|
||||
"label": label,
|
||||
}
|
||||
|
||||
def get_label_frequencies(self):
|
||||
label_freqs = Counter()
|
||||
@@ -110,10 +104,31 @@ def collate_fn(batch):
|
||||
|
||||
|
||||
def get_mmimdb_labels():
|
||||
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance',
|
||||
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure',
|
||||
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music',
|
||||
'Musical', 'Animation', 'Biography', 'Film-Noir']
|
||||
return [
|
||||
"Crime",
|
||||
"Drama",
|
||||
"Thriller",
|
||||
"Action",
|
||||
"Comedy",
|
||||
"Romance",
|
||||
"Documentary",
|
||||
"Short",
|
||||
"Mystery",
|
||||
"History",
|
||||
"Family",
|
||||
"Adventure",
|
||||
"Fantasy",
|
||||
"Sci-Fi",
|
||||
"Western",
|
||||
"Horror",
|
||||
"Sport",
|
||||
"War",
|
||||
"Music",
|
||||
"Musical",
|
||||
"Animation",
|
||||
"Biography",
|
||||
"Film-Noir",
|
||||
]
|
||||
|
||||
|
||||
def get_image_transforms():
|
||||
@@ -122,9 +137,6 @@ def get_image_transforms():
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.46777044, 0.44531429, 0.40661017],
|
||||
std=[0.12221994, 0.12145835, 0.14380469],
|
||||
),
|
||||
transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user