Skip to content

Commit

Permalink
[AutoPGLE] Explicitly disable command buffers when profiler is used.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709475833
  • Loading branch information
Google-ML-Automation committed Dec 25, 2024
1 parent 64511a1 commit b6aead6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 8 additions & 0 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ def get_compile_options(
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value

# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
env_options_overrides['xla_gpu_enable_command_buffer'] = ''

if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [
Expand Down
8 changes: 2 additions & 6 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def testPGLEProfilerGetFDOProfileLarge(self):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
# TODO(b/37664749): Remove this flag once the bug is fixed.
'xla_gpu_enable_command_buffer': '',
},
)
def f(x):
Expand Down Expand Up @@ -133,8 +133,6 @@ def testAutoPgle(self):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},
Expand Down Expand Up @@ -217,8 +215,6 @@ def testAutoPgleWithPersistentCache(self):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},
Expand Down

0 comments on commit b6aead6

Please sign in to comment.