Skip to content

Commit

Permalink
gccjit_backend.ml: big renaming, avoid confusing array
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Apr 30, 2024
1 parent 713e2ec commit 6c3070b
Showing 1 changed file with 64 additions and 65 deletions.
129 changes: 64 additions & 65 deletions arrayjit/lib/gccjit_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ let init ~label =
Core.Gc.Expert.add_finalizer_exn result finalize;
result

type ndarray = {
nd : Tn.t; (** The original array. *)
type tn_info = {
tn : Tn.t; (** The original array. *)
mutable ptr : (Gccjit.rvalue[@sexp.opaque]) Lazy.t;
(** Pointer to the first value of the associated array.
- if [mem = Constant_from_host], the pointer to the first element of the hosted [Ndarray],
Expand All @@ -72,7 +72,7 @@ type info = {
func : (Gccjit.function_[@sexp.opaque]);
traced_store : (Low_level.traced_store[@sexp.opaque]);
init_block : (Gccjit.block[@sexp.opaque]);
arrays : (Tn.t, ndarray) Hashtbl.t;
nodes : (Tn.t, tn_info) Hashtbl.t;
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -118,23 +118,23 @@ let jit_array_offset ctx ~idcs ~dims =
RValue.binary_op ctx Plus c_index idx
@@ RValue.binary_op ctx Mult c_index offset (RValue.int ctx c_index dim))

let zero_out ctx block arr =
let zero_out ctx block node =
let open Gccjit in
let c_index = Type.get ctx Type.Size_t in
let c_int = Type.get ctx Type.Int in
Block.eval block
@@ RValue.call ctx (Function.builtin ctx "memset")
[ Lazy.force arr.ptr; RValue.zero ctx c_int; RValue.int ctx c_index arr.size_in_bytes ]
[ Lazy.force node.ptr; RValue.zero ctx c_int; RValue.int ctx c_index node.size_in_bytes ]

let get_c_ptr ctx num_typ ba =
Gccjit.(RValue.ptr ctx (Type.pointer num_typ) @@ Ctypes.bigarray_start Ctypes_static.Genarray ba)

let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializations (key : Tn.t) =
let prepare_node ~debug_log_zero_out ctx nodes traced_store ctx_nodes initializations (key : Tn.t) =
let open Gccjit in
Hashtbl.update arrays key ~f:(function
Hashtbl.update nodes key ~f:(function
| Some old -> old
| None ->
let ta = Low_level.(get_node traced_store key) in
let traced = Low_level.(get_node traced_store key) in
let dims = Lazy.force key.dims in
let size_in_elems = Array.fold ~init:1 ~f:( * ) dims in
let size_in_bytes = size_in_elems * Ops.prec_in_bytes key.prec in
Expand All @@ -148,12 +148,12 @@ let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializati
let ptr_typ = Type.pointer num_typ in
let mem =
if not is_materialized then Local_only
else if is_constant && ta.read_only then Constant_from_host
else if is_constant && traced.read_only then Constant_from_host
else From_context
in
let name = Tn.name key in
let ptr =
match (mem, nodes) with
match (mem, ctx_nodes) with
| From_context, Ctx_arrays ctx_arrays -> (
match Map.find !ctx_arrays key with
| None ->
Expand Down Expand Up @@ -187,14 +187,14 @@ let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializati
(* The array is the pointer but the address of the array is the same pointer. *)
lazy (RValue.cast ctx (LValue.address @@ Option.value_exn !v) ptr_typ)
in
let result = { nd = key; ptr; mem; dims; size_in_bytes; num_typ; prec } in
let result = { tn = key; ptr; mem; dims; size_in_bytes; num_typ; prec } in
let backend_info = sexp_of_mem_properties mem in
let initialize init_block _func =
Block.comment init_block
[%string
"Array #%{key.id#Int} %{Tn.label key}: %{Sexp.to_string_hum @@ backend_info}; ptr: \
%{Sexp.to_string_hum @@ sexp_of_gccjit_rvalue @@ Lazy.force ptr}."];
if ta.zero_initialized then (
if traced.zero_initialized then (
debug_log_zero_out init_block result;
zero_out ctx init_block result)
in
Expand All @@ -212,7 +212,6 @@ let prec_to_kind prec =
| Single_prec _ -> Type.Float
| Double_prec _ -> Type.Double

let prec_is_double = function Ops.Double_prec _ -> true | _ -> false
let is_builtin_op = function Ops.Add | Sub | Mul | Div -> true | ToPowOf | Relu_gate | Arg2 | Arg1 -> false

let builtin_op = function
Expand All @@ -222,15 +221,15 @@ let builtin_op = function
| Div -> Gccjit.Divide
| ToPowOf | Relu_gate | Arg2 | Arg1 -> invalid_arg "Exec_as_gccjit.builtin_op: not a builtin"

let arr_debug_name array =
let node_debug_name node =
let memloc =
if Utils.settings.debug_memory_locations && Lazy.is_val array.ptr then
"@" ^ Gccjit.RValue.to_string (Lazy.force array.ptr)
if Utils.settings.debug_memory_locations && Lazy.is_val node.ptr then
"@" ^ Gccjit.RValue.to_string (Lazy.force node.ptr)
else ""
in
Tn.name array.nd ^ memloc
Tn.name node.tn ^ memloc

let debug_log_zero_out ctx log_functions block array =
let debug_log_zero_out ctx log_functions block node =
let open Gccjit in
let c_index = Type.get ctx Type.Int in
match Lazy.force log_functions with
Expand All @@ -240,9 +239,9 @@ let debug_log_zero_out ctx log_functions block array =
Block.eval block @@ RValue.call ctx pf
@@ lf
:: RValue.string_literal ctx
[%string {|memset_zero(%{arr_debug_name array}) where before first element = %g
[%string {|memset_zero(%{node_debug_name node}) where before first element = %g
|}]
:: [ to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force array.ptr) @@ RValue.zero ctx c_index ];
:: [ to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) @@ RValue.zero ctx c_index ];
Block.eval block @@ RValue.call ctx ff [ lf ]
| _ -> ()

Expand All @@ -257,7 +256,7 @@ let debug_log_index ctx log_functions =
Block.eval block @@ RValue.call ctx ff [ lf ]
| _ -> fun _block _i _index -> ()

let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func initial_block
let compile_main ~name ~log_functions ~env ({ ctx; nodes; _ } as info) func initial_block
(body : Low_level.t) =
let open Gccjit in
let c_int = Type.get ctx Type.Int in
Expand Down Expand Up @@ -321,9 +320,9 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
Block.eval !current_block @@ RValue.call ctx ff [ lf ]
| _ -> Block.comment !current_block c
in
let get_array = Hashtbl.find_exn arrays in
let rec debug_float ~env ~is_double value =
let loop = debug_float ~env ~is_double in
let get_node = Hashtbl.find_exn nodes in
let rec debug_float ~env prec value =
let loop = debug_float ~env prec in
match value with
| Low_level.Local_scope { id; _ } ->
(* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *)
Expand All @@ -344,12 +343,12 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
| Get_global (External_unsafe _, None) -> assert false
| Get_global (C_function _, Some _) -> failwith "gccjit_backend: FFI with parameters NOT IMPLEMENTED YET"
| Get (ptr, idcs) ->
let array = get_array ptr in
let array = get_node ptr in
let idcs = lookup env idcs in
let offset = jit_array_offset ctx ~idcs ~dims:array.dims in
(* FIXME(194): Convert according to array.typ ?= num_typ. *)
let v = to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force array.ptr) offset in
(arr_debug_name array ^ "[%d]{=%g}", [ offset; v ])
(node_debug_name array ^ "[%d]{=%g}", [ offset; v ])
| Constant c -> (Float.to_string c, [])
| Embed_index (Fixed_idx i) -> (Int.to_string i, [])
| Embed_index (Iterator s) -> (Indexing.symbol_ident s ^ "{=%d}", [ Map.find_exn env s ])
Expand All @@ -365,20 +364,20 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
let v, fillers = loop v in
(String.concat [ "("; v; " > 0.0 ? "; v; " : 0.0)" ], fillers @ fillers)
in
let debug_log_assignment ~env debug idcs array accum_op value v_code =
let debug_log_assignment ~env debug idcs node accum_op value v_code =
match log_functions with
| Some (lf, pf, ff) ->
let v_format, v_fillers = debug_float ~env ~is_double:array.is_double v_code in
let offset = jit_array_offset ctx ~idcs ~dims:array.dims in
let v_format, v_fillers = debug_float ~env node.prec v_code in
let offset = jit_array_offset ctx ~idcs ~dims:node.dims in
let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\n" in
Block.eval !current_block @@ RValue.call ctx pf @@ [ lf; RValue.string_literal ctx debug_line ];
Block.eval !current_block @@ RValue.call ctx pf
@@ lf
:: RValue.string_literal ctx
[%string
{|%{arr_debug_name array}[%d]{=%g} %{Ops.assign_op_C_syntax accum_op} %g = %{v_format}
{|%{node_debug_name node}[%d]{=%g} %{Ops.assign_op_C_syntax accum_op} %g = %{v_format}
|}]
:: (to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force array.ptr) offset)
:: (to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) offset)
:: offset :: to_d value :: v_fillers;
Block.eval !current_block @@ RValue.call ctx ff [ lf ]
| _ -> ()
Expand All @@ -394,40 +393,40 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
| For_loop { index; from_; to_; body; trace_it = _ } ->
loop_for_loop ~toplevel ~env index ~from_ ~to_ body
| Set { llv = Binop (Arg2, Get (_, _), _); _ } -> assert false
| Set { array; idcs; llv = Binop (op, Get (array2, idcs2), c2); debug }
when Tn.equal array array2 && [%equal: Indexing.axis_index array] idcs idcs2 && is_builtin_op op ->
| Set { tn; idcs; llv = Binop (op, Get (tn2, idcs2), c2); debug }
when Tn.equal tn tn2 && [%equal: Indexing.axis_index array] idcs idcs2 && is_builtin_op op ->
(* FIXME: maybe it's not worth it? *)
let array = get_array array in
let value = loop_float ~name ~env ~num_typ:array.num_typ ~is_double:array.is_double c2 in
let node = get_node tn in
let value = loop_float ~name ~env ~num_typ:node.num_typ node.prec c2 in
let idcs = lookup env idcs in
let offset = jit_array_offset ctx ~idcs ~dims:array.dims in
let lhs = LValue.access_array (Lazy.force array.ptr) offset in
debug_log_assignment ~env debug idcs array op value c2;
let offset = jit_array_offset ctx ~idcs ~dims:node.dims in
let lhs = LValue.access_array (Lazy.force node.ptr) offset in
debug_log_assignment ~env debug idcs node op value c2;
Block.assign_op !current_block lhs (builtin_op op) value
| Set { array; idcs; llv; debug } ->
let array = get_array array in
let value = loop_float ~name ~env ~num_typ:array.num_typ ~is_double:array.is_double llv in
| Set { tn; idcs; llv; debug } ->
let node = get_node tn in
let value = loop_float ~name ~env ~num_typ:node.num_typ node.prec llv in
let idcs = lookup env idcs in
let offset = jit_array_offset ctx ~idcs ~dims:array.dims in
let lhs = LValue.access_array (Lazy.force array.ptr) offset in
debug_log_assignment ~env debug idcs array Ops.Arg2 value llv;
let offset = jit_array_offset ctx ~idcs ~dims:node.dims in
let lhs = LValue.access_array (Lazy.force node.ptr) offset in
debug_log_assignment ~env debug idcs node Ops.Arg2 value llv;
Block.assign !current_block lhs value
| Zero_out array ->
if Hashtbl.mem info.arrays array then (
let array = Hashtbl.find_exn info.arrays array in
if Hashtbl.mem info.nodes array then (
let array = Hashtbl.find_exn info.nodes array in
debug_log_zero_out ctx (lazy log_functions) !current_block array;
zero_out ctx !current_block array)
else
let tn = Low_level.(get_node info.traced_store array) in
assert tn.zero_initialized (* The initialization is emitted by get_array. *)
assert tn.zero_initialized (* The initialization is emitted by prepare_nodes. *)
| Set_local (id, llv) ->
let lhs, num_typ = Map.find_exn !locals id in
let value = loop_float ~name ~env ~num_typ id.tn.prec llv in
Block.assign !current_block lhs value
| Comment c -> log_comment c
| Staged_compilation exp -> exp ()
and loop_float ~name ~env ~num_typ ~is_double v_code =
let loop = loop_float ~name ~env ~num_typ ~is_double in
and loop_float ~name ~env ~num_typ prec v_code =
let loop = loop_float ~name ~env ~num_typ prec in
match v_code with
| Local_scope { id = { scope_id = i; tn = {prec; _} } as id; body; orig_indices = _ } ->
let typ = Type.get ctx @@ prec_to_kind prec in
Expand Down Expand Up @@ -461,12 +460,12 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
RValue.lvalue @@ LValue.access_array ptr offset
| Get_global (External_unsafe _, None) -> assert false
| Get_global (C_function _, Some _) -> failwith "gccjit_backend: FFI with parameters NOT IMPLEMENTED YET"
| Get (array, idcs) ->
let array = get_array array in
| Get (tn, idcs) ->
let node = get_node tn in
let idcs = lookup env idcs in
let offset = jit_array_offset ctx ~idcs ~dims:array.dims in
let offset = jit_array_offset ctx ~idcs ~dims:node.dims in
(* FIXME(194): Convert according to array.typ ?= num_typ. *)
RValue.lvalue @@ LValue.access_array (Lazy.force array.ptr) offset
RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) offset
| Embed_index (Fixed_idx i) -> RValue.cast ctx (RValue.int ctx num_typ i) num_typ
| Embed_index (Iterator s) -> (
try RValue.cast ctx (Map.find_exn env s) num_typ
Expand Down Expand Up @@ -510,26 +509,26 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini
loop_proc ~toplevel:true ~name ~env body;
!current_block

let prepare_arrays ctx ~log_functions arrays traced_store nodes initializations (llc : Low_level.t) =
let prepare_nodes ctx ~log_functions nodes traced_store ctx_nodes initializations (llc : Low_level.t) =
let debug_log_zero_out = debug_log_zero_out ctx log_functions in
let prepare_array = prepare_array ctx ~debug_log_zero_out arrays traced_store nodes initializations in
let prepare_node = prepare_node ctx ~debug_log_zero_out nodes traced_store ctx_nodes initializations in
let rec loop llc =
match llc with
| Low_level.Noop | Low_level.Comment _ | Low_level.Staged_compilation _ -> ()
| Low_level.Seq (c1, c2) ->
loop c1;
loop c2
| Low_level.For_loop { body; _ } -> loop body
| Low_level.Zero_out arr -> prepare_array arr
| Low_level.Set { array; llv; _ } ->
prepare_array array;
| Low_level.Zero_out tn -> prepare_node tn
| Low_level.Set { tn; llv; _ } ->
prepare_node tn;
loop_float llv
| Low_level.Set_local (_, llv) -> loop_float llv
and loop_float llv =
match llv with
| Low_level.Local_scope { body; _ } -> loop body
| Low_level.Get_local _ | Low_level.Get_global (_, _) -> ()
| Low_level.Get (arr, _) -> prepare_array arr
| Low_level.Get (tn, _) -> prepare_node tn
| Low_level.Binop (_, v1, v2) ->
loop_float v1;
loop_float v2
Expand Down Expand Up @@ -557,16 +556,16 @@ let%track_sexp compile_func ~name ~opt_ctx_arrays ctx bindings (traced_store, pr
ref @@ Option.value opt_ctx_arrays ~default:(Map.empty (module Tn))
in
let params : (gccjit_param * param_source) list ref = ref (Option.to_list log_file_name @ static_indices) in
let nodes : ctx_nodes =
let ctx_nodes : ctx_nodes =
if Option.is_none opt_ctx_arrays then Param_ptrs params else Ctx_arrays ctx_arrays
in
let initializations = ref [] in
let arrays = Hashtbl.create (module Tn) in
let nodes = Hashtbl.create (module Tn) in
let log_functions_ref = ref None in
let log_functions = lazy !log_functions_ref in
prepare_arrays ~log_functions ctx arrays traced_store nodes initializations proc;
prepare_nodes ~log_functions ctx nodes traced_store ctx_nodes initializations proc;
let params : (gccjit_param * param_source) list =
match nodes with Param_ptrs ps -> !ps | Ctx_arrays _ -> !params
match ctx_nodes with Param_ptrs ps -> !ps | Ctx_arrays _ -> !params
in
let func = Function.create ctx fkind (Type.get ctx Void) name @@ List.map ~f:fst params in
let env =
Expand Down Expand Up @@ -597,7 +596,7 @@ let%track_sexp compile_func ~name ~opt_ctx_arrays ctx bindings (traced_store, pr
(* Do initializations in the order they were scheduled. *)
List.iter (List.rev !initializations) ~f:(fun init -> init init_block func);
let main_block = Block.create ~name func in
let ctx_info : info = { ctx; traced_store; init_block; func; arrays } in
let ctx_info : info = { ctx; traced_store; init_block; func; nodes } in
let after_proc = compile_main ~name ~log_functions ~env ctx_info func main_block proc in
(match log_functions with
| Some (lf, _, _) ->
Expand Down Expand Up @@ -756,7 +755,7 @@ let%track_sexp merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst ~accum
Low_level.(
Set
{
array = dst;
tn = dst;
idcs;
llv =
Binop
Expand Down

0 comments on commit 6c3070b

Please sign in to comment.