Skip to content

Commit

Permalink
Defensively traverse upper bounds when closing shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Mar 30, 2024
1 parent 34b35ed commit 0c7d3f3
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions lib/row.ml
Original file line number Diff line number Diff line change
Expand Up @@ -649,13 +649,26 @@ let solve_row_ineq ~(finish : bool) ~(cur : t) ~(subr : t) (env : environment) :
| _, { bcast = Broadcastable; _ } when r2_len <= r1_len -> (prefix_ineqs, env)
| { bcast = Row_var _ | Broadcastable; _ }, { bcast = Row_var _ | Broadcastable; _ } -> assert false

let rec get_upper_bounds env vars =
List.concat_map vars ~f:(fun v ->
match Map.find env.dim_env v with
| Some (Solved (Var v)) -> get_upper_bounds env [ v ]
| Some (Solved ub) -> [ ub ]
| None -> []
| Some (Bounds { lub = None; cur; _ }) -> get_upper_bounds env cur
| Some (Bounds { lub = Some lub; _ }) -> [ lub ])

let close_dim_terminal (env : environment) (dim : dim) : inequality list * environment =
match dim with
| Dim _ -> ([], env)
| Var v -> (
match Map.find env.dim_env v with
| Some (Solved _) -> assert false
| None | Some (Bounds { lub = None; _ }) -> (* needs more inference *) ([ Terminal_dim dim ], env)
| None -> (* needs more inference *) ([ Terminal_dim dim ], env)
| Some (Bounds { lub = None; cur; _ }) ->
let ubs = get_upper_bounds env cur in
if List.is_empty ubs then ([ Terminal_dim dim ], env)
else (List.map ~f:(fun ub -> Dim_eq { d1 = dim; d2 = ub }) ubs, env)
| Some (Bounds { lub = Some lub; _ }) -> ([ Dim_eq { d1 = dim; d2 = lub } ], env))

let close_row_terminal (env : environment) ({ dims; bcast; id } as _r : row) : inequality list * environment =
Expand All @@ -667,7 +680,9 @@ let close_row_terminal (env : environment) ({ dims; bcast; id } as _r : row) : i
match Map.find env.row_env v with
| Some (Bounds { lub = None; _ }) | None -> (* needs more inference *) (Terminal_row rem :: prefix, env)
| Some (Solved _) -> assert false
| Some (Bounds { lub = Some lub; _ }) -> (Row_eq { r1 = row_of_var v id; r2 = lub } :: prefix, env))
| Some (Bounds { lub = Some lub; _ }) ->
(* Solve and propagate the equality. *)
(Row_eq { r1 = rem; r2 = lub } :: Terminal_row rem :: prefix, env))

let finalize_dim (env : environment) (dim : dim) : inequality list =
match subst_dim env dim with Dim _ -> [] | Var _ -> [ Dim_eq { d1 = dim; d2 = get_dim ~d:1 () } ]
Expand Down

0 comments on commit 0c7d3f3

Please sign in to comment.