diff --git a/run.py b/run.py index 8b69e0544a..2ff7db533a 100644 --- a/run.py +++ b/run.py @@ -103,6 +103,8 @@ def printResultSummaryTime(result_summary, model, metrics_needed=[], flops_model flops = model.get_flops() tflops = flops / (cpu_walltime / 1.0e3) / 1.0e12 print('{:<20} {:>20}'.format("GPU FLOPS:", "%.4f TFLOPs per second" % tflops, sep='')) + if 'ttfb' in metrics_needed: + print('{:<20} {:>20}'.format("Time to first batch:", "%.4f ms" % model.ttfb, sep='')) if model_flops is not None: tflops = model_flops / (cpu_walltime / 1.0e3) / 1.0e12 print('{:<20} {:>20}'.format("Model Flops:", "%.4f TFLOPs per second" % tflops, sep='')) @@ -356,8 +358,8 @@ def _validate_profile_options(profile_options: str): parser.add_argument( "--metrics", type=str, - default="cpu_peak_mem,gpu_peak_mem", - help="Specify metrics [cpu_peak_mem,gpu_peak_mem,flops,model_flops]to be collected. You can also set `none` to disable all metrics. The metrics are separated by comma such as cpu_peak_mem,gpu_peak_mem.", + default="cpu_peak_mem,gpu_peak_mem,ttfb", + help="Specify metrics [cpu_peak_mem,gpu_peak_mem,ttfb,flops,model_flops]to be collected. You can also set `none` to disable all metrics. The metrics are separated by comma such as cpu_peak_mem,gpu_peak_mem.", ) parser.add_argument( "--metrics-gpu-backend", diff --git a/torchbenchmark/util/experiment/metrics.py b/torchbenchmark/util/experiment/metrics.py index b5e910a9e5..881e9c7ef2 100644 --- a/torchbenchmark/util/experiment/metrics.py +++ b/torchbenchmark/util/experiment/metrics.py @@ -20,6 +20,7 @@ class TorchBenchModelMetrics: throughputs: List[float] cpu_peak_mem: Optional[float] gpu_peak_mem: Optional[float] + ttfb: Optional[float] # time-to-first-batch pt2_compilation_time: Optional[float] pt2_graph_breaks: Optional[float] model_flops: Optional[float] @@ -112,6 +113,7 @@ def get_model_test_metrics(model: Union[BenchmarkModel, ModelTask], metrics=[], throughputs = None cpu_peak_mem = None gpu_peak_mem = None + ttfb = None pt2_compilation_time = None pt2_graph_breaks = None model_flops = None @@ -133,7 +135,10 @@ def get_model_test_metrics(model: Union[BenchmarkModel, ModelTask], metrics=[], if isinstance(model, ModelTask) else model.pt2_graph_breaks if 'model_flops' in metrics: model_flops = get_model_flops(model) - return TorchBenchModelMetrics(latencies, throughputs, cpu_peak_mem, gpu_peak_mem, pt2_compilation_time, pt2_graph_breaks, model_flops) + if 'ttfb' in metrics: + ttfb = model.get_model_attribute('ttfb') \ + if isinstance(model, ModelTask) else model.ttfb + return TorchBenchModelMetrics(latencies, throughputs, cpu_peak_mem, gpu_peak_mem, ttfb, pt2_compilation_time, pt2_graph_breaks, model_flops) def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True) -> str: import copy diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 0029737169..2f4a008610 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -3,6 +3,7 @@ from contextlib import contextmanager, ExitStack import warnings import yaml +import time from pathlib import Path from typing import ContextManager, Optional, List, Tuple, Generator from torchbenchmark import REPO_PATH @@ -79,6 +80,7 @@ class BenchmarkModel(metaclass=PostInitProcessor): See [Adding Models](#../models/ADDING_MODELS.md) """ def __init__(self, test: str, device: str, batch_size: Optional[int]=None, extra_args: List[str]=[]): + self._start_init_time = time.time_ns() self.metadata = self._load_metadata() self.test = test # sanity checks of the options @@ -149,6 +151,7 @@ def __post__init__(self): # Need to clean up the cache because we run deep copy within correceness check if self.device == "cuda": torch.cuda.empty_cache() + self._end_init_time = time.time_ns() def _skip_by_device_name(self): if not self.device == "cuda": @@ -390,3 +393,8 @@ def pt2_graph_breaks(self): from torch._dynamo.utils import counters num_graph_breaks = len(counters["graph_break"].keys()) return num_graph_breaks + + @property + def ttfb(self): + """Return the time taken to the first batch in ms.""" + return (self._end_init_time - self._start_init_time) / 1_000_000