Skip to content

Commit

Permalink
Backends: Get rid of is_double, use prec instead
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Apr 30, 2024
1 parent 674807f commit 713e2ec
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 94 deletions.
78 changes: 33 additions & 45 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type tn_info = {
num_typ : string;
(** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), [double]
(precision [Double]). *)
is_double : bool;
prec : Ops.prec;
zero_initialized : bool;
}
[@@deriving sexp_of]
Expand Down Expand Up @@ -200,13 +200,6 @@ let get_run_ptr_debug array =
| Some rv, _ -> "global_" ^ rv
| None, None -> assert false

let prec_to_c_type = function
| Ops.Void_prec -> "void"
| Byte_prec _ -> "uint8"
| Half_prec _ -> (* FIXME: *) "uint16"
| Single_prec _ -> "float"
| Double_prec _ -> "double"

(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim -> idx +
(offset * dim)) *)

Expand All @@ -220,12 +213,12 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key =
let dims = Lazy.force key.dims in
let size_in_elems = Array.fold ~init:1 ~f:( * ) dims in
let hosted = Lazy.force key.array in
let size_in_bytes = size_in_elems * Ops.prec_in_bytes key.prec in
let prec = key.prec in
let size_in_bytes = size_in_elems * Ops.prec_in_bytes prec in
let is_on_host = Tn.is_hosted_force key 31 in
let is_materialized = Tn.is_hosted_force key 32 in
assert (Bool.(Option.is_some hosted = is_on_host));
let is_double = Ops.is_double_prec key.prec in
let num_typ = prec_to_c_type key.prec in
let num_typ = Ops.cuda_typ_of_prec prec in
let mem = if not is_materialized then Local_only else Global in
let global = if is_local_only mem then None else Some (Tn.name key) in
let local = Option.some_if (is_local_only mem) @@ Tn.name key ^ "_local" in
Expand All @@ -237,6 +230,8 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key =
Tn.label key,
"mem",
(backend_info : Sexp.t),
"prec",
(prec : Ops.prec),
"on-host",
(is_on_host : bool),
"is-global",
Expand All @@ -245,7 +240,7 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key =
key.backend_info <- Utils.sexp_append ~elem:backend_info key.backend_info;
let zero_initialized = (Hashtbl.find_exn traced_store key).Low_level.zero_initialized in
let data =
{ hosted; local; mem; dims; size_in_bytes; size_in_elems; num_typ; is_double; global; zero_initialized }
{ hosted; local; mem; dims; size_in_bytes; size_in_elems; num_typ; prec; global; zero_initialized }
in
info.info_arrays <- Map.add_exn info.info_arrays ~key ~data;
data
Expand All @@ -254,7 +249,6 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key =

let compile_main ~traced_store info ppf llc : unit =
let open Stdlib.Format in
let locals = ref @@ Map.empty (module Low_level.Scope_id) in
let rec pp_ll ppf c : unit =
match c with
| Low_level.Noop -> ()
Expand All @@ -276,15 +270,13 @@ let compile_main ~traced_store info ppf llc : unit =
let tn = Low_level.(get_node traced_store array) in
assert tn.zero_initialized
(* The initialization will be emitted by get_array. *)
| Set { array; idcs; llv; debug } ->
let array = get_array ~traced_store info array in
let old_locals = !locals in
let loop_f = pp_float ~num_typ:array.num_typ ~is_double:array.is_double in
let loop_debug_f = debug_float ~num_typ:array.num_typ ~is_double:array.is_double in
| Set { tn; idcs; llv; debug } ->
let tn = get_array ~traced_store info tn in
let loop_f = pp_float ~num_typ:tn.num_typ tn.prec in
let loop_debug_f = debug_float ~num_typ:tn.num_typ tn.prec in
let num_closing_braces = pp_top_locals ppf llv in
(* No idea why adding any cut hint at the end of the assign line breaks formatting! *)
fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " (get_run_ptr array) pp_array_offset (idcs, array.dims) loop_f
llv;
fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " (get_run_ptr tn) pp_array_offset (idcs, tn.dims) loop_f llv;
(if Utils.settings.debug_log_from_routines then
let v_code, v_idcs = loop_debug_f llv in
let pp_args =
Expand All @@ -296,9 +288,9 @@ let compile_main ~traced_store info ppf llc : unit =
pp_comma ppf ();
pp_print_string ppf v
in
let run_ptr_debug = get_run_ptr_debug array in
let run_ptr = get_run_ptr array in
let offset = (idcs, array.dims) in
let run_ptr_debug = get_run_ptr_debug tn in
let run_ptr = get_run_ptr tn in
let offset = (idcs, tn.dims) in
let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\\n" in
fprintf ppf
"@ @[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(\"%%d: %s\", log_id);@ \
Expand All @@ -307,8 +299,7 @@ let compile_main ~traced_store info ppf llc : unit =
v_idcs);
for _ = 1 to num_closing_braces do
fprintf ppf "@]@ }@,"
done;
locals := old_locals
done
| Comment message ->
if Utils.settings.debug_log_from_routines then
fprintf ppf
Expand All @@ -317,23 +308,20 @@ let compile_main ~traced_store info ppf llc : unit =
(String.substr_replace_all ~pattern:"%" ~with_:"%%" message)
else fprintf ppf "/* %s */@ " message
| Staged_compilation callback -> callback ()
| Set_local (({ scope_id; _ } as id), value) ->
let num_typ, is_double = Map.find_exn !locals id in
let old_locals = !locals in
| Set_local (Low_level.{ scope_id; tn = { prec; _ } }, value) ->
let num_typ = Ops.cuda_typ_of_prec prec in
let num_closing_braces = pp_top_locals ppf value in
fprintf ppf "@[<2>v%d =@ %a;@]" scope_id (pp_float ~num_typ ~is_double) value;
fprintf ppf "@[<2>v%d =@ %a;@]" scope_id (pp_float ~num_typ prec) value;
for _ = 1 to num_closing_braces do
fprintf ppf "@]@ }@,"
done;
locals := old_locals
done
and pp_top_locals ppf (vcomp : Low_level.float_t) : int =
match vcomp with
| Local_scope { id = { scope_id = i; _ } as id; prec; body; orig_indices = _ } ->
let typ = prec_to_c_type prec in
| Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } ->
let num_typ = Ops.cuda_typ_of_prec prec in
(* Arrays are initialized to 0 by default. However, there is typically an explicit initialization for
virtual nodes. *)
fprintf ppf "@[<2>{@ %s v%d = 0;@ " typ i;
locals := Map.update !locals id ~f:(fun _ -> (typ, Ops.is_double_prec prec));
fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i;
pp_ll ppf body;
pp_print_space ppf ();
1
Expand All @@ -342,15 +330,15 @@ let compile_main ~traced_store info ppf llc : unit =
| Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2
| Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
| Unop (_, v) -> pp_top_locals ppf v
and pp_float ~num_typ ~is_double ppf value =
let loop = pp_float ~num_typ ~is_double in
and pp_float ~num_typ prec ppf value =
let loop = pp_float ~num_typ prec in
match value with
| Local_scope { id; _ } ->
(* Embedding of Local_scope is done by pp_top_locals. *)
loop ppf @@ Get_local id
| Get_local id ->
let typ, _local_is_double = Map.find_exn !locals id in
if not @@ String.equal num_typ typ then fprintf ppf "(%s)" num_typ;
let get_typ = Ops.cuda_typ_of_prec id.tn.prec in
if not @@ String.equal num_typ get_typ then fprintf ppf "(%s)" num_typ;
fprintf ppf "v%d" id.scope_id
| Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET"
| Get (array, idcs) ->
Expand All @@ -363,22 +351,22 @@ let compile_main ~traced_store info ppf llc : unit =
| Binop (Arg1, v1, _v2) -> loop ppf v1
| Binop (Arg2, _v1, v2) -> loop ppf v2
| Binop (op, v1, v2) ->
let prefix, infix, postfix = Ops.binop_C_syntax ~is_double op in
let prefix, infix, postfix = Ops.binop_C_syntax prec op in
fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix
| Unop (Identity, v) -> loop ppf v
| Unop (Relu, v) ->
(* FIXME: don't recompute v *)
fprintf ppf "@[<1>(%a > 0.0 ?@ %a : 0.0@])" loop v loop v
and debug_float ~num_typ ~is_double (value : Low_level.float_t) : string * 'a list =
let loop = debug_float ~num_typ ~is_double in
and debug_float ~num_typ prec (value : Low_level.float_t) : string * 'a list =
let loop = debug_float ~num_typ prec in
match value with
| Local_scope { id; _ } ->
(* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *)
loop @@ Get_local id
| Get_local id ->
let typ, _local_is_double = Map.find_exn !locals id in
let get_typ = Ops.cuda_typ_of_prec id.tn.prec in
let v =
(if not @@ String.equal num_typ typ then "(" ^ num_typ ^ ")" else "")
(if not @@ String.equal num_typ get_typ then "(" ^ num_typ ^ ")" else "")
^ "v" ^ Int.to_string id.scope_id
in
(v ^ "{=%f}", [ `Value v ])
Expand All @@ -393,7 +381,7 @@ let compile_main ~traced_store info ppf llc : unit =
| Binop (Arg1, v1, _v2) -> loop v1
| Binop (Arg2, _v1, v2) -> loop v2
| Binop (op, v1, v2) ->
let prefix, infix, postfix = Ops.binop_C_syntax ~is_double op in
let prefix, infix, postfix = Ops.binop_C_syntax prec op in
let v1, idcs1 = loop v1 in
let v2, idcs2 = loop v2 in
(String.concat [ prefix; v1; infix; " "; v2; postfix ], idcs1 @ idcs2)
Expand Down
57 changes: 27 additions & 30 deletions arrayjit/lib/gccjit_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type ndarray = {
num_typ : (Gccjit.type_[@sexp.opaque]);
(** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), [double]
(precision [Double]). *)
is_double : bool;
prec : Ops.prec;
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -142,14 +142,8 @@ let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializati
let is_materialized = Tn.is_materialized_force key 331 in
let is_constant = Tn.is_hosted_force ~specifically:Constant key 332 in
assert (Bool.(Option.is_some (Lazy.force key.array) = is_on_host));
let c_typ, is_double =
match key.prec with
| Byte_prec _ -> (Type.Unsigned_char, false)
| Half_prec _ -> (* FIXME: *) (Type.Float, false)
| Single_prec _ -> (Type.Float, false)
| Double_prec _ -> (Type.Double, true)
| Void_prec -> assert false
in
let prec = key.prec in
let c_typ = Ops.gcc_typ_of_prec prec in
let num_typ = Type.(get ctx c_typ) in
let ptr_typ = Type.pointer num_typ in
let mem =
Expand Down Expand Up @@ -193,7 +187,7 @@ let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializati
(* The array is the pointer but the address of the array is the same pointer. *)
lazy (RValue.cast ctx (LValue.address @@ Option.value_exn !v) ptr_typ)
in
let result = { nd = key; ptr; mem; dims; size_in_bytes; num_typ; is_double } in
let result = { nd = key; ptr; mem; dims; size_in_bytes; num_typ; prec } in
let backend_info = sexp_of_mem_properties mem in
let initialize init_block _func =
Block.comment init_block
Expand Down Expand Up @@ -299,23 +293,25 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
let locals = ref @@ Map.empty (module Low_level.Scope_id) in
let debug_locals = ref !locals in
let current_block = ref initial_block in
let loop_binop op ~num_typ ~is_double ~v1 ~v2 =
match op with
| Ops.Add -> RValue.binary_op ctx Plus num_typ v1 v2
| Sub -> RValue.binary_op ctx Minus num_typ v1 v2
| Mul -> RValue.binary_op ctx Mult num_typ v1 v2
| Div -> RValue.binary_op ctx Divide num_typ v1 v2
| ToPowOf when is_double ->
let loop_binop op ~num_typ prec ~v1 ~v2 =
match (op, prec) with
| _, Ops.Void_prec -> raise @@ Utils.User_error "gccjit_backend: binary operation on Void_prec"
| Ops.Add, _ -> RValue.binary_op ctx Plus num_typ v1 v2
| Sub, _ -> RValue.binary_op ctx Minus num_typ v1 v2
| Mul, _ -> RValue.binary_op ctx Mult num_typ v1 v2
| Div, _ -> RValue.binary_op ctx Divide num_typ v1 v2
| ToPowOf, Double_prec _ ->
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "pow") [ to_d v1; to_d v2 ]) num_typ
| ToPowOf ->
| ToPowOf, (Single_prec _ | Half_prec _ (* FIXME: *)) ->
let base = RValue.cast ctx v1 c_float in
let expon = RValue.cast ctx v2 c_float in
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "powf") [ base; expon ]) num_typ
| Relu_gate ->
| ToPowOf, Byte_prec _ -> raise @@ Utils.User_error "gccjit_backend: Byte_prec does not support ToPowOf"
| Relu_gate, _ ->
let cmp = cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) v1 in
RValue.binary_op ctx Mult num_typ cmp @@ v2
| Arg2 -> v2
| Arg1 -> v1
| Arg2, _ -> v2
| Arg1, _ -> v1
in
let log_comment c =
match log_functions with
Expand All @@ -333,7 +329,7 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
(* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *)
loop @@ Get_local id
| Get_local id ->
let lvalue, _typ, _local_is_double = Map.find_exn !debug_locals id in
let lvalue, _typ = Map.find_exn !debug_locals id in
(* FIXME(194): Convert according to _typ ?= num_typ. *)
(LValue.to_string lvalue ^ "{=%g}", [ to_d @@ RValue.lvalue lvalue ])
| Get_global (C_function f_name, None) -> ("<calls " ^ f_name ^ ">", [])
Expand All @@ -360,7 +356,7 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
| Binop (Arg1, v1, _v2) -> loop v1
| Binop (Arg2, _v1, v2) -> loop v2
| Binop (op, v1, v2) ->
let prefix, infix, postfix = Ops.binop_C_syntax ~is_double op in
let prefix, infix, postfix = Ops.binop_C_syntax prec op in
let v1, fillers1 = loop v1 in
let v2, fillers2 = loop v2 in
(String.concat [ prefix; v1; infix; " "; v2; postfix ], fillers1 @ fillers2)
Expand Down Expand Up @@ -425,15 +421,15 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
let tn = Low_level.(get_node info.traced_store array) in
assert tn.zero_initialized (* The initialization is emitted by get_array. *)
| Set_local (id, llv) ->
let lhs, num_typ, is_double = Map.find_exn !locals id in
let value = loop_float ~name ~env ~num_typ ~is_double llv in
let lhs, num_typ = Map.find_exn !locals id in
let value = loop_float ~name ~env ~num_typ id.tn.prec llv in
Block.assign !current_block lhs value
| Comment c -> log_comment c
| Staged_compilation exp -> exp ()
and loop_float ~name ~env ~num_typ ~is_double v_code =
let loop = loop_float ~name ~env ~num_typ ~is_double in
match v_code with
| Local_scope { id = { scope_id = i; _ } as id; prec; body; orig_indices = _ } ->
| Local_scope { id = { scope_id = i; tn = {prec; _} } as id; body; orig_indices = _ } ->
let typ = Type.get ctx @@ prec_to_kind prec in
(* Scope ids can be non-unique due to inlining. *)
let v_name = Int.("v" ^ to_string i ^ "_" ^ get_uid ()) in
Expand All @@ -442,13 +438,14 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
virtual nodes. *)
Block.assign !current_block lvalue @@ RValue.zero ctx typ;
let old_locals = !locals in
locals := Map.update !locals id ~f:(fun _ -> (lvalue, typ, prec_is_double prec));
locals := Map.update !locals id ~f:(fun _ -> (lvalue, typ));
loop_proc ~toplevel:false ~env ~name:(name ^ "_at_" ^ v_name) body;
debug_locals := Map.update !debug_locals id ~f:(fun _ -> (lvalue, typ, prec_is_double prec));
debug_locals := Map.update !debug_locals id ~f:(fun _ -> (lvalue, typ));
locals := old_locals;
RValue.lvalue lvalue
| Get_local id ->
let lvalue, _typ, _local_is_double = Map.find_exn !locals id in
let lvalue, _typ = Map.find_exn !locals id in

(* FIXME(194): Convert according to _typ ?= num_typ. *)
RValue.lvalue lvalue
| Get_global (C_function f_name, None) ->
Expand Down Expand Up @@ -482,7 +479,7 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
raise e)
| Binop (Arg2, _, c2) -> loop c2
| Binop (Arg1, c1, _) -> loop c1
| Binop (op, c1, c2) -> loop_binop op ~num_typ ~is_double ~v1:(loop c1) ~v2:(loop c2)
| Binop (op, c1, c2) -> loop_binop op ~num_typ prec ~v1:(loop c1) ~v2:(loop c2)
| Unop (Identity, c) -> loop c
| Unop (Relu, c) ->
(* FIXME: don't recompute c *)
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ let fprint_hum ?(ident_style = `Heuristic_ocannl) ?name ?static_indices () ppf l
| Binop (Arg1, v1, _v2) -> pp_float prec ppf v1
| Binop (Arg2, _v1, v2) -> pp_float prec ppf v2
| Binop (op, v1, v2) ->
let prefix, infix, postfix = Ops.binop_C_syntax op in
let prefix, infix, postfix = Ops.binop_C_syntax prec op in
fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix (pp_float prec) v1 infix (pp_float prec) v2 postfix
| Unop (Identity, v) -> (pp_float prec) ppf v
| Unop (Relu, v) -> fprintf ppf "@[<1>relu(%a@])" (pp_float prec) v
Expand Down
1 change: 0 additions & 1 deletion arrayjit/lib/ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ let precision_string = function
| Double_nd _ -> "double"

let default_kind = Ops.Single
let is_double_prec_t = function Double_nd _ -> true | _ -> false

let get_prec = function
| Byte_nd _ -> Ops.byte
Expand Down
Loading

0 comments on commit 713e2ec

Please sign in to comment.