Skip to content

Commit

Permalink
Cuda sign-of-life established! Single-threaded.
Browse files Browse the repository at this point in the history
Cuda execution logging happens in the main log file even if there are multiple Cuda devices.
(It's not "multicore".)
  • Loading branch information
lukstafi committed Apr 29, 2024
1 parent 97417db commit 2cf42d7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 75 deletions.
49 changes: 8 additions & 41 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ let compile_main ~traced_store info ppf llc : unit =
let run_ptr_debug = get_run_ptr_debug array in
let run_ptr = get_run_ptr array in
let offset = (idcs, array.dims) in
(* FIXME: does this work or should \n be \\n? *)
let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\n" in
let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\\n" in
fprintf ppf
"@ @[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(\"%%d: %s\", log_id);@ \
printf(@[<h>\"%%d: %s[%%u] = %%f = %s\\n\"@], log_id,@ %a,@ %s[%a]%a);@ @]}"
Expand Down Expand Up @@ -413,7 +412,7 @@ type code = {
}
[@@deriving sexp_of]

let%debug_sexp compile_func ~name idx_params (traced_store, llc) =
let%track_sexp compile_func ~name idx_params (traced_store, llc) =
[%log "generating the .cu source"];
let info = { info_arrays = Map.empty (module Tn); used_tensors = Hash_set.create (module Tn) } in
let b = Buffer.create 4096 in
Expand Down Expand Up @@ -471,9 +470,8 @@ extern "C" __global__ void %{name}(%{String.concat ~sep:", " @@ log_id @ idx_par
Stdio.Out_channel.close oc);
[%log "compiling to PTX"];
let module Cu = Cudajit in
let ptx =
Cu.compile_to_ptx ~cu_src ~name ~options:[ "--use_fast_math" ] ~with_debug:Utils.settings.with_debug
in
let with_debug = Utils.settings.output_debug_files_in_run_directory || Utils.settings.with_debug in
let ptx = Cu.compile_to_ptx ~cu_src ~name ~options:[ "--use_fast_math" ] ~with_debug in
if Utils.settings.output_debug_files_in_run_directory then (
let f_name = name ^ "-cudajit-debug" in
let oc = Out_channel.open_text @@ f_name ^ ".ptx" in
Expand Down Expand Up @@ -503,10 +501,6 @@ let%diagn_sexp link_func (old_context : context) ~name info ptx =
[%log "compilation finished"];
(func, global_arrays, run_module)

let header_sep =
let open Re in
compile (seq [ str " "; opt any; str "="; str " " ])

let compile ?name bindings ((_, llc) as compiled : compiled) =
let name : string = Option.value_or_thunk name ~default:(fun () -> Low_level.extract_block_name [ llc ]) in
let idx_params = Indexing.bound_symbols bindings in
Expand All @@ -529,7 +523,7 @@ let link old_context code =
let%diagn_sexp schedule () =
let log_id = get_global_run_id () in
let log_id_prefix = Int.to_string log_id ^ ": " in
[%log_result "Scheduling", code.name, old_context.label, (log_id : int)];
[%log_result "Scheduling", code.name, context.label, (log_id : int)];
let module Cu = Cudajit in
let log_arg = if Utils.settings.debug_log_from_routines then [ Cu.Int log_id ] else [] in
let idx_args =
Expand Down Expand Up @@ -571,38 +565,11 @@ let link old_context code =
@@ log_arg @ idx_args @ args;
[%log "kernel launched"];
if Utils.settings.debug_log_from_routines then
(* FIXME: move this to a shared location, like Assignments or Utils. *)
let rec loop = function
| [] -> []
| line :: more when String.is_empty line -> loop more
| "COMMENT: end" :: more -> more
| comment :: more when String.is_prefix comment ~prefix:"COMMENT: " ->
let more =
[%log_entry
String.chop_prefix_exn ~prefix:"COMMENT: " comment;
loop more]
in
loop more
| source :: trace :: more when String.is_prefix source ~prefix:"# " ->
(let source =
String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source
in
match Utils.split_with_seps header_sep trace with
| [] | [ "" ] -> [%log source]
| header1 :: assign1 :: header2 :: body ->
let header = String.concat [ header1; assign1; header2 ] in
let body = String.concat body in
let _message = Sexp.(List [ Atom header; Atom source; Atom body ]) in
[%log (_message : Sexp.t)]
| _ -> [%log source, trace]);
loop more
| _line :: more ->
[%log _line];
loop more
in
let postprocess_logs ~output =
let output = List.filter_map output ~f:(String.chop_prefix ~prefix:log_id_prefix) in
assert (List.is_empty @@ loop output)
[%log_entry
context.label;
Utils.log_trace_tree _debug_runtime output]
in
context.device.postprocess_queue <- (context, postprocess_logs) :: context.device.postprocess_queue
in
Expand Down
36 changes: 5 additions & 31 deletions arrayjit/lib/gccjit_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ let debug_log_index ctx log_functions =
Block.eval block @@ RValue.call ctx ff [ lf ]
| _ -> fun _block _i _index -> ()

let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func initial_block (body : Low_level.t) =
let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func initial_block
(body : Low_level.t) =
let open Gccjit in
let c_int = Type.get ctx Type.Int in
let c_index = c_int in
Expand Down Expand Up @@ -394,7 +395,8 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
| Low_level.Seq (c1, c2) ->
loop c1;
loop c2
| For_loop { index; from_; to_; body; trace_it = _ } -> loop_for_loop ~toplevel ~env index ~from_ ~to_ body
| For_loop { index; from_; to_; body; trace_it = _ } ->
loop_for_loop ~toplevel ~env index ~from_ ~to_ body
| Set { llv = Binop (Arg2, Get (_, _), _); _ } -> assert false
| Set { array; idcs; llv = Binop (op, Get (array2, idcs2), c2); debug }
when Tn.equal array array2 && [%equal: Indexing.axis_index array] idcs idcs2 && is_builtin_op op ->
Expand Down Expand Up @@ -712,35 +714,7 @@ let%track_sexp link_compiled (old_context : context) (code : routine) : context
[%log_result name];
callback ();
if Utils.settings.debug_log_from_routines then (
let rec loop = function
| [] -> []
| line :: more when String.is_empty line -> loop more
| "COMMENT: end" :: more -> more
| comment :: more when String.is_prefix comment ~prefix:"COMMENT: " ->
let more =
[%log_entry
String.chop_prefix_exn ~prefix:"COMMENT: " comment;
loop more]
in
loop more
| source :: trace :: more when String.is_prefix source ~prefix:"# " ->
(let source =
String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source
in
match Utils.split_with_seps header_sep trace with
| [] | [ "" ] -> [%log source]
| header1 :: assign1 :: header2 :: body ->
let header = String.concat [ header1; assign1; header2 ] in
let body = String.concat body in
let _message = Sexp.(List [ Atom header; Atom source; Atom body ]) in
[%log (_message : Sexp.t)]
| _ -> [%log source, trace]);
loop more
| _line :: more ->
[%log _line];
loop more
in
assert (List.is_empty @@ loop (Stdio.In_channel.read_lines log_file_name));
Utils.log_trace_tree _debug_runtime (Stdio.In_channel.read_lines log_file_name);
Stdlib.Sys.remove log_file_name)
in
Tn.Work work
Expand Down
42 changes: 42 additions & 0 deletions arrayjit/lib/utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,45 @@ let get_debug_formatter ~fname =
else None

exception User_error of string

let header_sep =
let open Re in
compile (seq [ str " "; opt any; str "="; str " " ])

let%diagn_rt_sexp log_trace_tree logs =
let rec loop = function
| [] -> []
| line :: more when String.is_empty line -> loop more
| "COMMENT: end" :: more -> more
| comment :: more when String.is_prefix comment ~prefix:"COMMENT: " ->
let more =
[%log_entry
String.chop_prefix_exn ~prefix:"COMMENT: " comment;
loop more]
in
loop more
| source :: trace :: more when String.is_prefix source ~prefix:"# " ->
(let source =
String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source
in
match split_with_seps header_sep trace with
| [] | [ "" ] -> [%log source]
| header1 :: assign1 :: header2 :: body ->
let header = String.concat [ header1; assign1; header2 ] in
let body = String.concat body in
let _message = Sexp.(List [ Atom header; Atom source; Atom body ]) in
[%log (_message : Sexp.t)]
| _ -> [%log source, trace]);
loop more
| _line :: more ->
[%log _line];
loop more
in
let rec loop_logs logs =
let output = loop logs in
if not (List.is_empty output) then
[%log_entry
"TRAILING LOGS:";
loop_logs output]
in
loop_logs logs
8 changes: 5 additions & 3 deletions bin/micrograd_basic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@ module Rand = Arrayjit.Rand.Lib
module Debug_runtime = Utils.Debug_runtime

let%diagn_sexp () =
let module Backend = (val Train.fresh_backend ()) in
let module Backend = (val Train.fresh_backend ~backend_name:"cuda" ()) in
let device = Backend.get_device ~ordinal:0 in
let ctx = Backend.init device in
Utils.settings.output_debug_files_in_run_directory <- true;
Utils.settings.debug_log_from_routines <- true;
Utils.settings.with_debug <- true;
Rand.init 0;
let%op c = "a" [ -4 ] + "b" [ 2 ] in
let%op d = c + c + 1 in
(* let%op c = c + 1 + c + ~-a in *)
(* Uncomment just the first "fully on host" line to see which arrays can be virtual, and just the second
line to see the intermediate computation values. *)
Train.every_non_literal_on_host c;
Train.every_non_literal_on_host d;
(* List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff ]; *)
let update = Train.grad_update c in
let update = Train.grad_update d in
let routine = Backend.(link ctx @@ compile IDX.empty update.fwd_bprop) in
Train.sync_run (module Backend) routine d;
Tensor.print_tree ~with_grad:true ~depth:9 d;
Expand Down

0 comments on commit 2cf42d7

Please sign in to comment.