diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index b918e81f..1f6c7797 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -18,7 +18,7 @@ type tn_info = { num_typ : string; (** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), [double] (precision [Double]). *) - is_double : bool; + prec : Ops.prec; zero_initialized : bool; } [@@deriving sexp_of] @@ -200,13 +200,6 @@ let get_run_ptr_debug array = | Some rv, _ -> "global_" ^ rv | None, None -> assert false -let prec_to_c_type = function - | Ops.Void_prec -> "void" - | Byte_prec _ -> "uint8" - | Half_prec _ -> (* FIXME: *) "uint16" - | Single_prec _ -> "float" - | Double_prec _ -> "double" - (* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim -> idx + (offset * dim)) *) @@ -220,12 +213,12 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key = let dims = Lazy.force key.dims in let size_in_elems = Array.fold ~init:1 ~f:( * ) dims in let hosted = Lazy.force key.array in - let size_in_bytes = size_in_elems * Ops.prec_in_bytes key.prec in + let prec = key.prec in + let size_in_bytes = size_in_elems * Ops.prec_in_bytes prec in let is_on_host = Tn.is_hosted_force key 31 in let is_materialized = Tn.is_hosted_force key 32 in assert (Bool.(Option.is_some hosted = is_on_host)); - let is_double = Ops.is_double_prec key.prec in - let num_typ = prec_to_c_type key.prec in + let num_typ = Ops.cuda_typ_of_prec prec in let mem = if not is_materialized then Local_only else Global in let global = if is_local_only mem then None else Some (Tn.name key) in let local = Option.some_if (is_local_only mem) @@ Tn.name key ^ "_local" in @@ -237,6 +230,8 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key = Tn.label key, "mem", (backend_info : Sexp.t), + "prec", + (prec : Ops.prec), "on-host", (is_on_host : bool), "is-global", @@ -245,7 +240,7 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key = key.backend_info <- Utils.sexp_append ~elem:backend_info key.backend_info; let zero_initialized = (Hashtbl.find_exn traced_store key).Low_level.zero_initialized in let data = - { hosted; local; mem; dims; size_in_bytes; size_in_elems; num_typ; is_double; global; zero_initialized } + { hosted; local; mem; dims; size_in_bytes; size_in_elems; num_typ; prec; global; zero_initialized } in info.info_arrays <- Map.add_exn info.info_arrays ~key ~data; data @@ -254,7 +249,6 @@ let%debug_sexp get_array ~(traced_store : Low_level.traced_store) info key = let compile_main ~traced_store info ppf llc : unit = let open Stdlib.Format in - let locals = ref @@ Map.empty (module Low_level.Scope_id) in let rec pp_ll ppf c : unit = match c with | Low_level.Noop -> () @@ -276,15 +270,13 @@ let compile_main ~traced_store info ppf llc : unit = let tn = Low_level.(get_node traced_store array) in assert tn.zero_initialized (* The initialization will be emitted by get_array. *) - | Set { array; idcs; llv; debug } -> - let array = get_array ~traced_store info array in - let old_locals = !locals in - let loop_f = pp_float ~num_typ:array.num_typ ~is_double:array.is_double in - let loop_debug_f = debug_float ~num_typ:array.num_typ ~is_double:array.is_double in + | Set { tn; idcs; llv; debug } -> + let tn = get_array ~traced_store info tn in + let loop_f = pp_float ~num_typ:tn.num_typ tn.prec in + let loop_debug_f = debug_float ~num_typ:tn.num_typ tn.prec in let num_closing_braces = pp_top_locals ppf llv in (* No idea why adding any cut hint at the end of the assign line breaks formatting! *) - fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " (get_run_ptr array) pp_array_offset (idcs, array.dims) loop_f - llv; + fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " (get_run_ptr tn) pp_array_offset (idcs, tn.dims) loop_f llv; (if Utils.settings.debug_log_from_routines then let v_code, v_idcs = loop_debug_f llv in let pp_args = @@ -296,9 +288,9 @@ let compile_main ~traced_store info ppf llc : unit = pp_comma ppf (); pp_print_string ppf v in - let run_ptr_debug = get_run_ptr_debug array in - let run_ptr = get_run_ptr array in - let offset = (idcs, array.dims) in + let run_ptr_debug = get_run_ptr_debug tn in + let run_ptr = get_run_ptr tn in + let offset = (idcs, tn.dims) in let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\\n" in fprintf ppf "@ @[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(\"%%d: %s\", log_id);@ \ @@ -307,8 +299,7 @@ let compile_main ~traced_store info ppf llc : unit = v_idcs); for _ = 1 to num_closing_braces do fprintf ppf "@]@ }@," - done; - locals := old_locals + done | Comment message -> if Utils.settings.debug_log_from_routines then fprintf ppf @@ -317,23 +308,20 @@ let compile_main ~traced_store info ppf llc : unit = (String.substr_replace_all ~pattern:"%" ~with_:"%%" message) else fprintf ppf "/* %s */@ " message | Staged_compilation callback -> callback () - | Set_local (({ scope_id; _ } as id), value) -> - let num_typ, is_double = Map.find_exn !locals id in - let old_locals = !locals in + | Set_local (Low_level.{ scope_id; tn = { prec; _ } }, value) -> + let num_typ = Ops.cuda_typ_of_prec prec in let num_closing_braces = pp_top_locals ppf value in - fprintf ppf "@[<2>v%d =@ %a;@]" scope_id (pp_float ~num_typ ~is_double) value; + fprintf ppf "@[<2>v%d =@ %a;@]" scope_id (pp_float ~num_typ prec) value; for _ = 1 to num_closing_braces do fprintf ppf "@]@ }@," - done; - locals := old_locals + done and pp_top_locals ppf (vcomp : Low_level.float_t) : int = match vcomp with - | Local_scope { id = { scope_id = i; _ } as id; prec; body; orig_indices = _ } -> - let typ = prec_to_c_type prec in + | Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } -> + let num_typ = Ops.cuda_typ_of_prec prec in (* Arrays are initialized to 0 by default. However, there is typically an explicit initialization for virtual nodes. *) - fprintf ppf "@[<2>{@ %s v%d = 0;@ " typ i; - locals := Map.update !locals id ~f:(fun _ -> (typ, Ops.is_double_prec prec)); + fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i; pp_ll ppf body; pp_print_space ppf (); 1 @@ -342,15 +330,15 @@ let compile_main ~traced_store info ppf llc : unit = | Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2 | Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2 | Unop (_, v) -> pp_top_locals ppf v - and pp_float ~num_typ ~is_double ppf value = - let loop = pp_float ~num_typ ~is_double in + and pp_float ~num_typ prec ppf value = + let loop = pp_float ~num_typ prec in match value with | Local_scope { id; _ } -> (* Embedding of Local_scope is done by pp_top_locals. *) loop ppf @@ Get_local id | Get_local id -> - let typ, _local_is_double = Map.find_exn !locals id in - if not @@ String.equal num_typ typ then fprintf ppf "(%s)" num_typ; + let get_typ = Ops.cuda_typ_of_prec id.tn.prec in + if not @@ String.equal num_typ get_typ then fprintf ppf "(%s)" num_typ; fprintf ppf "v%d" id.scope_id | Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET" | Get (array, idcs) -> @@ -363,22 +351,22 @@ let compile_main ~traced_store info ppf llc : unit = | Binop (Arg1, v1, _v2) -> loop ppf v1 | Binop (Arg2, _v1, v2) -> loop ppf v2 | Binop (op, v1, v2) -> - let prefix, infix, postfix = Ops.binop_C_syntax ~is_double op in + let prefix, infix, postfix = Ops.binop_C_syntax prec op in fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix | Unop (Identity, v) -> loop ppf v | Unop (Relu, v) -> (* FIXME: don't recompute v *) fprintf ppf "@[<1>(%a > 0.0 ?@ %a : 0.0@])" loop v loop v - and debug_float ~num_typ ~is_double (value : Low_level.float_t) : string * 'a list = - let loop = debug_float ~num_typ ~is_double in + and debug_float ~num_typ prec (value : Low_level.float_t) : string * 'a list = + let loop = debug_float ~num_typ prec in match value with | Local_scope { id; _ } -> (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *) loop @@ Get_local id | Get_local id -> - let typ, _local_is_double = Map.find_exn !locals id in + let get_typ = Ops.cuda_typ_of_prec id.tn.prec in let v = - (if not @@ String.equal num_typ typ then "(" ^ num_typ ^ ")" else "") + (if not @@ String.equal num_typ get_typ then "(" ^ num_typ ^ ")" else "") ^ "v" ^ Int.to_string id.scope_id in (v ^ "{=%f}", [ `Value v ]) @@ -393,7 +381,7 @@ let compile_main ~traced_store info ppf llc : unit = | Binop (Arg1, v1, _v2) -> loop v1 | Binop (Arg2, _v1, v2) -> loop v2 | Binop (op, v1, v2) -> - let prefix, infix, postfix = Ops.binop_C_syntax ~is_double op in + let prefix, infix, postfix = Ops.binop_C_syntax prec op in let v1, idcs1 = loop v1 in let v2, idcs2 = loop v2 in (String.concat [ prefix; v1; infix; " "; v2; postfix ], idcs1 @ idcs2) diff --git a/arrayjit/lib/gccjit_backend.ml b/arrayjit/lib/gccjit_backend.ml index 18da4022..e7ae36b9 100644 --- a/arrayjit/lib/gccjit_backend.ml +++ b/arrayjit/lib/gccjit_backend.ml @@ -63,7 +63,7 @@ type ndarray = { num_typ : (Gccjit.type_[@sexp.opaque]); (** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), [double] (precision [Double]). *) - is_double : bool; + prec : Ops.prec; } [@@deriving sexp_of] @@ -142,14 +142,8 @@ let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializati let is_materialized = Tn.is_materialized_force key 331 in let is_constant = Tn.is_hosted_force ~specifically:Constant key 332 in assert (Bool.(Option.is_some (Lazy.force key.array) = is_on_host)); - let c_typ, is_double = - match key.prec with - | Byte_prec _ -> (Type.Unsigned_char, false) - | Half_prec _ -> (* FIXME: *) (Type.Float, false) - | Single_prec _ -> (Type.Float, false) - | Double_prec _ -> (Type.Double, true) - | Void_prec -> assert false - in + let prec = key.prec in + let c_typ = Ops.gcc_typ_of_prec prec in let num_typ = Type.(get ctx c_typ) in let ptr_typ = Type.pointer num_typ in let mem = @@ -193,7 +187,7 @@ let prepare_array ~debug_log_zero_out ctx arrays traced_store nodes initializati (* The array is the pointer but the address of the array is the same pointer. *) lazy (RValue.cast ctx (LValue.address @@ Option.value_exn !v) ptr_typ) in - let result = { nd = key; ptr; mem; dims; size_in_bytes; num_typ; is_double } in + let result = { nd = key; ptr; mem; dims; size_in_bytes; num_typ; prec } in let backend_info = sexp_of_mem_properties mem in let initialize init_block _func = Block.comment init_block @@ -299,23 +293,25 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini let locals = ref @@ Map.empty (module Low_level.Scope_id) in let debug_locals = ref !locals in let current_block = ref initial_block in - let loop_binop op ~num_typ ~is_double ~v1 ~v2 = - match op with - | Ops.Add -> RValue.binary_op ctx Plus num_typ v1 v2 - | Sub -> RValue.binary_op ctx Minus num_typ v1 v2 - | Mul -> RValue.binary_op ctx Mult num_typ v1 v2 - | Div -> RValue.binary_op ctx Divide num_typ v1 v2 - | ToPowOf when is_double -> + let loop_binop op ~num_typ prec ~v1 ~v2 = + match (op, prec) with + | _, Ops.Void_prec -> raise @@ Utils.User_error "gccjit_backend: binary operation on Void_prec" + | Ops.Add, _ -> RValue.binary_op ctx Plus num_typ v1 v2 + | Sub, _ -> RValue.binary_op ctx Minus num_typ v1 v2 + | Mul, _ -> RValue.binary_op ctx Mult num_typ v1 v2 + | Div, _ -> RValue.binary_op ctx Divide num_typ v1 v2 + | ToPowOf, Double_prec _ -> RValue.cast ctx (RValue.call ctx (Function.builtin ctx "pow") [ to_d v1; to_d v2 ]) num_typ - | ToPowOf -> + | ToPowOf, (Single_prec _ | Half_prec _ (* FIXME: *)) -> let base = RValue.cast ctx v1 c_float in let expon = RValue.cast ctx v2 c_float in RValue.cast ctx (RValue.call ctx (Function.builtin ctx "powf") [ base; expon ]) num_typ - | Relu_gate -> + | ToPowOf, Byte_prec _ -> raise @@ Utils.User_error "gccjit_backend: Byte_prec does not support ToPowOf" + | Relu_gate, _ -> let cmp = cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) v1 in RValue.binary_op ctx Mult num_typ cmp @@ v2 - | Arg2 -> v2 - | Arg1 -> v1 + | Arg2, _ -> v2 + | Arg1, _ -> v1 in let log_comment c = match log_functions with @@ -333,7 +329,7 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *) loop @@ Get_local id | Get_local id -> - let lvalue, _typ, _local_is_double = Map.find_exn !debug_locals id in + let lvalue, _typ = Map.find_exn !debug_locals id in (* FIXME(194): Convert according to _typ ?= num_typ. *) (LValue.to_string lvalue ^ "{=%g}", [ to_d @@ RValue.lvalue lvalue ]) | Get_global (C_function f_name, None) -> ("", []) @@ -360,7 +356,7 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini | Binop (Arg1, v1, _v2) -> loop v1 | Binop (Arg2, _v1, v2) -> loop v2 | Binop (op, v1, v2) -> - let prefix, infix, postfix = Ops.binop_C_syntax ~is_double op in + let prefix, infix, postfix = Ops.binop_C_syntax prec op in let v1, fillers1 = loop v1 in let v2, fillers2 = loop v2 in (String.concat [ prefix; v1; infix; " "; v2; postfix ], fillers1 @ fillers2) @@ -425,15 +421,15 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini let tn = Low_level.(get_node info.traced_store array) in assert tn.zero_initialized (* The initialization is emitted by get_array. *) | Set_local (id, llv) -> - let lhs, num_typ, is_double = Map.find_exn !locals id in - let value = loop_float ~name ~env ~num_typ ~is_double llv in + let lhs, num_typ = Map.find_exn !locals id in + let value = loop_float ~name ~env ~num_typ id.tn.prec llv in Block.assign !current_block lhs value | Comment c -> log_comment c | Staged_compilation exp -> exp () and loop_float ~name ~env ~num_typ ~is_double v_code = let loop = loop_float ~name ~env ~num_typ ~is_double in match v_code with - | Local_scope { id = { scope_id = i; _ } as id; prec; body; orig_indices = _ } -> + | Local_scope { id = { scope_id = i; tn = {prec; _} } as id; body; orig_indices = _ } -> let typ = Type.get ctx @@ prec_to_kind prec in (* Scope ids can be non-unique due to inlining. *) let v_name = Int.("v" ^ to_string i ^ "_" ^ get_uid ()) in @@ -442,13 +438,14 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini virtual nodes. *) Block.assign !current_block lvalue @@ RValue.zero ctx typ; let old_locals = !locals in - locals := Map.update !locals id ~f:(fun _ -> (lvalue, typ, prec_is_double prec)); + locals := Map.update !locals id ~f:(fun _ -> (lvalue, typ)); loop_proc ~toplevel:false ~env ~name:(name ^ "_at_" ^ v_name) body; - debug_locals := Map.update !debug_locals id ~f:(fun _ -> (lvalue, typ, prec_is_double prec)); + debug_locals := Map.update !debug_locals id ~f:(fun _ -> (lvalue, typ)); locals := old_locals; RValue.lvalue lvalue | Get_local id -> - let lvalue, _typ, _local_is_double = Map.find_exn !locals id in + let lvalue, _typ = Map.find_exn !locals id in + (* FIXME(194): Convert according to _typ ?= num_typ. *) RValue.lvalue lvalue | Get_global (C_function f_name, None) -> @@ -482,7 +479,7 @@ let compile_main ~name ~log_functions ~env ({ ctx; arrays; _ } as info) func ini raise e) | Binop (Arg2, _, c2) -> loop c2 | Binop (Arg1, c1, _) -> loop c1 - | Binop (op, c1, c2) -> loop_binop op ~num_typ ~is_double ~v1:(loop c1) ~v2:(loop c2) + | Binop (op, c1, c2) -> loop_binop op ~num_typ prec ~v1:(loop c1) ~v2:(loop c2) | Unop (Identity, c) -> loop c | Unop (Relu, c) -> (* FIXME: don't recompute c *) diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index e946d9fb..b1c7fc27 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -815,7 +815,7 @@ let fprint_hum ?(ident_style = `Heuristic_ocannl) ?name ?static_indices () ppf l | Binop (Arg1, v1, _v2) -> pp_float prec ppf v1 | Binop (Arg2, _v1, v2) -> pp_float prec ppf v2 | Binop (op, v1, v2) -> - let prefix, infix, postfix = Ops.binop_C_syntax op in + let prefix, infix, postfix = Ops.binop_C_syntax prec op in fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix (pp_float prec) v1 infix (pp_float prec) v2 postfix | Unop (Identity, v) -> (pp_float prec) ppf v | Unop (Relu, v) -> fprintf ppf "@[<1>relu(%a@])" (pp_float prec) v diff --git a/arrayjit/lib/ndarray.ml b/arrayjit/lib/ndarray.ml index 1bfc2275..565aa26a 100644 --- a/arrayjit/lib/ndarray.ml +++ b/arrayjit/lib/ndarray.ml @@ -47,7 +47,6 @@ let precision_string = function | Double_nd _ -> "double" let default_kind = Ops.Single -let is_double_prec_t = function Double_nd _ -> true | _ -> false let get_prec = function | Byte_nd _ -> Ops.byte diff --git a/arrayjit/lib/ops.ml b/arrayjit/lib/ops.ml index 1ecb03ef..7d5e7f4a 100644 --- a/arrayjit/lib/ops.ml +++ b/arrayjit/lib/ops.ml @@ -88,11 +88,6 @@ let promote_prec p1 p2 = | _, Byte_prec _ -> p2 | Void_prec, Void_prec -> Void_prec -let is_double (type ocaml elt_t) (prec : (ocaml, elt_t) precision) = - match prec with Double -> true | _ -> false - -let is_double_prec = function Double_prec _ -> true | _ -> false - let pack_prec (type ocaml elt_t) (prec : (ocaml, elt_t) precision) = match prec with Byte -> byte | Half -> half | Single -> single | Double -> double @@ -106,7 +101,7 @@ let map_prec ?default { f } = function | Double_prec Double -> f Double | _ -> . -let c_typ_of_prec = +let gcc_typ_of_prec = let open Gccjit in function | Byte_prec _ -> Type.Unsigned_char @@ -115,6 +110,14 @@ let c_typ_of_prec = | Double_prec _ -> Type.Double | Void_prec -> Type.Void +let cuda_typ_of_prec = function + | Byte_prec _ -> "unsigned char" + (* TODO: or should it be uint8, or uint8_t? *) + | Half_prec _ -> (* FIXME: *) "float" + | Single_prec _ -> "float" + | Double_prec _ -> "double" + | Void_prec -> "void" + (** {2 *** Operations ***} *) (** Initializes or resets a array by filling in the corresponding numbers, at the appropriate precision. *) @@ -157,16 +160,22 @@ let interpret_unop op v = let open Float in match op with Identity -> v | Relu when v >= 0. -> v | Relu -> 0. -let binop_C_syntax ~is_double = function - | Arg1 -> invalid_arg "Ops.binop_C_syntax: Arg1 is not a C operator" - | Arg2 -> invalid_arg "Ops.binop_C_syntax: Arg2 is not a C operator" - | Add -> ("(", " +", ")") - | Sub -> ("(", " -", ")") - | Mul -> ("(", " *", ")") - | Div -> ("(", " /", ")") - | ToPowOf when is_double -> ("pow(", ",", ")") - | ToPowOf -> ("powf(", ",", ")") - | Relu_gate -> ("(", " > 0.0 ?", " : 0.0)") +let binop_C_syntax prec v = + match (v, prec) with + | Arg1, _ -> invalid_arg "Ops.binop_C_syntax: Arg1 is not a C operator" + | Arg2, _ -> invalid_arg "Ops.binop_C_syntax: Arg2 is not a C operator" + | _, Void_prec -> invalid_arg "Ops.binop_C_syntax: Void precision" + | Add, _ -> ("(", " +", ")") + | Sub, _ -> ("(", " -", ")") + | Mul, _ -> ("(", " *", ")") + | Div, _ -> ("(", " /", ")") + | ToPowOf, Double_prec _ -> ("pow(", ",", ")") + | ToPowOf, Single_prec _ -> ("powf(", ",", ")") + | ToPowOf, Half_prec _ -> ("powf(", ",", ")") + | ToPowOf, Byte_prec _ -> + invalid_arg "Ops.binop_C_syntax: ToPowOf not supported for byte/integer precisions" + | Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)") + | Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)") (* "((int)(", "> 0.0) *", ")" *) let binop_cd_syntax = function @@ -218,7 +227,7 @@ let equal_voidptr : voidptr -> voidptr -> bool = phys_equal let ptr_to_string ptr prec = let open Gccjit in let ctx = Context.create () in - let result = RValue.to_string @@ RValue.ptr ctx Type.(pointer @@ get ctx @@ c_typ_of_prec prec) ptr in + let result = RValue.to_string @@ RValue.ptr ctx Type.(pointer @@ get ctx @@ gcc_typ_of_prec prec) ptr in Context.release ctx; result