diff --git a/api/rs/src/lib.rs b/api/rs/src/lib.rs index 35a955e9b9..df959ccd5e 100644 --- a/api/rs/src/lib.rs +++ b/api/rs/src/lib.rs @@ -277,8 +277,7 @@ impl ModelInterface for Model { &PlanOptions::default(), &inputs, None, - true, - false + true )?; }; let export = tract_libcli::export::GraphPerfInfo::from(&self.0, &annotations); diff --git a/cli/src/bench.rs b/cli/src/bench.rs index de9c63bafd..b6f055b660 100644 --- a/cli/src/bench.rs +++ b/cli/src/bench.rs @@ -37,22 +37,29 @@ pub fn handle( let model = params.tract_model.downcast_ref::().context("Can only bench TypedModel")?; let inputs = tract_libcli::tensor::retrieve_or_make_inputs(model, &run_params)?.remove(0); - let mut plan = SimplePlan::new_with_options(model, &plan_options)?; - let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?; limits.warmup(model, &inputs)?; - #[cfg(any(target_os = "macos", target_os = "ios"))] - { - let session_handler = tract_metal::MetalSessionHandler::from_plan( - &plan, - &state.session_state.resolved_symbols, + let mut state = { + #[cfg(any(target_os = "macos", target_os = "ios"))] + { + let mut plan = SimplePlan::new_with_options(model, &plan_options)?; + let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?; + + let session_handler = tract_metal::MetalSessionHandler::from_plan( + &plan, + &state.session_state.resolved_symbols, )?; - - plan = plan.with_session_handler(session_handler); - } - let mut state = SimpleState::new(Arc::new(plan))?; + plan = plan.with_session_handler(session_handler); + SimpleState::new(Arc::new(plan))? + } + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + { + let plan = SimplePlan::new_with_options(model, &plan_options)?; + SimpleState::new(Arc::new(plan))? + } + }; let mut iters = 0; let progress = probe.and_then(|m| m.get_i64("progress")); diff --git a/cli/src/dump.rs b/cli/src/dump.rs index 0da170f35c..750c6023ef 100644 --- a/cli/src/dump.rs +++ b/cli/src/dump.rs @@ -121,16 +121,33 @@ pub fn handle( .downcast_ref::() .context("Can only profile typed models")?; let inputs = retrieve_or_make_inputs(model, &run_params)?; - tract_libcli::profile::profile( - model, - bench_limits, - &mut annotations, - &plan_options, - &inputs[0], - None, - options.folded, - matches.is_present("metal") - )?; + + if matches.is_present("metal") { + #[cfg(any(target_os = "macos", target_os = "ios"))] + { + tract_libcli::profile::profile_metal( + model, + bench_limits, + &mut annotations, + &plan_options, + &inputs[0], + )?; + } + #[cfg(any(target_os = "macos", target_os = "ios"))] + { + bail!("Metal profiling called on non-Metal device"); + } + } else { + tract_libcli::profile::profile( + model, + bench_limits, + &mut annotations, + &plan_options, + &inputs[0], + None, + options.folded, + )?; + } } if sub_matches.is_present("axes") || sub_matches.is_present("axes-names") { diff --git a/libcli/src/profile.rs b/libcli/src/profile.rs index 8765686ba0..6d528f17fb 100644 --- a/libcli/src/profile.rs +++ b/libcli/src/profile.rs @@ -49,7 +49,6 @@ impl BenchLimits { } } -#[allow(clippy::too_many_arguments)] pub fn profile( model: &TypedModel, bench_limits: &BenchLimits, @@ -58,7 +57,6 @@ pub fn profile( inputs: &TVec, custom_profiler: Option>, folded: bool, - is_metal: bool, ) -> TractResult<()> { info!("Running entire network"); let mut iters = 0usize; @@ -66,58 +64,79 @@ pub fn profile( bench_limits.warmup(model, inputs)?; - let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?; - let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?; + let plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?; + let mut state = TypedSimpleState::new(Arc::new(plan))?; - #[cfg(any(target_os = "macos", target_os = "ios"))] - { - let session_handler = tract_metal::MetalSessionHandler::from_plan( - &plan, - &state.session_state.resolved_symbols, - )?; - - plan = plan.with_session_handler(session_handler); + let start = crate::time::now(); + let mut time_accounted_by_inner_nodes = Duration::default(); + while iters < bench_limits.max_loops && start.elapsed() < bench_limits.max_time { + rec_profiler( + &mut state, + dg, + inputs, + custom_profiler.as_ref(), + &prefix, + None, + &mut time_accounted_by_inner_nodes, + folded, + )?; + + iters += 1; } - let mut state = TypedSimpleState::new(Arc::new(plan))?; + let entire = start.elapsed() - time_accounted_by_inner_nodes; - let entire = if !is_metal { - let start = crate::time::now(); - let mut time_accounted_by_inner_nodes = Duration::default(); - while iters < bench_limits.max_loops && start.elapsed() < bench_limits.max_time { - rec_profiler( - &mut state, - dg, - inputs, - custom_profiler.as_ref(), - &prefix, - None, - &mut time_accounted_by_inner_nodes, - folded, - )?; + info!("Running {} iterations max. for each node.", bench_limits.max_loops); + info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis()); - iters += 1; + let denum = (iters as f32).recip(); + let entire = entire.mul_f32(denum); + for d in dg.tags.values_mut() { + if let Some(d) = d.profile.as_mut() { + *d = d.mul_f32(denum); } - start.elapsed() - time_accounted_by_inner_nodes - } else { - #[cfg(any(target_os = "macos", target_os = "ios"))] - { - let mut entire = Duration::default(); - while iters < bench_limits.max_loops && entire < bench_limits.max_time { - println!("Running iter {iters}"); - entire += rec_profiler_metal(&mut state, dg, inputs, &prefix)?.1; + if let Some(d) = d.accelerator_profile.as_mut() { + *d = d.mul_f32(denum); + } + } + let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap(); + let sum = dg.tags.values().filter_map(|t| t.profile).sum::(); + let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::(); + dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters }); + Ok(()) +} + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub fn profile_metal( + model: &TypedModel, + bench_limits: &BenchLimits, + dg: &mut Annotations, + plan_options: &PlanOptions, + inputs: &TVec, +) -> TractResult<()> { + info!("Running entire network"); + let mut iters = 0usize; + let prefix = tvec!(); + + bench_limits.warmup(model, inputs)?; - iters += 1; - } + let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?; + let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?; - entire - } - #[cfg(not(any(target_os = "macos", target_os = "ios")))] - { - bail!("Metal Profiling on non-Metal Device"); - } - }; + let session_handler = + tract_metal::MetalSessionHandler::from_plan(&plan, &state.session_state.resolved_symbols)?; + + plan = plan.with_session_handler(session_handler); + + let mut state = TypedSimpleState::new(Arc::new(plan))?; + + let mut entire = Duration::default(); + while iters < bench_limits.max_loops && entire < bench_limits.max_time { + entire += rec_profiler_metal(&mut state, dg, inputs, &prefix)?.1; + + iters += 1; + } info!("Running {} iterations max. for each node.", bench_limits.max_loops); info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());