[Benchmarks] improve Example Plotter (#5245)
* improve plotting * better labels * fix time plot
This commit is contained in:
committed by
GitHub
parent
88d7f96e33
commit
79a82cc06a
@@ -1,7 +1,7 @@
|
|||||||
import csv
|
import csv
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -10,6 +10,10 @@ from matplotlib.ticker import ScalarFormatter
|
|||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def list_field(default=None, metadata=None):
|
||||||
|
return field(default_factory=lambda: default, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlotArguments:
|
class PlotArguments:
|
||||||
"""
|
"""
|
||||||
@@ -37,6 +41,25 @@ class PlotArguments:
|
|||||||
figure_png_file: Optional[str] = field(
|
figure_png_file: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
|
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
|
||||||
)
|
)
|
||||||
|
short_model_names: Optional[List[str]] = list_field(
|
||||||
|
default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def can_convert_to_int(string):
|
||||||
|
try:
|
||||||
|
int(string)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def can_convert_to_float(string):
|
||||||
|
try:
|
||||||
|
float(string)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Plot:
|
class Plot:
|
||||||
@@ -50,9 +73,16 @@ class Plot:
|
|||||||
model_name = row["model"]
|
model_name = row["model"]
|
||||||
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
|
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
|
||||||
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
|
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
|
||||||
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
|
if can_convert_to_int(row["result"]):
|
||||||
"result"
|
# value is not None
|
||||||
]
|
self.result_dict[model_name]["result"][
|
||||||
|
(int(row["batch_size"]), int(row["sequence_length"]))
|
||||||
|
] = int(row["result"])
|
||||||
|
elif can_convert_to_float(row["result"]):
|
||||||
|
# value is not None
|
||||||
|
self.result_dict[model_name]["result"][
|
||||||
|
(int(row["batch_size"]), int(row["sequence_length"]))
|
||||||
|
] = float(row["result"])
|
||||||
|
|
||||||
def plot(self):
|
def plot(self):
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
@@ -67,7 +97,7 @@ class Plot:
|
|||||||
for axis in [ax.xaxis, ax.yaxis]:
|
for axis in [ax.xaxis, ax.yaxis]:
|
||||||
axis.set_major_formatter(ScalarFormatter())
|
axis.set_major_formatter(ScalarFormatter())
|
||||||
|
|
||||||
for model_name in self.result_dict.keys():
|
for model_name_idx, model_name in enumerate(self.result_dict.keys()):
|
||||||
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
|
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
|
||||||
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
|
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
|
||||||
results = self.result_dict[model_name]["result"]
|
results = self.result_dict[model_name]["result"]
|
||||||
@@ -76,23 +106,33 @@ class Plot:
|
|||||||
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
|
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
label_model_name = (
|
||||||
|
model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx]
|
||||||
|
)
|
||||||
|
|
||||||
for inner_loop_value in inner_loop_array:
|
for inner_loop_value in inner_loop_array:
|
||||||
if self.args.plot_along_batch:
|
if self.args.plot_along_batch:
|
||||||
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
|
y_axis_array = np.asarray(
|
||||||
|
[results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results],
|
||||||
|
dtype=np.int,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
|
y_axis_array = np.asarray(
|
||||||
|
[results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
(x_axis_label, inner_loop_label) = (
|
(x_axis_label, inner_loop_label) = (
|
||||||
("batch_size", "sequence_length in #tokens")
|
("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz")
|
||||||
if self.args.plot_along_batch
|
|
||||||
else ("sequence_length in #tokens", "batch_size")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x_axis_array = np.asarray(x_axis_array, np.int)
|
x_axis_array = np.asarray(x_axis_array, np.int)[: len(y_axis_array)]
|
||||||
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
|
plt.scatter(
|
||||||
|
x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}"
|
||||||
|
)
|
||||||
plt.plot(x_axis_array, y_axis_array, "--")
|
plt.plot(x_axis_array, y_axis_array, "--")
|
||||||
|
|
||||||
title_str += f" {model_name} vs."
|
title_str += f" {label_model_name} vs."
|
||||||
|
|
||||||
title_str = title_str[:-4]
|
title_str = title_str[:-4]
|
||||||
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"
|
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"
|
||||||
|
|||||||
4
examples/benchmarking/time_xla_1.csv
Normal file
4
examples/benchmarking/time_xla_1.csv
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
model,batch_size,sequence_length,result
|
||||||
|
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,8,512,0.2032
|
||||||
|
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,64,512,1.5279
|
||||||
|
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,256,512,6.1837
|
||||||
|
@@ -74,12 +74,6 @@ class BenchmarkArguments:
|
|||||||
"help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
|
"help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with_lm_head: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Use model with its language model head (MODEL_WITH_LM_HEAD_MAPPING instead of MODEL_MAPPING)"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inference_time_csv_file: str = field(
|
inference_time_csv_file: str = field(
|
||||||
default=f"inference_time_{round(time())}.csv",
|
default=f"inference_time_{round(time())}.csv",
|
||||||
metadata={"help": "CSV filename used if saving time results to csv."},
|
metadata={"help": "CSV filename used if saving time results to csv."},
|
||||||
|
|||||||
Reference in New Issue
Block a user