fix PR (#4810)
This commit is contained in:
committed by
GitHub
parent
e817747941
commit
c0554776de
@@ -18,8 +18,8 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import timeit
|
import timeit
|
||||||
|
|
||||||
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available
|
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available
|
||||||
@@ -52,66 +52,34 @@ class PyTorchBenchmark(Benchmark):
|
|||||||
model.to(self.args.device)
|
model.to(self.args.device)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
# encoder-decoder has vocab size saved differently
|
||||||
|
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
||||||
input_ids = torch.randint(
|
input_ids = torch.randint(
|
||||||
model.config.vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
|
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_loss_and_backprob():
|
def compute_loss_and_backprob_encoder():
|
||||||
# TODO: Not all models call labels argument labels => this hack using the function signature should be corrected once all models have a common name for labels
|
|
||||||
function_argument_names = inspect.getfullargspec(model.forward).args
|
|
||||||
if "labels" in function_argument_names:
|
|
||||||
loss = model(input_ids, labels=input_ids)[0]
|
loss = model(input_ids, labels=input_ids)[0]
|
||||||
elif "lm_labels" in function_argument_names:
|
|
||||||
loss = model(input_ids, lm_labels=input_ids)[0]
|
|
||||||
elif "masked_lm_labels" in function_argument_names:
|
|
||||||
loss = model(input_ids, masked_lm_labels=input_ids)[0]
|
|
||||||
else:
|
|
||||||
NotImplementedError(f"{model_name} does not seem to allow training with labels")
|
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
|
|
||||||
if trace_memory is True:
|
def compute_loss_and_backprob_encoder_decoder():
|
||||||
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
|
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
|
||||||
trace = start_memory_tracing("transformers")
|
loss.backward()
|
||||||
else:
|
model.zero_grad()
|
||||||
# clear cuda cache
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
# calculate loss and do backpropagation
|
_train = (
|
||||||
compute_loss_and_backprob()
|
compute_loss_and_backprob_encoder_decoder
|
||||||
|
if config.is_encoder_decoder
|
||||||
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
|
else compute_loss_and_backprob_encoder
|
||||||
summary = stop_memory_tracing(trace)
|
|
||||||
memory = summary.total
|
|
||||||
else:
|
|
||||||
memory = Memory(torch.cuda.max_memory_reserved())
|
|
||||||
|
|
||||||
return memory
|
|
||||||
else:
|
|
||||||
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
|
||||||
runtimes = timeit.repeat(lambda: compute_loss_and_backprob(), repeat=self.args.repeat, number=10,)
|
|
||||||
return min(runtimes) / 10.0
|
|
||||||
except RuntimeError as e:
|
|
||||||
self.print_fn("Doesn't fit on GPU. {}".format(e))
|
|
||||||
return "N/A"
|
|
||||||
|
|
||||||
def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
|
|
||||||
try:
|
|
||||||
config = self.config_dict[model_name]
|
|
||||||
model = MODEL_MAPPING[config.__class__](config)
|
|
||||||
model.to(self.args.device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
input_ids = torch.randint(
|
|
||||||
config.vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if trace_memory is True:
|
if trace_memory is True:
|
||||||
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
|
if self.args.trace_memory_line_by_line:
|
||||||
trace = start_memory_tracing("transformers")
|
trace = start_memory_tracing("transformers")
|
||||||
else:
|
|
||||||
# clear cuda cache
|
if self.args.n_gpu > 0:
|
||||||
|
# clear gpu cache
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if hasattr(torch.cuda, "max_memory_reserved"):
|
if hasattr(torch.cuda, "max_memory_reserved"):
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
@@ -121,12 +89,16 @@ class PyTorchBenchmark(Benchmark):
|
|||||||
)
|
)
|
||||||
torch.cuda.reset_max_memory_cached()
|
torch.cuda.reset_max_memory_cached()
|
||||||
|
|
||||||
model(input_ids)
|
# calculate loss and do backpropagation
|
||||||
|
_train()
|
||||||
|
|
||||||
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
|
if self.args.trace_memory_line_by_line:
|
||||||
summary = stop_memory_tracing(trace)
|
summary = stop_memory_tracing(trace)
|
||||||
memory = summary.total
|
|
||||||
else:
|
else:
|
||||||
|
summary = None
|
||||||
|
|
||||||
|
if self.args.n_gpu > 0:
|
||||||
|
# gpu
|
||||||
if hasattr(torch.cuda, "max_memory_reserved"):
|
if hasattr(torch.cuda, "max_memory_reserved"):
|
||||||
memory = Memory(torch.cuda.max_memory_reserved())
|
memory = Memory(torch.cuda.max_memory_reserved())
|
||||||
else:
|
else:
|
||||||
@@ -134,11 +106,106 @@ class PyTorchBenchmark(Benchmark):
|
|||||||
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
|
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
|
||||||
)
|
)
|
||||||
memory = Memory(torch.cuda.max_memory_cached())
|
memory = Memory(torch.cuda.max_memory_cached())
|
||||||
|
memory = Memory(torch.cuda.max_memory_reserved())
|
||||||
|
else:
|
||||||
|
# cpu
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
except (ImportError):
|
||||||
|
logger.warning(
|
||||||
|
"Psutil not installed, we won't log CPU memory usage. "
|
||||||
|
"Install psutil (pip install psutil) to use CPU memory tracing."
|
||||||
|
)
|
||||||
|
memory = "N/A"
|
||||||
|
else:
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
memory = Memory(process.memory_info().rss)
|
||||||
|
|
||||||
return memory
|
return memory, summary
|
||||||
else:
|
else:
|
||||||
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
||||||
runtimes = timeit.repeat(lambda: model(input_ids), repeat=self.args.repeat, number=10,)
|
runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,)
|
||||||
|
return min(runtimes) / 10.0
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.print_fn("Doesn't fit on GPU. {}".format(e))
|
||||||
|
return "N/A"
|
||||||
|
|
||||||
|
def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
|
||||||
|
try:
|
||||||
|
config = self.config_dict[model_name]
|
||||||
|
|
||||||
|
if self.args.with_lm_head:
|
||||||
|
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
||||||
|
else:
|
||||||
|
model = MODEL_MAPPING[config.__class__](config)
|
||||||
|
|
||||||
|
model.to(self.args.device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# encoder-decoder has vocab size saved differently
|
||||||
|
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
||||||
|
|
||||||
|
input_ids = torch.randint(
|
||||||
|
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def encoder_decoder_forward():
|
||||||
|
model(input_ids, decoder_input_ids=input_ids)
|
||||||
|
|
||||||
|
def encoder_forward():
|
||||||
|
model(input_ids)
|
||||||
|
|
||||||
|
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
||||||
|
|
||||||
|
if trace_memory is True:
|
||||||
|
if self.args.trace_memory_line_by_line:
|
||||||
|
trace = start_memory_tracing("transformers")
|
||||||
|
|
||||||
|
if self.args.n_gpu > 0:
|
||||||
|
# clear gpu cache
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if hasattr(torch.cuda, "max_memory_reserved"):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
|
||||||
|
)
|
||||||
|
torch.cuda.reset_max_memory_cached()
|
||||||
|
|
||||||
|
_forward()
|
||||||
|
|
||||||
|
if self.args.trace_memory_line_by_line:
|
||||||
|
summary = stop_memory_tracing(trace)
|
||||||
|
else:
|
||||||
|
summary = None
|
||||||
|
|
||||||
|
if self.args.n_gpu > 0:
|
||||||
|
# gpu
|
||||||
|
if hasattr(torch.cuda, "max_memory_reserved"):
|
||||||
|
memory = Memory(torch.cuda.max_memory_reserved())
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
|
||||||
|
)
|
||||||
|
memory = Memory(torch.cuda.max_memory_cached())
|
||||||
|
else:
|
||||||
|
# cpu
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
except (ImportError):
|
||||||
|
logger.warning(
|
||||||
|
"Psutil not installed, we won't log CPU memory usage. "
|
||||||
|
"Install psutil (pip install psutil) to use CPU memory tracing."
|
||||||
|
)
|
||||||
|
memory = "N/A"
|
||||||
|
else:
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
memory = Memory(process.memory_info().rss)
|
||||||
|
|
||||||
|
return memory, summary
|
||||||
|
else:
|
||||||
|
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
||||||
|
runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,)
|
||||||
return min(runtimes) / 10.0
|
return min(runtimes) / 10.0
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|||||||
@@ -61,6 +61,12 @@ class BenchmarkArguments:
|
|||||||
save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
|
save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
|
||||||
log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
|
log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
|
||||||
no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"})
|
no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"})
|
||||||
|
with_lm_head: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use model with its language model head (MODEL_WITH_LM_HEAD_MAPPING instead of MODEL_MAPPING)"
|
||||||
|
},
|
||||||
|
)
|
||||||
inference_time_csv_file: str = field(
|
inference_time_csv_file: str = field(
|
||||||
default=f"inference_time_{round(time())}.csv",
|
default=f"inference_time_{round(time())}.csv",
|
||||||
metadata={"help": "CSV filename used if saving time results to csv."},
|
metadata={"help": "CSV filename used if saving time results to csv."},
|
||||||
|
|||||||
@@ -36,7 +36,15 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|||||||
_is_memory_tracing_enabled = False
|
_is_memory_tracing_enabled = False
|
||||||
|
|
||||||
BenchmarkOutput = namedtuple(
|
BenchmarkOutput = namedtuple(
|
||||||
"BenchmarkOutput", ["time_inference_result", "memory_inference_result", "time_train_result", "memory_train_result"]
|
"BenchmarkOutput",
|
||||||
|
[
|
||||||
|
"time_inference_result",
|
||||||
|
"memory_inference_result",
|
||||||
|
"time_train_result",
|
||||||
|
"memory_train_result",
|
||||||
|
"inference_summary",
|
||||||
|
"train_summary",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -401,15 +409,10 @@ class Benchmark(ABC):
|
|||||||
def print_fn(self):
|
def print_fn(self):
|
||||||
if self._print_fn is None:
|
if self._print_fn is None:
|
||||||
if self.args.log_print:
|
if self.args.log_print:
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG,
|
|
||||||
filename=self.args.log_filename,
|
|
||||||
filemode="a+",
|
|
||||||
format="%(asctime)-15s %(levelname)-8s %(message)s",
|
|
||||||
)
|
|
||||||
|
|
||||||
def print_and_log(*args):
|
def print_and_log(*args):
|
||||||
logging.info(*args)
|
with open(self.args.log_filename, "a") as log_file:
|
||||||
|
log_file.write(str(*args) + "\n")
|
||||||
print(*args)
|
print(*args)
|
||||||
|
|
||||||
self._print_fn = print_and_log
|
self._print_fn = print_and_log
|
||||||
@@ -454,11 +457,15 @@ class Benchmark(ABC):
|
|||||||
train_result_time[model_name] = copy.deepcopy(model_dict)
|
train_result_time[model_name] = copy.deepcopy(model_dict)
|
||||||
train_result_memory[model_name] = copy.deepcopy(model_dict)
|
train_result_memory[model_name] = copy.deepcopy(model_dict)
|
||||||
|
|
||||||
|
inference_summary = train_summary = None
|
||||||
|
|
||||||
for batch_size in self.args.batch_sizes:
|
for batch_size in self.args.batch_sizes:
|
||||||
for sequence_length in self.args.sequence_lengths:
|
for sequence_length in self.args.sequence_lengths:
|
||||||
if not self.args.no_inference:
|
if not self.args.no_inference:
|
||||||
if not self.args.no_memory:
|
if not self.args.no_memory:
|
||||||
memory = self.inference(model_name, batch_size, sequence_length, trace_memory=True)
|
memory, inference_summary = self.inference(
|
||||||
|
model_name, batch_size, sequence_length, trace_memory=True
|
||||||
|
)
|
||||||
inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
|
inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
|
||||||
if not self.args.no_speed:
|
if not self.args.no_speed:
|
||||||
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
|
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
|
||||||
@@ -466,7 +473,9 @@ class Benchmark(ABC):
|
|||||||
|
|
||||||
if self.args.training:
|
if self.args.training:
|
||||||
if not self.args.no_memory:
|
if not self.args.no_memory:
|
||||||
memory = self.train(model_name, batch_size, sequence_length, trace_memory=True)
|
memory, train_summary = self.train(
|
||||||
|
model_name, batch_size, sequence_length, trace_memory=True
|
||||||
|
)
|
||||||
train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
|
train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
|
||||||
if not self.args.no_speed:
|
if not self.args.no_speed:
|
||||||
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
|
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
|
||||||
@@ -483,6 +492,10 @@ class Benchmark(ABC):
|
|||||||
self.print_results(inference_result_memory)
|
self.print_results(inference_result_memory)
|
||||||
self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
|
self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
|
||||||
|
|
||||||
|
if self.args.trace_memory_line_by_line:
|
||||||
|
self.print_fn("======= INFERENCE - MEMORY LINE BY LINE TRACE - SUMMARY =======")
|
||||||
|
self.print_memory_trace_statistics(inference_summary)
|
||||||
|
|
||||||
if self.args.training:
|
if self.args.training:
|
||||||
if not self.args.no_speed:
|
if not self.args.no_speed:
|
||||||
self.print_fn("======= TRAIN - SPEED - RESULT =======")
|
self.print_fn("======= TRAIN - SPEED - RESULT =======")
|
||||||
@@ -494,6 +507,10 @@ class Benchmark(ABC):
|
|||||||
self.print_results(train_result_memory)
|
self.print_results(train_result_memory)
|
||||||
self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
|
self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
|
||||||
|
|
||||||
|
if self.args.trace_memory_line_by_line:
|
||||||
|
self.print_fn("======= TRAIN - MEMORY LINE BY LINE TRACE - SUMMARY =======")
|
||||||
|
self.print_memory_trace_statistics(train_summary)
|
||||||
|
|
||||||
if not self.args.no_env_print:
|
if not self.args.no_env_print:
|
||||||
self.print_fn("\n======== ENVIRONMENT - INFORMATION ========")
|
self.print_fn("\n======== ENVIRONMENT - INFORMATION ========")
|
||||||
self.print_fn(
|
self.print_fn(
|
||||||
@@ -506,7 +523,14 @@ class Benchmark(ABC):
|
|||||||
for key, value in self.environment_info.items():
|
for key, value in self.environment_info.items():
|
||||||
writer.writerow([key, value])
|
writer.writerow([key, value])
|
||||||
|
|
||||||
return BenchmarkOutput(inference_result_time, inference_result_memory, train_result_time, train_result_memory)
|
return BenchmarkOutput(
|
||||||
|
inference_result_time,
|
||||||
|
inference_result_memory,
|
||||||
|
train_result_time,
|
||||||
|
train_result_memory,
|
||||||
|
inference_summary,
|
||||||
|
train_summary,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def environment_info(self):
|
def environment_info(self):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers import GPT2Config, is_torch_available
|
from transformers import AutoConfig, is_torch_available
|
||||||
|
|
||||||
from .utils import require_torch
|
from .utils import require_torch
|
||||||
|
|
||||||
@@ -45,7 +45,18 @@ class BenchmarkTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_inference_with_configs(self):
|
def test_inference_with_configs(self):
|
||||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
config = GPT2Config.from_pretrained(MODEL_ID)
|
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||||
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
|
||||||
|
)
|
||||||
|
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
|
||||||
|
results = benchmark.run()
|
||||||
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
|
def test_inference_encoder_decoder_with_configs(self):
|
||||||
|
MODEL_ID = "sshleifer/tinier_bart"
|
||||||
|
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||||
benchmark_args = PyTorchBenchmarkArguments(
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
|
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
|
||||||
)
|
)
|
||||||
@@ -56,7 +67,18 @@ class BenchmarkTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_train_with_configs(self):
|
def test_train_with_configs(self):
|
||||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
config = GPT2Config.from_pretrained(MODEL_ID)
|
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||||
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
|
||||||
|
)
|
||||||
|
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
|
||||||
|
results = benchmark.run()
|
||||||
|
self.check_results_dict_not_empty(results.time_train_result)
|
||||||
|
self.check_results_dict_not_empty(results.memory_train_result)
|
||||||
|
|
||||||
|
def test_train_encoder_decoder_with_configs(self):
|
||||||
|
MODEL_ID = "sshleifer/tinier_bart"
|
||||||
|
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||||
benchmark_args = PyTorchBenchmarkArguments(
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
|
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
|
||||||
)
|
)
|
||||||
@@ -88,3 +110,29 @@ class BenchmarkTest(unittest.TestCase):
|
|||||||
self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists())
|
self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists())
|
||||||
self.assertTrue(Path(os.path.join(tmp_dir, "train_mem.csv")).exists())
|
self.assertTrue(Path(os.path.join(tmp_dir, "train_mem.csv")).exists())
|
||||||
self.assertTrue(Path(os.path.join(tmp_dir, "env.csv")).exists())
|
self.assertTrue(Path(os.path.join(tmp_dir, "env.csv")).exists())
|
||||||
|
|
||||||
|
def test_trace_memory(self):
|
||||||
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
|
|
||||||
|
def _check_summary_is_not_empty(summary):
|
||||||
|
self.assertTrue(hasattr(summary, "sequential"))
|
||||||
|
self.assertTrue(hasattr(summary, "cumulative"))
|
||||||
|
self.assertTrue(hasattr(summary, "current"))
|
||||||
|
self.assertTrue(hasattr(summary, "total"))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
models=[MODEL_ID],
|
||||||
|
training=True,
|
||||||
|
no_inference=False,
|
||||||
|
sequence_lengths=[8],
|
||||||
|
batch_sizes=[1],
|
||||||
|
log_filename=os.path.join(tmp_dir, "log.txt"),
|
||||||
|
log_print=True,
|
||||||
|
trace_memory_line_by_line=True,
|
||||||
|
)
|
||||||
|
benchmark = PyTorchBenchmark(benchmark_args)
|
||||||
|
result = benchmark.run()
|
||||||
|
_check_summary_is_not_empty(result.inference_summary)
|
||||||
|
_check_summary_is_not_empty(result.train_summary)
|
||||||
|
self.assertTrue(Path(os.path.join(tmp_dir, "log.txt")).exists())
|
||||||
|
|||||||
Reference in New Issue
Block a user