Skip to content

Commit

Permalink
fixed non-Metal compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisChourakiSonos committed Dec 9, 2024
1 parent c33b266 commit ccdd234
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 68 deletions.
3 changes: 1 addition & 2 deletions api/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ impl ModelInterface for Model {
&PlanOptions::default(),
&inputs,
None,
true,
false
true
)?;
};
let export = tract_libcli::export::GraphPerfInfo::from(&self.0, &annotations);
Expand Down
29 changes: 18 additions & 11 deletions cli/src/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,29 @@ pub fn handle(
let model =
params.tract_model.downcast_ref::<TypedModel>().context("Can only bench TypedModel")?;
let inputs = tract_libcli::tensor::retrieve_or_make_inputs(model, &run_params)?.remove(0);
let mut plan = SimplePlan::new_with_options(model, &plan_options)?;
let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?;

limits.warmup(model, &inputs)?;

#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let session_handler = tract_metal::MetalSessionHandler::from_plan(
&plan,
&state.session_state.resolved_symbols,
let mut state = {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let mut plan = SimplePlan::new_with_options(model, &plan_options)?;
let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?;

let session_handler = tract_metal::MetalSessionHandler::from_plan(
&plan,
&state.session_state.resolved_symbols,
)?;

plan = plan.with_session_handler(session_handler);
}

let mut state = SimpleState::new(Arc::new(plan))?;
plan = plan.with_session_handler(session_handler);
SimpleState::new(Arc::new(plan))?
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
let plan = SimplePlan::new_with_options(model, &plan_options)?;
SimpleState::new(Arc::new(plan))?
}
};
let mut iters = 0;

let progress = probe.and_then(|m| m.get_i64("progress"));
Expand Down
37 changes: 27 additions & 10 deletions cli/src/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,33 @@ pub fn handle(
.downcast_ref::<TypedModel>()
.context("Can only profile typed models")?;
let inputs = retrieve_or_make_inputs(model, &run_params)?;
tract_libcli::profile::profile(
model,
bench_limits,
&mut annotations,
&plan_options,
&inputs[0],
None,
options.folded,
matches.is_present("metal")
)?;

if matches.is_present("metal") {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
tract_libcli::profile::profile_metal(
model,
bench_limits,
&mut annotations,
&plan_options,
&inputs[0],
)?;
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
bail!("Metal profiling called on non-Metal device");
}
} else {
tract_libcli::profile::profile(
model,
bench_limits,
&mut annotations,
&plan_options,
&inputs[0],
None,
options.folded,
)?;
}
}

if sub_matches.is_present("axes") || sub_matches.is_present("axes-names") {
Expand Down
109 changes: 64 additions & 45 deletions libcli/src/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ impl BenchLimits {
}
}

#[allow(clippy::too_many_arguments)]
pub fn profile(
model: &TypedModel,
bench_limits: &BenchLimits,
Expand All @@ -58,66 +57,86 @@ pub fn profile(
inputs: &TVec<TValue>,
custom_profiler: Option<HashMap<TypeId, Profiler>>,
folded: bool,
is_metal: bool,
) -> TractResult<()> {
info!("Running entire network");
let mut iters = 0usize;
let prefix = tvec!();

bench_limits.warmup(model, inputs)?;

let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?;
let plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
let mut state = TypedSimpleState::new(Arc::new(plan))?;

#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let session_handler = tract_metal::MetalSessionHandler::from_plan(
&plan,
&state.session_state.resolved_symbols,
)?;

plan = plan.with_session_handler(session_handler);
let start = crate::time::now();
let mut time_accounted_by_inner_nodes = Duration::default();
while iters < bench_limits.max_loops && start.elapsed() < bench_limits.max_time {
rec_profiler(
&mut state,
dg,
inputs,
custom_profiler.as_ref(),
&prefix,
None,
&mut time_accounted_by_inner_nodes,
folded,
)?;

iters += 1;
}

let mut state = TypedSimpleState::new(Arc::new(plan))?;
let entire = start.elapsed() - time_accounted_by_inner_nodes;

let entire = if !is_metal {
let start = crate::time::now();
let mut time_accounted_by_inner_nodes = Duration::default();
while iters < bench_limits.max_loops && start.elapsed() < bench_limits.max_time {
rec_profiler(
&mut state,
dg,
inputs,
custom_profiler.as_ref(),
&prefix,
None,
&mut time_accounted_by_inner_nodes,
folded,
)?;
info!("Running {} iterations max. for each node.", bench_limits.max_loops);
info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());

iters += 1;
let denum = (iters as f32).recip();
let entire = entire.mul_f32(denum);
for d in dg.tags.values_mut() {
if let Some(d) = d.profile.as_mut() {
*d = d.mul_f32(denum);
}

start.elapsed() - time_accounted_by_inner_nodes
} else {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let mut entire = Duration::default();
while iters < bench_limits.max_loops && entire < bench_limits.max_time {
println!("Running iter {iters}");
entire += rec_profiler_metal(&mut state, dg, inputs, &prefix)?.1;
if let Some(d) = d.accelerator_profile.as_mut() {
*d = d.mul_f32(denum);
}
}
let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
Ok(())
}

#[cfg(any(target_os = "macos", target_os = "ios"))]
pub fn profile_metal(
model: &TypedModel,
bench_limits: &BenchLimits,
dg: &mut Annotations,
plan_options: &PlanOptions,
inputs: &TVec<TValue>,
) -> TractResult<()> {
info!("Running entire network");
let mut iters = 0usize;
let prefix = tvec!();

bench_limits.warmup(model, inputs)?;

iters += 1;
}
let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?;

entire
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
bail!("Metal Profiling on non-Metal Device");
}
};
let session_handler =
tract_metal::MetalSessionHandler::from_plan(&plan, &state.session_state.resolved_symbols)?;

plan = plan.with_session_handler(session_handler);

let mut state = TypedSimpleState::new(Arc::new(plan))?;

let mut entire = Duration::default();
while iters < bench_limits.max_loops && entire < bench_limits.max_time {
entire += rec_profiler_metal(&mut state, dg, inputs, &prefix)?.1;

iters += 1;
}

info!("Running {} iterations max. for each node.", bench_limits.max_loops);
info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
Expand Down

0 comments on commit ccdd234

Please sign in to comment.