From 2cf42d7c5864f0bdd0d81b30e68e1dfabd61e0a0 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Mon, 29 Apr 2024 11:48:55 +0200 Subject: [PATCH] Cuda sign-of-life established! Single-threaded. Cuda execution logging happens in the main log file even if there are multiple Cuda devices. (It's not "multicore".) --- arrayjit/lib/cuda_backend.cudajit.ml | 49 +++++----------------------- arrayjit/lib/gccjit_backend.ml | 36 +++----------------- arrayjit/lib/utils.ml | 42 ++++++++++++++++++++++++ bin/micrograd_basic.ml | 8 +++-- 4 files changed, 60 insertions(+), 75 deletions(-) diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index 39f09fc5..b918e81f 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -299,8 +299,7 @@ let compile_main ~traced_store info ppf llc : unit = let run_ptr_debug = get_run_ptr_debug array in let run_ptr = get_run_ptr array in let offset = (idcs, array.dims) in - (* FIXME: does this work or should \n be \\n? *) - let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\n" in + let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\\n" in fprintf ppf "@ @[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(\"%%d: %s\", log_id);@ \ printf(@[\"%%d: %s[%%u] = %%f = %s\\n\"@], log_id,@ %a,@ %s[%a]%a);@ @]}" @@ -413,7 +412,7 @@ type code = { } [@@deriving sexp_of] -let%debug_sexp compile_func ~name idx_params (traced_store, llc) = +let%track_sexp compile_func ~name idx_params (traced_store, llc) = [%log "generating the .cu source"]; let info = { info_arrays = Map.empty (module Tn); used_tensors = Hash_set.create (module Tn) } in let b = Buffer.create 4096 in @@ -471,9 +470,8 @@ extern "C" __global__ void %{name}(%{String.concat ~sep:", " @@ log_id @ idx_par Stdio.Out_channel.close oc); [%log "compiling to PTX"]; let module Cu = Cudajit in - let ptx = - Cu.compile_to_ptx ~cu_src ~name ~options:[ "--use_fast_math" ] ~with_debug:Utils.settings.with_debug - in + let with_debug = Utils.settings.output_debug_files_in_run_directory || Utils.settings.with_debug in + let ptx = Cu.compile_to_ptx ~cu_src ~name ~options:[ "--use_fast_math" ] ~with_debug in if Utils.settings.output_debug_files_in_run_directory then ( let f_name = name ^ "-cudajit-debug" in let oc = Out_channel.open_text @@ f_name ^ ".ptx" in @@ -503,10 +501,6 @@ let%diagn_sexp link_func (old_context : context) ~name info ptx = [%log "compilation finished"]; (func, global_arrays, run_module) -let header_sep = - let open Re in - compile (seq [ str " "; opt any; str "="; str " " ]) - let compile ?name bindings ((_, llc) as compiled : compiled) = let name : string = Option.value_or_thunk name ~default:(fun () -> Low_level.extract_block_name [ llc ]) in let idx_params = Indexing.bound_symbols bindings in @@ -529,7 +523,7 @@ let link old_context code = let%diagn_sexp schedule () = let log_id = get_global_run_id () in let log_id_prefix = Int.to_string log_id ^ ": " in - [%log_result "Scheduling", code.name, old_context.label, (log_id : int)]; + [%log_result "Scheduling", code.name, context.label, (log_id : int)]; let module Cu = Cudajit in let log_arg = if Utils.settings.debug_log_from_routines then [ Cu.Int log_id ] else [] in let idx_args = @@ -571,38 +565,11 @@ let link old_context code = @@ log_arg @ idx_args @ args; [%log "kernel launched"]; if Utils.settings.debug_log_from_routines then - (* FIXME: move this to a shared location, like Assignments or Utils. *) - let rec loop = function - | [] -> [] - | line :: more when String.is_empty line -> loop more - | "COMMENT: end" :: more -> more - | comment :: more when String.is_prefix comment ~prefix:"COMMENT: " -> - let more = - [%log_entry - String.chop_prefix_exn ~prefix:"COMMENT: " comment; - loop more] - in - loop more - | source :: trace :: more when String.is_prefix source ~prefix:"# " -> - (let source = - String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source - in - match Utils.split_with_seps header_sep trace with - | [] | [ "" ] -> [%log source] - | header1 :: assign1 :: header2 :: body -> - let header = String.concat [ header1; assign1; header2 ] in - let body = String.concat body in - let _message = Sexp.(List [ Atom header; Atom source; Atom body ]) in - [%log (_message : Sexp.t)] - | _ -> [%log source, trace]); - loop more - | _line :: more -> - [%log _line]; - loop more - in let postprocess_logs ~output = let output = List.filter_map output ~f:(String.chop_prefix ~prefix:log_id_prefix) in - assert (List.is_empty @@ loop output) + [%log_entry + context.label; + Utils.log_trace_tree _debug_runtime output] in context.device.postprocess_queue <- (context, postprocess_logs) :: context.device.postprocess_queue in diff --git a/arrayjit/lib/gccjit_backend.ml b/arrayjit/lib/gccjit_backend.ml index 76b293bd..18da4022 100644 --- a/arrayjit/lib/gccjit_backend.ml +++ b/arrayjit/lib/gccjit_backend.ml @@ -263,7 +263,8 @@ let debug_log_index ctx log_functions = Block.eval block @@ RValue.call ctx ff [ lf ] | _ -> fun _block _i _index -> () -let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func initial_block (body : Low_level.t) = +let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func initial_block + (body : Low_level.t) = let open Gccjit in let c_int = Type.get ctx Type.Int in let c_index = c_int in @@ -394,7 +395,8 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini | Low_level.Seq (c1, c2) -> loop c1; loop c2 - | For_loop { index; from_; to_; body; trace_it = _ } -> loop_for_loop ~toplevel ~env index ~from_ ~to_ body + | For_loop { index; from_; to_; body; trace_it = _ } -> + loop_for_loop ~toplevel ~env index ~from_ ~to_ body | Set { llv = Binop (Arg2, Get (_, _), _); _ } -> assert false | Set { array; idcs; llv = Binop (op, Get (array2, idcs2), c2); debug } when Tn.equal array array2 && [%equal: Indexing.axis_index array] idcs idcs2 && is_builtin_op op -> @@ -712,35 +714,7 @@ let%track_sexp link_compiled (old_context : context) (code : routine) : context [%log_result name]; callback (); if Utils.settings.debug_log_from_routines then ( - let rec loop = function - | [] -> [] - | line :: more when String.is_empty line -> loop more - | "COMMENT: end" :: more -> more - | comment :: more when String.is_prefix comment ~prefix:"COMMENT: " -> - let more = - [%log_entry - String.chop_prefix_exn ~prefix:"COMMENT: " comment; - loop more] - in - loop more - | source :: trace :: more when String.is_prefix source ~prefix:"# " -> - (let source = - String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source - in - match Utils.split_with_seps header_sep trace with - | [] | [ "" ] -> [%log source] - | header1 :: assign1 :: header2 :: body -> - let header = String.concat [ header1; assign1; header2 ] in - let body = String.concat body in - let _message = Sexp.(List [ Atom header; Atom source; Atom body ]) in - [%log (_message : Sexp.t)] - | _ -> [%log source, trace]); - loop more - | _line :: more -> - [%log _line]; - loop more - in - assert (List.is_empty @@ loop (Stdio.In_channel.read_lines log_file_name)); + Utils.log_trace_tree _debug_runtime (Stdio.In_channel.read_lines log_file_name); Stdlib.Sys.remove log_file_name) in Tn.Work work diff --git a/arrayjit/lib/utils.ml b/arrayjit/lib/utils.ml index 941be3e5..e7aadeb8 100644 --- a/arrayjit/lib/utils.ml +++ b/arrayjit/lib/utils.ml @@ -356,3 +356,45 @@ let get_debug_formatter ~fname = else None exception User_error of string + +let header_sep = + let open Re in + compile (seq [ str " "; opt any; str "="; str " " ]) + +let%diagn_rt_sexp log_trace_tree logs = + let rec loop = function + | [] -> [] + | line :: more when String.is_empty line -> loop more + | "COMMENT: end" :: more -> more + | comment :: more when String.is_prefix comment ~prefix:"COMMENT: " -> + let more = + [%log_entry + String.chop_prefix_exn ~prefix:"COMMENT: " comment; + loop more] + in + loop more + | source :: trace :: more when String.is_prefix source ~prefix:"# " -> + (let source = + String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source + in + match split_with_seps header_sep trace with + | [] | [ "" ] -> [%log source] + | header1 :: assign1 :: header2 :: body -> + let header = String.concat [ header1; assign1; header2 ] in + let body = String.concat body in + let _message = Sexp.(List [ Atom header; Atom source; Atom body ]) in + [%log (_message : Sexp.t)] + | _ -> [%log source, trace]); + loop more + | _line :: more -> + [%log _line]; + loop more + in + let rec loop_logs logs = + let output = loop logs in + if not (List.is_empty output) then + [%log_entry + "TRAILING LOGS:"; + loop_logs output] + in + loop_logs logs diff --git a/bin/micrograd_basic.ml b/bin/micrograd_basic.ml index ead7cb3b..468a7974 100644 --- a/bin/micrograd_basic.ml +++ b/bin/micrograd_basic.ml @@ -9,19 +9,21 @@ module Rand = Arrayjit.Rand.Lib module Debug_runtime = Utils.Debug_runtime let%diagn_sexp () = - let module Backend = (val Train.fresh_backend ()) in + let module Backend = (val Train.fresh_backend ~backend_name:"cuda" ()) in let device = Backend.get_device ~ordinal:0 in let ctx = Backend.init device in Utils.settings.output_debug_files_in_run_directory <- true; + Utils.settings.debug_log_from_routines <- true; + Utils.settings.with_debug <- true; Rand.init 0; let%op c = "a" [ -4 ] + "b" [ 2 ] in let%op d = c + c + 1 in (* let%op c = c + 1 + c + ~-a in *) (* Uncomment just the first "fully on host" line to see which arrays can be virtual, and just the second line to see the intermediate computation values. *) - Train.every_non_literal_on_host c; + Train.every_non_literal_on_host d; (* List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff ]; *) - let update = Train.grad_update c in + let update = Train.grad_update d in let routine = Backend.(link ctx @@ compile IDX.empty update.fwd_bprop) in Train.sync_run (module Backend) routine d; Tensor.print_tree ~with_grad:true ~depth:9 d;