Remove nested lxmert (#9440)
This commit is contained in:
@@ -1,5 +0,0 @@
|
|||||||
# LXMERT DEMO
|
|
||||||
|
|
||||||
1. make a virtualenv: ``virtualenv venv`` and activate ``source venv/bin/activate``
|
|
||||||
2. install reqs: ``pip install -r ./requirements.txt``
|
|
||||||
3. usage is as shown in demo.ipynb
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,149 +0,0 @@
|
|||||||
import getopt
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
# import numpy as np
|
|
||||||
import sys
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modeling_frcnn import GeneralizedRCNN
|
|
||||||
from processing_image import Preprocess
|
|
||||||
from utils import Config
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
USAGE:
|
|
||||||
``python extracting_data.py -i <img_dir> -o <dataset_file>.datasets <batch_size>``
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
TEST = False
|
|
||||||
CONFIG = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
|
|
||||||
DEFAULT_SCHEMA = datasets.Features(
|
|
||||||
OrderedDict(
|
|
||||||
{
|
|
||||||
"attr_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
|
||||||
"attr_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
|
||||||
"boxes": datasets.Array2D((CONFIG.MAX_DETECTIONS, 4), dtype="float32"),
|
|
||||||
"img_id": datasets.Value("int32"),
|
|
||||||
"obj_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
|
||||||
"obj_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
|
||||||
"roi_features": datasets.Array2D((CONFIG.MAX_DETECTIONS, 2048), dtype="float32"),
|
|
||||||
"sizes": datasets.Sequence(length=2, feature=datasets.Value("float32")),
|
|
||||||
"preds_per_image": datasets.Value(dtype="int32"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Extract:
|
|
||||||
def __init__(self, argv=sys.argv[1:]):
|
|
||||||
inputdir = None
|
|
||||||
outputfile = None
|
|
||||||
subset_list = None
|
|
||||||
batch_size = 1
|
|
||||||
opts, args = getopt.getopt(argv, "i:o:b:s", ["inputdir=", "outfile=", "batch_size=", "subset_list="])
|
|
||||||
for opt, arg in opts:
|
|
||||||
if opt in ("-i", "--inputdir"):
|
|
||||||
inputdir = arg
|
|
||||||
elif opt in ("-o", "--outfile"):
|
|
||||||
outputfile = arg
|
|
||||||
elif opt in ("-b", "--batch_size"):
|
|
||||||
batch_size = int(arg)
|
|
||||||
elif opt in ("-s", "--subset_list"):
|
|
||||||
subset_list = arg
|
|
||||||
|
|
||||||
assert inputdir is not None # and os.path.isdir(inputdir), f"{inputdir}"
|
|
||||||
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
|
|
||||||
if subset_list is not None:
|
|
||||||
with open(os.path.realpath(subset_list)) as f:
|
|
||||||
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f)))
|
|
||||||
else:
|
|
||||||
self.subset_list = None
|
|
||||||
|
|
||||||
self.config = CONFIG
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.config.model.device = "cuda"
|
|
||||||
self.inputdir = os.path.realpath(inputdir)
|
|
||||||
self.outputfile = os.path.realpath(outputfile)
|
|
||||||
self.preprocess = Preprocess(self.config)
|
|
||||||
self.model = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.config)
|
|
||||||
self.batch = batch_size if batch_size != 0 else 1
|
|
||||||
self.schema = DEFAULT_SCHEMA
|
|
||||||
|
|
||||||
def _vqa_file_split(self, file):
|
|
||||||
img_id = int(file.split(".")[0].split("_")[-1])
|
|
||||||
filepath = os.path.join(self.inputdir, file)
|
|
||||||
return (img_id, filepath)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def file_generator(self):
|
|
||||||
batch = []
|
|
||||||
for i, file in enumerate(os.listdir(self.inputdir)):
|
|
||||||
if self.subset_list is not None and i not in self.subset_list:
|
|
||||||
continue
|
|
||||||
batch.append(self._vqa_file_split(file))
|
|
||||||
if len(batch) == self.batch:
|
|
||||||
temp = batch
|
|
||||||
batch = []
|
|
||||||
yield list(map(list, zip(*temp)))
|
|
||||||
|
|
||||||
for i in range(1):
|
|
||||||
yield list(map(list, zip(*batch)))
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
# make writer
|
|
||||||
if not TEST:
|
|
||||||
writer = datasets.ArrowWriter(features=self.schema, path=self.outputfile)
|
|
||||||
# do file generator
|
|
||||||
for i, (img_ids, filepaths) in enumerate(self.file_generator):
|
|
||||||
images, sizes, scales_yx = self.preprocess(filepaths)
|
|
||||||
output_dict = self.model(
|
|
||||||
images,
|
|
||||||
sizes,
|
|
||||||
scales_yx=scales_yx,
|
|
||||||
padding="max_detections",
|
|
||||||
max_detections=self.config.MAX_DETECTIONS,
|
|
||||||
pad_value=0,
|
|
||||||
return_tensors="np",
|
|
||||||
location="cpu",
|
|
||||||
)
|
|
||||||
output_dict["boxes"] = output_dict.pop("normalized_boxes")
|
|
||||||
if not TEST:
|
|
||||||
output_dict["img_id"] = np.array(img_ids)
|
|
||||||
batch = self.schema.encode_batch(output_dict)
|
|
||||||
writer.write_batch(batch)
|
|
||||||
if TEST:
|
|
||||||
break
|
|
||||||
# finalizer the writer
|
|
||||||
if not TEST:
|
|
||||||
num_examples, num_bytes = writer.finalize()
|
|
||||||
print(f"Success! You wrote {num_examples} entry(s) and {num_bytes >> 20} mb")
|
|
||||||
|
|
||||||
|
|
||||||
def tryload(stream):
|
|
||||||
try:
|
|
||||||
data = json.load(stream)
|
|
||||||
try:
|
|
||||||
data = list(data.keys())
|
|
||||||
except Exception:
|
|
||||||
data = [d["img_id"] for d in data]
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
data = eval(stream.read())
|
|
||||||
except Exception:
|
|
||||||
data = stream.read().split("\n")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
extract = Extract(sys.argv[1:])
|
|
||||||
extract()
|
|
||||||
if not TEST:
|
|
||||||
dataset = datasets.Dataset.from_file(extract.outputfile)
|
|
||||||
# wala!
|
|
||||||
# print(np.array(dataset[0:2]["roi_features"]).shape)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,147 +0,0 @@
|
|||||||
"""
|
|
||||||
coding=utf-8
|
|
||||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
|
||||||
Adapted From Facebook Inc, Detectron2
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.import copy
|
|
||||||
"""
|
|
||||||
import sys
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from utils import img_tensorize
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeShortestEdge:
|
|
||||||
def __init__(self, short_edge_length, max_size=sys.maxsize):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
short_edge_length (list[min, max])
|
|
||||||
max_size (int): maximum allowed longest edge length.
|
|
||||||
"""
|
|
||||||
self.interp_method = "bilinear"
|
|
||||||
self.max_size = max_size
|
|
||||||
self.short_edge_length = short_edge_length
|
|
||||||
|
|
||||||
def __call__(self, imgs):
|
|
||||||
img_augs = []
|
|
||||||
for img in imgs:
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
# later: provide list and randomly choose index for resize
|
|
||||||
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
|
||||||
if size == 0:
|
|
||||||
return img
|
|
||||||
scale = size * 1.0 / min(h, w)
|
|
||||||
if h < w:
|
|
||||||
newh, neww = size, scale * w
|
|
||||||
else:
|
|
||||||
newh, neww = scale * h, size
|
|
||||||
if max(newh, neww) > self.max_size:
|
|
||||||
scale = self.max_size * 1.0 / max(newh, neww)
|
|
||||||
newh = newh * scale
|
|
||||||
neww = neww * scale
|
|
||||||
neww = int(neww + 0.5)
|
|
||||||
newh = int(newh + 0.5)
|
|
||||||
|
|
||||||
if img.dtype == np.uint8:
|
|
||||||
pil_image = Image.fromarray(img)
|
|
||||||
pil_image = pil_image.resize((neww, newh), Image.BILINEAR)
|
|
||||||
img = np.asarray(pil_image)
|
|
||||||
else:
|
|
||||||
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
|
|
||||||
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
|
|
||||||
img_augs.append(img)
|
|
||||||
|
|
||||||
return img_augs
|
|
||||||
|
|
||||||
|
|
||||||
class Preprocess:
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
|
|
||||||
self.input_format = cfg.INPUT.FORMAT
|
|
||||||
self.size_divisibility = cfg.SIZE_DIVISIBILITY
|
|
||||||
self.pad_value = cfg.PAD_VALUE
|
|
||||||
self.max_image_size = cfg.INPUT.MAX_SIZE_TEST
|
|
||||||
self.device = cfg.MODEL.DEVICE
|
|
||||||
self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
|
||||||
self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
|
||||||
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
|
|
||||||
|
|
||||||
def pad(self, images):
|
|
||||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
|
||||||
image_sizes = [im.shape[-2:] for im in images]
|
|
||||||
images = [
|
|
||||||
F.pad(
|
|
||||||
im,
|
|
||||||
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
|
|
||||||
value=self.pad_value,
|
|
||||||
)
|
|
||||||
for size, im in zip(image_sizes, images)
|
|
||||||
]
|
|
||||||
|
|
||||||
return torch.stack(images), torch.tensor(image_sizes)
|
|
||||||
|
|
||||||
def __call__(self, images, single_image=False):
|
|
||||||
with torch.no_grad():
|
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
if single_image:
|
|
||||||
assert len(images) == 1
|
|
||||||
for i in range(len(images)):
|
|
||||||
if isinstance(images[i], torch.Tensor):
|
|
||||||
images.insert(i, images.pop(i).to(self.device).float())
|
|
||||||
elif not isinstance(images[i], torch.Tensor):
|
|
||||||
images.insert(
|
|
||||||
i,
|
|
||||||
torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format))
|
|
||||||
.to(self.device)
|
|
||||||
.float(),
|
|
||||||
)
|
|
||||||
# resize smallest edge
|
|
||||||
raw_sizes = torch.tensor([im.shape[:2] for im in images])
|
|
||||||
images = self.aug(images)
|
|
||||||
# transpose images and convert to torch tensors
|
|
||||||
# images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images]
|
|
||||||
# now normalize before pad to avoid useless arithmetic
|
|
||||||
images = [self.normalizer(x) for x in images]
|
|
||||||
# now pad them to do the following operations
|
|
||||||
images, sizes = self.pad(images)
|
|
||||||
# Normalize
|
|
||||||
|
|
||||||
if self.size_divisibility > 0:
|
|
||||||
raise NotImplementedError()
|
|
||||||
# pad
|
|
||||||
scales_yx = torch.true_divide(raw_sizes, sizes)
|
|
||||||
if single_image:
|
|
||||||
return images[0], sizes[0], scales_yx[0]
|
|
||||||
else:
|
|
||||||
return images, sizes, scales_yx
|
|
||||||
|
|
||||||
|
|
||||||
def _scale_box(boxes, scale_yx):
|
|
||||||
boxes[:, 0::2] *= scale_yx[:, 1]
|
|
||||||
boxes[:, 1::2] *= scale_yx[:, 0]
|
|
||||||
return boxes
|
|
||||||
|
|
||||||
|
|
||||||
def _clip_box(tensor, box_size: Tuple[int, int]):
|
|
||||||
assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
|
|
||||||
h, w = box_size
|
|
||||||
tensor[:, 0].clamp_(min=0, max=w)
|
|
||||||
tensor[:, 1].clamp_(min=0, max=h)
|
|
||||||
tensor[:, 2].clamp_(min=0, max=w)
|
|
||||||
tensor[:, 3].clamp_(min=0, max=h)
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
appdirs==1.4.3
|
|
||||||
argon2-cffi==20.1.0
|
|
||||||
async-generator==1.10
|
|
||||||
attrs==20.2.0
|
|
||||||
backcall==0.2.0
|
|
||||||
bleach==3.1.5
|
|
||||||
CacheControl==0.12.6
|
|
||||||
certifi==2020.6.20
|
|
||||||
cffi==1.14.2
|
|
||||||
chardet==3.0.4
|
|
||||||
click==7.1.2
|
|
||||||
colorama==0.4.3
|
|
||||||
contextlib2==0.6.0
|
|
||||||
cycler==0.10.0
|
|
||||||
datasets==1.0.0
|
|
||||||
decorator==4.4.2
|
|
||||||
defusedxml==0.6.0
|
|
||||||
dill==0.3.2
|
|
||||||
distlib==0.3.0
|
|
||||||
distro==1.4.0
|
|
||||||
entrypoints==0.3
|
|
||||||
filelock==3.0.12
|
|
||||||
future==0.18.2
|
|
||||||
html5lib==1.0.1
|
|
||||||
idna==2.8
|
|
||||||
ipaddr==2.2.0
|
|
||||||
ipykernel==5.3.4
|
|
||||||
ipython
|
|
||||||
ipython-genutils==0.2.0
|
|
||||||
ipywidgets==7.5.1
|
|
||||||
jedi==0.17.2
|
|
||||||
Jinja2==2.11.2
|
|
||||||
joblib==0.16.0
|
|
||||||
jsonschema==3.2.0
|
|
||||||
jupyter==1.0.0
|
|
||||||
jupyter-client==6.1.7
|
|
||||||
jupyter-console==6.2.0
|
|
||||||
jupyter-core==4.6.3
|
|
||||||
jupyterlab-pygments==0.1.1
|
|
||||||
kiwisolver==1.2.0
|
|
||||||
lockfile==0.12.2
|
|
||||||
MarkupSafe==1.1.1
|
|
||||||
matplotlib==3.3.1
|
|
||||||
mistune==0.8.4
|
|
||||||
msgpack==0.6.2
|
|
||||||
nbclient==0.5.0
|
|
||||||
nbconvert==6.0.1
|
|
||||||
nbformat==5.0.7
|
|
||||||
nest-asyncio==1.4.0
|
|
||||||
notebook==6.1.5
|
|
||||||
numpy==1.19.2
|
|
||||||
opencv-python==4.4.0.42
|
|
||||||
packaging==20.3
|
|
||||||
pandas==1.1.2
|
|
||||||
pandocfilters==1.4.2
|
|
||||||
parso==0.7.1
|
|
||||||
pep517==0.8.2
|
|
||||||
pexpect==4.8.0
|
|
||||||
pickleshare==0.7.5
|
|
||||||
Pillow==7.2.0
|
|
||||||
progress==1.5
|
|
||||||
prometheus-client==0.8.0
|
|
||||||
prompt-toolkit==3.0.7
|
|
||||||
ptyprocess==0.6.0
|
|
||||||
pyaml==20.4.0
|
|
||||||
pyarrow==1.0.1
|
|
||||||
pycparser==2.20
|
|
||||||
Pygments==2.6.1
|
|
||||||
pyparsing==2.4.6
|
|
||||||
pyrsistent==0.16.0
|
|
||||||
python-dateutil==2.8.1
|
|
||||||
pytoml==0.1.21
|
|
||||||
pytz==2020.1
|
|
||||||
PyYAML==5.3.1
|
|
||||||
pyzmq==19.0.2
|
|
||||||
qtconsole==4.7.7
|
|
||||||
QtPy==1.9.0
|
|
||||||
regex==2020.7.14
|
|
||||||
requests==2.22.0
|
|
||||||
retrying==1.3.3
|
|
||||||
sacremoses==0.0.43
|
|
||||||
Send2Trash==1.5.0
|
|
||||||
sentencepiece==0.1.91
|
|
||||||
six==1.14.0
|
|
||||||
terminado==0.8.3
|
|
||||||
testpath==0.4.4
|
|
||||||
tokenizers==0.8.1rc2
|
|
||||||
torch==1.6.0
|
|
||||||
torchvision==0.7.0
|
|
||||||
tornado==6.0.4
|
|
||||||
tqdm==4.48.2
|
|
||||||
traitlets
|
|
||||||
transformers==3.5.1
|
|
||||||
urllib3==1.25.8
|
|
||||||
wcwidth==0.2.5
|
|
||||||
webencodings==0.5.1
|
|
||||||
wget==3.2
|
|
||||||
widgetsnbextension==3.5.1
|
|
||||||
xxhash==2.0.0
|
|
||||||
@@ -1,559 +0,0 @@
|
|||||||
"""
|
|
||||||
coding=utf-8
|
|
||||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :)
|
|
||||||
Adapted From Facebook Inc, Detectron2
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.import copy
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import fnmatch
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import pickle as pkl
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import tarfile
|
|
||||||
import tempfile
|
|
||||||
from collections import OrderedDict
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import partial
|
|
||||||
from hashlib import sha256
|
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from zipfile import ZipFile, is_zipfile
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import requests
|
|
||||||
import wget
|
|
||||||
from filelock import FileLock
|
|
||||||
from yaml import Loader, dump, load
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
_torch_available = True
|
|
||||||
except ImportError:
|
|
||||||
_torch_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.hub import _get_torch_home
|
|
||||||
|
|
||||||
torch_cache_home = _get_torch_home()
|
|
||||||
except ImportError:
|
|
||||||
torch_cache_home = os.path.expanduser(
|
|
||||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
|
||||||
)
|
|
||||||
|
|
||||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
|
||||||
|
|
||||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
|
||||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
|
||||||
PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1])
|
|
||||||
CONFIG = os.path.join(PATH, "config.yaml")
|
|
||||||
ATTRIBUTES = os.path.join(PATH, "attributes.txt")
|
|
||||||
OBJECTS = os.path.join(PATH, "objects.txt")
|
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
|
||||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
|
||||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
|
||||||
CONFIG_NAME = "config.yaml"
|
|
||||||
|
|
||||||
|
|
||||||
def load_labels(objs=OBJECTS, attrs=ATTRIBUTES):
|
|
||||||
vg_classes = []
|
|
||||||
with open(objs) as f:
|
|
||||||
for object in f.readlines():
|
|
||||||
vg_classes.append(object.split(",")[0].lower().strip())
|
|
||||||
|
|
||||||
vg_attrs = []
|
|
||||||
with open(attrs) as f:
|
|
||||||
for object in f.readlines():
|
|
||||||
vg_attrs.append(object.split(",")[0].lower().strip())
|
|
||||||
return vg_classes, vg_attrs
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(ckp):
|
|
||||||
r = OrderedDict()
|
|
||||||
with open(ckp, "rb") as f:
|
|
||||||
ckp = pkl.load(f)["model"]
|
|
||||||
for k in copy.deepcopy(list(ckp.keys())):
|
|
||||||
v = ckp.pop(k)
|
|
||||||
if isinstance(v, np.ndarray):
|
|
||||||
v = torch.tensor(v)
|
|
||||||
else:
|
|
||||||
assert isinstance(v, torch.tensor), type(v)
|
|
||||||
r[k] = v
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
_pointer = {}
|
|
||||||
|
|
||||||
def __init__(self, dictionary: dict, name: str = "root", level=0):
|
|
||||||
self._name = name
|
|
||||||
self._level = level
|
|
||||||
d = {}
|
|
||||||
for k, v in dictionary.items():
|
|
||||||
if v is None:
|
|
||||||
raise ValueError()
|
|
||||||
k = copy.deepcopy(k)
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
if isinstance(v, dict):
|
|
||||||
v = Config(v, name=k, level=level + 1)
|
|
||||||
d[k] = v
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
self._pointer = d
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(list((self._pointer.keys())))
|
|
||||||
|
|
||||||
def __setattr__(self, key, val):
|
|
||||||
self.__dict__[key] = val
|
|
||||||
self.__dict__[key.upper()] = val
|
|
||||||
levels = key.split(".")
|
|
||||||
last_level = len(levels) - 1
|
|
||||||
pointer = self._pointer
|
|
||||||
if len(levels) > 1:
|
|
||||||
for i, l in enumerate(levels):
|
|
||||||
if hasattr(self, l) and isinstance(getattr(self, l), Config):
|
|
||||||
setattr(getattr(self, l), ".".join(levels[i:]), val)
|
|
||||||
if l == last_level:
|
|
||||||
pointer[l] = val
|
|
||||||
else:
|
|
||||||
pointer = pointer[l]
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return self._pointer
|
|
||||||
|
|
||||||
def dump_yaml(self, data, file_name):
|
|
||||||
with open(f"{file_name}", "w") as stream:
|
|
||||||
dump(data, stream)
|
|
||||||
|
|
||||||
def dump_json(self, data, file_name):
|
|
||||||
with open(f"{file_name}", "w") as stream:
|
|
||||||
json.dump(data, stream)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_yaml(config):
|
|
||||||
with open(config) as stream:
|
|
||||||
data = load(stream, Loader=Loader)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
t = " "
|
|
||||||
if self._name != "root":
|
|
||||||
r = f"{t * (self._level-1)}{self._name}:\n"
|
|
||||||
else:
|
|
||||||
r = ""
|
|
||||||
level = self._level
|
|
||||||
for i, (k, v) in enumerate(self._pointer.items()):
|
|
||||||
if isinstance(v, Config):
|
|
||||||
r += f"{t * (self._level)}{v}\n"
|
|
||||||
self._level += 1
|
|
||||||
else:
|
|
||||||
r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n"
|
|
||||||
self._level = level
|
|
||||||
return r[:-1]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
|
||||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
|
||||||
return cls(config_dict)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
|
|
||||||
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
|
||||||
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
|
||||||
config_file = pretrained_model_name_or_path
|
|
||||||
else:
|
|
||||||
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load from URL or cache if already cached
|
|
||||||
resolved_config_file = cached_path(
|
|
||||||
config_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
# Load config dict
|
|
||||||
if resolved_config_file is None:
|
|
||||||
raise EnvironmentError
|
|
||||||
|
|
||||||
config_file = Config.load_yaml(resolved_config_file)
|
|
||||||
|
|
||||||
except EnvironmentError:
|
|
||||||
msg = "Can't load config for"
|
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
if resolved_config_file == config_file:
|
|
||||||
print("loading configuration file from path")
|
|
||||||
else:
|
|
||||||
print("loading configuration file cache")
|
|
||||||
|
|
||||||
return Config.load_yaml(resolved_config_file), kwargs
|
|
||||||
|
|
||||||
|
|
||||||
# quick compare tensors
|
|
||||||
def compare(in_tensor):
|
|
||||||
|
|
||||||
out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
|
|
||||||
n1 = in_tensor.numpy()
|
|
||||||
n2 = out_tensor.numpy()[0]
|
|
||||||
print(n1.shape, n1[0, 0, :5])
|
|
||||||
print(n2.shape, n2[0, 0, :5])
|
|
||||||
assert np.allclose(
|
|
||||||
n1, n2, rtol=0.01, atol=0.1
|
|
||||||
), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
|
|
||||||
raise Exception("tensors are all good")
|
|
||||||
|
|
||||||
# Hugging face functions below
|
|
||||||
|
|
||||||
|
|
||||||
def is_remote_url(url_or_filename):
|
|
||||||
parsed = urlparse(url_or_filename)
|
|
||||||
return parsed.scheme in ("http", "https")
|
|
||||||
|
|
||||||
|
|
||||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
|
|
||||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
|
|
||||||
legacy_format = "/" not in model_id
|
|
||||||
if legacy_format:
|
|
||||||
return f"{endpoint}/{model_id}-{filename}"
|
|
||||||
else:
|
|
||||||
return f"{endpoint}/{model_id}/{filename}"
|
|
||||||
|
|
||||||
|
|
||||||
def http_get(
|
|
||||||
url,
|
|
||||||
temp_file,
|
|
||||||
proxies=None,
|
|
||||||
resume_size=0,
|
|
||||||
user_agent=None,
|
|
||||||
):
|
|
||||||
ua = "python/{}".format(sys.version.split()[0])
|
|
||||||
if _torch_available:
|
|
||||||
ua += "; torch/{}".format(torch.__version__)
|
|
||||||
if isinstance(user_agent, dict):
|
|
||||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
|
||||||
elif isinstance(user_agent, str):
|
|
||||||
ua += "; " + user_agent
|
|
||||||
headers = {"user-agent": ua}
|
|
||||||
if resume_size > 0:
|
|
||||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
|
||||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
|
||||||
if response.status_code == 416: # Range not satisfiable
|
|
||||||
return
|
|
||||||
content_length = response.headers.get("Content-Length")
|
|
||||||
total = resume_size + int(content_length) if content_length is not None else None
|
|
||||||
progress = tqdm(
|
|
||||||
unit="B",
|
|
||||||
unit_scale=True,
|
|
||||||
total=total,
|
|
||||||
initial=resume_size,
|
|
||||||
desc="Downloading",
|
|
||||||
)
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk: # filter out keep-alive new chunks
|
|
||||||
progress.update(len(chunk))
|
|
||||||
temp_file.write(chunk)
|
|
||||||
progress.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(
|
|
||||||
url,
|
|
||||||
cache_dir=None,
|
|
||||||
force_download=False,
|
|
||||||
proxies=None,
|
|
||||||
etag_timeout=10,
|
|
||||||
resume_download=False,
|
|
||||||
user_agent=None,
|
|
||||||
local_files_only=False,
|
|
||||||
):
|
|
||||||
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
etag = None
|
|
||||||
if not local_files_only:
|
|
||||||
try:
|
|
||||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
|
||||||
if response.status_code == 200:
|
|
||||||
etag = response.headers.get("ETag")
|
|
||||||
except (EnvironmentError, requests.exceptions.Timeout):
|
|
||||||
# etag is already None
|
|
||||||
pass
|
|
||||||
|
|
||||||
filename = url_to_filename(url, etag)
|
|
||||||
|
|
||||||
# get cache path to put the file
|
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
|
||||||
|
|
||||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
|
||||||
# try to get the last downloaded one
|
|
||||||
if etag is None:
|
|
||||||
if os.path.exists(cache_path):
|
|
||||||
return cache_path
|
|
||||||
else:
|
|
||||||
matching_files = [
|
|
||||||
file
|
|
||||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
|
||||||
if not file.endswith(".json") and not file.endswith(".lock")
|
|
||||||
]
|
|
||||||
if len(matching_files) > 0:
|
|
||||||
return os.path.join(cache_dir, matching_files[-1])
|
|
||||||
else:
|
|
||||||
# If files cannot be found and local_files_only=True,
|
|
||||||
# the models might've been found if local_files_only=False
|
|
||||||
# Notify the user about that
|
|
||||||
if local_files_only:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
|
||||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
|
||||||
" to False."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# From now on, etag is not None.
|
|
||||||
if os.path.exists(cache_path) and not force_download:
|
|
||||||
return cache_path
|
|
||||||
|
|
||||||
# Prevent parallel downloads of the same file with a lock.
|
|
||||||
lock_path = cache_path + ".lock"
|
|
||||||
with FileLock(lock_path):
|
|
||||||
|
|
||||||
# If the download just completed while the lock was activated.
|
|
||||||
if os.path.exists(cache_path) and not force_download:
|
|
||||||
# Even if returning early like here, the lock will be released.
|
|
||||||
return cache_path
|
|
||||||
|
|
||||||
if resume_download:
|
|
||||||
incomplete_path = cache_path + ".incomplete"
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _resumable_file_manager():
|
|
||||||
with open(incomplete_path, "a+b") as f:
|
|
||||||
yield f
|
|
||||||
|
|
||||||
temp_file_manager = _resumable_file_manager
|
|
||||||
if os.path.exists(incomplete_path):
|
|
||||||
resume_size = os.stat(incomplete_path).st_size
|
|
||||||
else:
|
|
||||||
resume_size = 0
|
|
||||||
else:
|
|
||||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
|
||||||
resume_size = 0
|
|
||||||
|
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
|
||||||
with temp_file_manager() as temp_file:
|
|
||||||
print(
|
|
||||||
"%s not found in cache or force_download set to True, downloading to %s",
|
|
||||||
url,
|
|
||||||
temp_file.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
http_get(
|
|
||||||
url,
|
|
||||||
temp_file,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_size=resume_size,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
os.replace(temp_file.name, cache_path)
|
|
||||||
|
|
||||||
meta = {"url": url, "etag": etag}
|
|
||||||
meta_path = cache_path + ".json"
|
|
||||||
with open(meta_path, "w") as meta_file:
|
|
||||||
json.dump(meta, meta_file)
|
|
||||||
|
|
||||||
return cache_path
|
|
||||||
|
|
||||||
|
|
||||||
def url_to_filename(url, etag=None):
|
|
||||||
|
|
||||||
url_bytes = url.encode("utf-8")
|
|
||||||
url_hash = sha256(url_bytes)
|
|
||||||
filename = url_hash.hexdigest()
|
|
||||||
|
|
||||||
if etag:
|
|
||||||
etag_bytes = etag.encode("utf-8")
|
|
||||||
etag_hash = sha256(etag_bytes)
|
|
||||||
filename += "." + etag_hash.hexdigest()
|
|
||||||
|
|
||||||
if url.endswith(".h5"):
|
|
||||||
filename += ".h5"
|
|
||||||
|
|
||||||
return filename
|
|
||||||
|
|
||||||
|
|
||||||
def cached_path(
|
|
||||||
url_or_filename,
|
|
||||||
cache_dir=None,
|
|
||||||
force_download=False,
|
|
||||||
proxies=None,
|
|
||||||
resume_download=False,
|
|
||||||
user_agent=None,
|
|
||||||
extract_compressed_file=False,
|
|
||||||
force_extract=False,
|
|
||||||
local_files_only=False,
|
|
||||||
):
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
|
||||||
if isinstance(url_or_filename, Path):
|
|
||||||
url_or_filename = str(url_or_filename)
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
if is_remote_url(url_or_filename):
|
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
|
||||||
output_path = get_from_cache(
|
|
||||||
url_or_filename,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
user_agent=user_agent,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
elif os.path.exists(url_or_filename):
|
|
||||||
# File, and it exists.
|
|
||||||
output_path = url_or_filename
|
|
||||||
elif urlparse(url_or_filename).scheme == "":
|
|
||||||
# File, but it doesn't exist.
|
|
||||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
|
||||||
else:
|
|
||||||
# Something unknown
|
|
||||||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
|
||||||
|
|
||||||
if extract_compressed_file:
|
|
||||||
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
# Path where we extract compressed archives
|
|
||||||
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
|
||||||
output_dir, output_file = os.path.split(output_path)
|
|
||||||
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
|
||||||
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
|
||||||
|
|
||||||
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
|
||||||
return output_path_extracted
|
|
||||||
|
|
||||||
# Prevent parallel extractions
|
|
||||||
lock_path = output_path + ".lock"
|
|
||||||
with FileLock(lock_path):
|
|
||||||
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
|
||||||
os.makedirs(output_path_extracted)
|
|
||||||
if is_zipfile(output_path):
|
|
||||||
with ZipFile(output_path, "r") as zip_file:
|
|
||||||
zip_file.extractall(output_path_extracted)
|
|
||||||
zip_file.close()
|
|
||||||
elif tarfile.is_tarfile(output_path):
|
|
||||||
tar_file = tarfile.open(output_path)
|
|
||||||
tar_file.extractall(output_path_extracted)
|
|
||||||
tar_file.close()
|
|
||||||
else:
|
|
||||||
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
|
||||||
|
|
||||||
return output_path_extracted
|
|
||||||
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_data(query, delim=","):
|
|
||||||
assert isinstance(query, str)
|
|
||||||
if os.path.isfile(query):
|
|
||||||
with open(query) as f:
|
|
||||||
data = eval(f.read())
|
|
||||||
else:
|
|
||||||
req = requests.get(query)
|
|
||||||
try:
|
|
||||||
data = requests.json()
|
|
||||||
except Exception:
|
|
||||||
data = req.content.decode()
|
|
||||||
assert data is not None, "could not connect"
|
|
||||||
try:
|
|
||||||
data = eval(data)
|
|
||||||
except Exception:
|
|
||||||
data = data.split("\n")
|
|
||||||
req.close()
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_url(url):
|
|
||||||
response = requests.get(url)
|
|
||||||
img = np.array(Image.open(BytesIO(response.content)))
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# to load legacy frcnn checkpoint from detectron
|
|
||||||
def load_frcnn_pkl_from_url(url):
|
|
||||||
fn = url.split("/")[-1]
|
|
||||||
if fn not in os.listdir(os.getcwd()):
|
|
||||||
wget.download(url)
|
|
||||||
with open(fn, "rb") as stream:
|
|
||||||
weights = pkl.load(stream)
|
|
||||||
model = weights.pop("model")
|
|
||||||
new = {}
|
|
||||||
for k, v in model.items():
|
|
||||||
new[k] = torch.from_numpy(v)
|
|
||||||
if "running_var" in k:
|
|
||||||
zero = torch.Tensor([0])
|
|
||||||
k2 = k.replace("running_var", "num_batches_tracked")
|
|
||||||
new[k2] = zero
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
def get_demo_path():
|
|
||||||
print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb")
|
|
||||||
|
|
||||||
|
|
||||||
def img_tensorize(im, input_format="RGB"):
|
|
||||||
assert isinstance(im, str)
|
|
||||||
if os.path.isfile(im):
|
|
||||||
img = cv2.imread(im)
|
|
||||||
else:
|
|
||||||
img = get_image_from_url(im)
|
|
||||||
assert img is not None, f"could not connect to: {im}"
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
if input_format == "RGB":
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(images, batch=1):
|
|
||||||
return (images[i : i + batch] for i in range(0, len(images), batch))
|
|
||||||
@@ -1,499 +0,0 @@
|
|||||||
"""
|
|
||||||
coding=utf-8
|
|
||||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
|
||||||
Adapted From Facebook Inc, Detectron2
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.import copy
|
|
||||||
"""
|
|
||||||
import colorsys
|
|
||||||
import io
|
|
||||||
|
|
||||||
import matplotlib as mpl
|
|
||||||
import matplotlib.colors as mplc
|
|
||||||
import matplotlib.figure as mplfigure
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from utils import img_tensorize
|
|
||||||
|
|
||||||
|
|
||||||
_SMALL_OBJ = 1000
|
|
||||||
|
|
||||||
|
|
||||||
class SingleImageViz:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img,
|
|
||||||
scale=1.2,
|
|
||||||
edgecolor="g",
|
|
||||||
alpha=0.5,
|
|
||||||
linestyle="-",
|
|
||||||
saveas="test_out.jpg",
|
|
||||||
rgb=True,
|
|
||||||
pynb=False,
|
|
||||||
id2obj=None,
|
|
||||||
id2attr=None,
|
|
||||||
pad=0.7,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
img: an RGB image of shape (H, W, 3).
|
|
||||||
"""
|
|
||||||
if isinstance(img, torch.Tensor):
|
|
||||||
img = img.numpy().astype("np.uint8")
|
|
||||||
if isinstance(img, str):
|
|
||||||
img = img_tensorize(img)
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
|
|
||||||
width, height = img.shape[1], img.shape[0]
|
|
||||||
fig = mplfigure.Figure(frameon=False)
|
|
||||||
dpi = fig.get_dpi()
|
|
||||||
width_in = (width * scale + 1e-2) / dpi
|
|
||||||
height_in = (height * scale + 1e-2) / dpi
|
|
||||||
fig.set_size_inches(width_in, height_in)
|
|
||||||
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
|
||||||
ax.axis("off")
|
|
||||||
ax.set_xlim(0.0, width)
|
|
||||||
ax.set_ylim(height)
|
|
||||||
|
|
||||||
self.saveas = saveas
|
|
||||||
self.rgb = rgb
|
|
||||||
self.pynb = pynb
|
|
||||||
self.img = img
|
|
||||||
self.edgecolor = edgecolor
|
|
||||||
self.alpha = 0.5
|
|
||||||
self.linestyle = linestyle
|
|
||||||
self.font_size = int(np.sqrt(min(height, width)) * scale // 3)
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.scale = scale
|
|
||||||
self.fig = fig
|
|
||||||
self.ax = ax
|
|
||||||
self.pad = pad
|
|
||||||
self.id2obj = id2obj
|
|
||||||
self.id2attr = id2attr
|
|
||||||
self.canvas = FigureCanvasAgg(fig)
|
|
||||||
|
|
||||||
def add_box(self, box, color=None):
|
|
||||||
if color is None:
|
|
||||||
color = self.edgecolor
|
|
||||||
(x0, y0, x1, y1) = box
|
|
||||||
width = x1 - x0
|
|
||||||
height = y1 - y0
|
|
||||||
self.ax.add_patch(
|
|
||||||
mpl.patches.Rectangle(
|
|
||||||
(x0, y0),
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
fill=False,
|
|
||||||
edgecolor=color,
|
|
||||||
linewidth=self.font_size // 3,
|
|
||||||
alpha=self.alpha,
|
|
||||||
linestyle=self.linestyle,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None):
|
|
||||||
if len(boxes.shape) > 2:
|
|
||||||
boxes = boxes[0]
|
|
||||||
if len(obj_ids.shape) > 1:
|
|
||||||
obj_ids = obj_ids[0]
|
|
||||||
if len(obj_scores.shape) > 1:
|
|
||||||
obj_scores = obj_scores[0]
|
|
||||||
if len(attr_ids.shape) > 1:
|
|
||||||
attr_ids = attr_ids[0]
|
|
||||||
if len(attr_scores.shape) > 1:
|
|
||||||
attr_scores = attr_scores[0]
|
|
||||||
if isinstance(boxes, torch.Tensor):
|
|
||||||
boxes = boxes.numpy()
|
|
||||||
if isinstance(boxes, list):
|
|
||||||
boxes = np.array(boxes)
|
|
||||||
assert isinstance(boxes, np.ndarray)
|
|
||||||
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
|
||||||
sorted_idxs = np.argsort(-areas).tolist()
|
|
||||||
boxes = boxes[sorted_idxs] if boxes is not None else None
|
|
||||||
obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None
|
|
||||||
obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None
|
|
||||||
attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None
|
|
||||||
attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None
|
|
||||||
|
|
||||||
assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))]
|
|
||||||
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
|
||||||
if obj_ids is not None:
|
|
||||||
labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores)
|
|
||||||
for i in range(len(boxes)):
|
|
||||||
color = assigned_colors[i]
|
|
||||||
self.add_box(boxes[i], color)
|
|
||||||
self.draw_labels(labels[i], boxes[i], color)
|
|
||||||
|
|
||||||
def draw_labels(self, label, box, color):
|
|
||||||
x0, y0, x1, y1 = box
|
|
||||||
text_pos = (x0, y0)
|
|
||||||
instance_area = (y1 - y0) * (x1 - x0)
|
|
||||||
small = _SMALL_OBJ * self.scale
|
|
||||||
if instance_area < small or y1 - y0 < 40 * self.scale:
|
|
||||||
if y1 >= self.height - 5:
|
|
||||||
text_pos = (x1, y0)
|
|
||||||
else:
|
|
||||||
text_pos = (x0, y1)
|
|
||||||
|
|
||||||
height_ratio = (y1 - y0) / np.sqrt(self.height * self.width)
|
|
||||||
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
|
||||||
font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
|
||||||
font_size *= 0.75 * self.font_size
|
|
||||||
|
|
||||||
self.draw_text(
|
|
||||||
text=label,
|
|
||||||
position=text_pos,
|
|
||||||
color=lighter_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
def draw_text(
|
|
||||||
self,
|
|
||||||
text,
|
|
||||||
position,
|
|
||||||
color="g",
|
|
||||||
ha="left",
|
|
||||||
):
|
|
||||||
rotation = 0
|
|
||||||
font_size = self.font_size
|
|
||||||
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
|
||||||
color[np.argmax(color)] = max(0.8, np.max(color))
|
|
||||||
bbox = {
|
|
||||||
"facecolor": "black",
|
|
||||||
"alpha": self.alpha,
|
|
||||||
"pad": self.pad,
|
|
||||||
"edgecolor": "none",
|
|
||||||
}
|
|
||||||
x, y = position
|
|
||||||
self.ax.text(
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
text,
|
|
||||||
size=font_size * self.scale,
|
|
||||||
family="sans-serif",
|
|
||||||
bbox=bbox,
|
|
||||||
verticalalignment="top",
|
|
||||||
horizontalalignment=ha,
|
|
||||||
color=color,
|
|
||||||
zorder=10,
|
|
||||||
rotation=rotation,
|
|
||||||
)
|
|
||||||
|
|
||||||
def save(self, saveas=None):
|
|
||||||
if saveas is None:
|
|
||||||
saveas = self.saveas
|
|
||||||
if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"):
|
|
||||||
cv2.imwrite(
|
|
||||||
saveas,
|
|
||||||
self._get_buffer()[:, :, ::-1],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.fig.savefig(saveas)
|
|
||||||
|
|
||||||
def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores):
|
|
||||||
labels = [self.id2obj[i] for i in classes]
|
|
||||||
attr_labels = [self.id2attr[i] for i in attr_classes]
|
|
||||||
labels = [
|
|
||||||
f"{label} {score:.2f} {attr} {attr_score:.2f}"
|
|
||||||
for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores)
|
|
||||||
]
|
|
||||||
return labels
|
|
||||||
|
|
||||||
def _create_text_labels(self, classes, scores):
|
|
||||||
labels = [self.id2obj[i] for i in classes]
|
|
||||||
if scores is not None:
|
|
||||||
if labels is None:
|
|
||||||
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
|
||||||
else:
|
|
||||||
labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)]
|
|
||||||
return labels
|
|
||||||
|
|
||||||
def _random_color(self, maximum=255):
|
|
||||||
idx = np.random.randint(0, len(_COLORS))
|
|
||||||
ret = _COLORS[idx] * maximum
|
|
||||||
if not self.rgb:
|
|
||||||
ret = ret[::-1]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _get_buffer(self):
|
|
||||||
if not self.pynb:
|
|
||||||
s, (width, height) = self.canvas.print_to_buffer()
|
|
||||||
if (width, height) != (self.width, self.height):
|
|
||||||
img = cv2.resize(self.img, (width, height))
|
|
||||||
else:
|
|
||||||
img = self.img
|
|
||||||
else:
|
|
||||||
buf = io.BytesIO() # works for cairo backend
|
|
||||||
self.canvas.print_rgba(buf)
|
|
||||||
width, height = self.width, self.height
|
|
||||||
s = buf.getvalue()
|
|
||||||
img = self.img
|
|
||||||
|
|
||||||
buffer = np.frombuffer(s, dtype="uint8")
|
|
||||||
img_rgba = buffer.reshape(height, width, 4)
|
|
||||||
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import numexpr as ne # fuse them with numexpr
|
|
||||||
|
|
||||||
visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
|
|
||||||
except ImportError:
|
|
||||||
alpha = alpha.astype("float32") / 255.0
|
|
||||||
visualized_image = img * (1 - alpha) + rgb * alpha
|
|
||||||
|
|
||||||
return visualized_image.astype("uint8")
|
|
||||||
|
|
||||||
def _change_color_brightness(self, color, brightness_factor):
|
|
||||||
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
|
||||||
color = mplc.to_rgb(color)
|
|
||||||
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
|
||||||
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
|
||||||
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
|
||||||
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
|
||||||
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
|
|
||||||
return modified_color
|
|
||||||
|
|
||||||
|
|
||||||
# Color map
|
|
||||||
_COLORS = (
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
0.000,
|
|
||||||
0.447,
|
|
||||||
0.741,
|
|
||||||
0.850,
|
|
||||||
0.325,
|
|
||||||
0.098,
|
|
||||||
0.929,
|
|
||||||
0.694,
|
|
||||||
0.125,
|
|
||||||
0.494,
|
|
||||||
0.184,
|
|
||||||
0.556,
|
|
||||||
0.466,
|
|
||||||
0.674,
|
|
||||||
0.188,
|
|
||||||
0.301,
|
|
||||||
0.745,
|
|
||||||
0.933,
|
|
||||||
0.635,
|
|
||||||
0.078,
|
|
||||||
0.184,
|
|
||||||
0.300,
|
|
||||||
0.300,
|
|
||||||
0.300,
|
|
||||||
0.600,
|
|
||||||
0.600,
|
|
||||||
0.600,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
0.749,
|
|
||||||
0.749,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
0.333,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.333,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.500,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
0.500,
|
|
||||||
0.333,
|
|
||||||
0.333,
|
|
||||||
0.500,
|
|
||||||
0.333,
|
|
||||||
0.667,
|
|
||||||
0.500,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
0.500,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
0.500,
|
|
||||||
0.667,
|
|
||||||
0.333,
|
|
||||||
0.500,
|
|
||||||
0.667,
|
|
||||||
0.667,
|
|
||||||
0.500,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
0.500,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.500,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.500,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
0.500,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
0.667,
|
|
||||||
1.000,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.833,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.167,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.833,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.167,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.333,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.500,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.667,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.833,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
1.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.000,
|
|
||||||
0.143,
|
|
||||||
0.143,
|
|
||||||
0.143,
|
|
||||||
0.857,
|
|
||||||
0.857,
|
|
||||||
0.857,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
1.000,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
.astype(np.float32)
|
|
||||||
.reshape(-1, 3)
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user