diff --git a/fathom/src/core.rs b/fathom/src/core.rs index 7b0f9b0bf..16d7ff627 100644 --- a/fathom/src/core.rs +++ b/fathom/src/core.rs @@ -388,6 +388,20 @@ impl<'arena> Term<'arena> { ), } } + + pub fn is_trivial(&self) -> bool { + match self { + Term::ItemVar(_, _) + | Term::LocalVar(_, _) + | Term::MetaVar(_, _) + | Term::InsertedMeta(_, _, _) + | Term::Universe(_) + | Term::Prim(_, _) + | Term::ConstLit(_, _) => true, + Term::RecordProj(_, head, _) => head.is_trivial(), + _ => false, + } + } } /// Simple patterns that have had some initial elaboration performed on them @@ -431,6 +445,16 @@ impl<'arena> CheckedPattern<'arena> { CheckedPattern::ConstLit(_, _) | CheckedPattern::RecordLit(_, _, _) => false, } } + + pub fn is_trivial(&self) -> bool { + match self { + CheckedPattern::ReportedError(_) + | CheckedPattern::Placeholder(_) + | CheckedPattern::Binder(_, _) + | CheckedPattern::ConstLit(_, _) => true, + CheckedPattern::RecordLit(_, _, _) => false, + } + } } macro_rules! def_prims { diff --git a/fathom/src/surface/distillation.rs b/fathom/src/surface/distillation.rs index 6afa11215..d6dcec6a0 100644 --- a/fathom/src/surface/distillation.rs +++ b/fathom/src/surface/distillation.rs @@ -393,11 +393,11 @@ impl<'interner, 'arena, 'env> Context<'interner, 'arena, 'env> { match core_term { core::Term::ItemVar(_span, var) => match self.get_item_name(*var) { Some(name) => Term::Name((), name), - None => todo!("misbound variable"), // TODO: error? + None => panic!("misbound item variable: {var:?}"), }, core::Term::LocalVar(_span, var) => match self.get_local_name(*var) { Some(name) => Term::Name((), name), - None => todo!("misbound variable"), // TODO: error? + None => panic!("Unbound local variable: {var:?}"), }, core::Term::MetaVar(_span, var) => match self.get_hole_name(*var) { Some(name) => Term::Hole((), name), diff --git a/fathom/src/surface/elaboration.rs b/fathom/src/surface/elaboration.rs index cc73001ed..2063d1579 100644 --- a/fathom/src/surface/elaboration.rs +++ b/fathom/src/surface/elaboration.rs @@ -804,21 +804,52 @@ impl<'interner, 'arena> Context<'interner, 'arena> { match (surface_term, expected_type.as_ref()) { (Term::Let(range, def_pattern, def_type, def_expr, body_expr), _) => { let (def_pattern, def_type_value) = self.synth_ann_pattern(def_pattern, *def_type); - let scrut = self.check_scrutinee(def_expr, def_type_value.clone()); + let mut scrut = self.check_scrutinee(def_expr, def_type_value.clone()); let value = self.eval_env().eval(scrut.expr); + + // Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times, + // and may be evaluated multiple times by the pattern match compiler + let extra_def = match (scrut.expr.is_trivial(), def_pattern.is_trivial()) { + (false, false) => { + let def_name = None; // TODO: generate a fresh name + let def_type = self.quote_env().quote(self.scope, &scrut.r#type); + let def_expr = scrut.expr.clone(); + + let var = core::Term::LocalVar(def_expr.span(), env::Index::last()); + scrut.expr = self.scope.to_scope(var); + (self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone()); + Some((def_name, def_type, def_expr)) + } + _ => None, + }; + let initial_len = self.local_env.len(); self.push_local_def(&def_pattern, value, scrut.r#type.clone()); let body_expr = self.check(body_expr, &expected_type); self.local_env.truncate(initial_len); let matrix = PatMatrix::singleton(scrut, def_pattern); - self.elab_match( + let expr = self.elab_match( matrix, &[body_expr], *range, def_expr.range(), PatternMode::Let, - ) + ); + let expr = match extra_def { + None => expr, + Some((def_name, def_type, def_expr)) => { + self.local_env.pop(); + core::Term::Let( + range.into(), + def_name, + self.scope.to_scope(def_type), + self.scope.to_scope(def_expr), + self.scope.to_scope(expr), + ) + } + }; + expr } (Term::If(range, cond_expr, then_expr, else_expr), _) => { let cond_expr = self.check(cond_expr, &self.bool_type.clone()); @@ -1110,9 +1141,25 @@ impl<'interner, 'arena> Context<'interner, 'arena> { } Term::Let(range, def_pattern, def_type, def_expr, body_expr) => { let (def_pattern, def_type_value) = self.synth_ann_pattern(def_pattern, *def_type); - let scrut = self.check_scrutinee(def_expr, def_type_value.clone()); + let mut scrut = self.check_scrutinee(def_expr, def_type_value.clone()); let value = self.eval_env().eval(scrut.expr); + // Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times, + // and may be evaluated multiple times by the pattern match compiler + let extra_def = match (scrut.expr.is_trivial(), def_pattern.is_trivial()) { + (false, false) => { + let def_name = None; // TODO: generate a fresh name + let def_type = self.quote_env().quote(self.scope, &scrut.r#type); + let def_expr = scrut.expr.clone(); + + let var = core::Term::LocalVar(def_expr.span(), env::Index::last()); + scrut.expr = self.scope.to_scope(var); + (self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone()); + Some((def_name, def_type, def_expr)) + } + _ => None, + }; + let initial_len = self.local_env.len(); self.push_local_def(&def_pattern, value, scrut.r#type.clone()); let (body_expr, body_type) = self.synth(body_expr); @@ -1126,6 +1173,19 @@ impl<'interner, 'arena> Context<'interner, 'arena> { def_expr.range(), PatternMode::Let, ); + let expr = match extra_def { + None => expr, + Some((def_name, def_type, def_expr)) => { + self.local_env.pop(); + core::Term::Let( + range.into(), + def_name, + self.scope.to_scope(def_type), + self.scope.to_scope(def_expr), + self.scope.to_scope(expr), + ) + } + }; (expr, body_type) } Term::If(range, cond_expr, then_expr, else_expr) => { @@ -1817,15 +1877,37 @@ impl<'interner, 'arena> Context<'interner, 'arena> { expected_type: &ArcValue<'arena>, ) -> core::Term<'arena> { let expected_type = self.elim_env().force(expected_type); - let scrut = self.synth_scrutinee(scrutinee_expr); + let mut scrut = self.synth_scrutinee(scrutinee_expr); let value = self.eval_env().eval(scrut.expr); + let patterns: Vec<_> = equations + .iter() + .map(|(pat, _)| self.check_pattern(pat, &scrut.r#type)) + .collect(); + + // Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times, + // and may be evaluated multiple times by the pattern match compiler + let extra_def = match ( + scrut.expr.is_trivial(), + patterns.iter().all(|pat| pat.is_trivial()), + ) { + (false, false) => { + let def_name = None; // TODO: generate a fresh name + let def_type = self.quote_env().quote(self.scope, &scrut.r#type); + let def_expr = scrut.expr.clone(); + + let var = core::Term::LocalVar(def_expr.span(), env::Index::last()); + scrut.expr = self.scope.to_scope(var); + (self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone()); + Some((def_name, def_type, def_expr)) + } + _ => None, + }; + let mut rows = Vec::with_capacity(equations.len()); let mut exprs = Vec::with_capacity(equations.len()); - for (pat, expr) in equations { - let pattern = self.check_pattern(pat, &scrut.r#type); - + for (pattern, (_, expr)) in patterns.into_iter().zip(equations) { let initial_len = self.local_env.len(); self.push_pattern( &pattern, @@ -1841,7 +1923,21 @@ impl<'interner, 'arena> Context<'interner, 'arena> { } let matrix = patterns::PatMatrix::new(rows); - self.elab_match(matrix, &exprs, range, scrut.range, PatternMode::Match) + let expr = self.elab_match(matrix, &exprs, range, scrut.range, PatternMode::Match); + let expr = match extra_def { + None => expr, + Some((def_name, def_type, def_expr)) => { + self.local_env.pop(); + core::Term::Let( + range.into(), + def_name, + self.scope.to_scope(def_type), + self.scope.to_scope(def_expr), + self.scope.to_scope(expr), + ) + } + }; + expr } fn synth_scrutinee(&mut self, scrutinee_expr: &Term<'_, ByteRange>) -> Scrutinee<'arena> { diff --git a/tests/succeed/record-patterns/let-check.snap b/tests/succeed/record-patterns/let-check.snap index 066a5d7de..160928a14 100644 --- a/tests/succeed/record-patterns/let-check.snap +++ b/tests/succeed/record-patterns/let-check.snap @@ -1,9 +1,14 @@ stdout = ''' let _ : () = (); -let x : Bool = (false, true)._0; -let y : Bool = (false, true)._1; -let a : Bool = (false, true)._0; -let b : Bool = (false, true)._1; +let _ : () = (); +let _ : (Bool, Bool) = (false, true); +let x : Bool = _._0; +let y : Bool = _._1; +let _ : (Bool, Bool) = (false, true); +let a : Bool = _._0; +let _ : (Bool, Bool) = (false, true); +let b : Bool = _._1; +let _ : (Bool, Bool) = (false, true); let _ : (Bool, Bool) = (false, true); () : () ''' diff --git a/tests/succeed/record-patterns/let-synth.snap b/tests/succeed/record-patterns/let-synth.snap index 066a5d7de..160928a14 100644 --- a/tests/succeed/record-patterns/let-synth.snap +++ b/tests/succeed/record-patterns/let-synth.snap @@ -1,9 +1,14 @@ stdout = ''' let _ : () = (); -let x : Bool = (false, true)._0; -let y : Bool = (false, true)._1; -let a : Bool = (false, true)._0; -let b : Bool = (false, true)._1; +let _ : () = (); +let _ : (Bool, Bool) = (false, true); +let x : Bool = _._0; +let y : Bool = _._1; +let _ : (Bool, Bool) = (false, true); +let a : Bool = _._0; +let _ : (Bool, Bool) = (false, true); +let b : Bool = _._1; +let _ : (Bool, Bool) = (false, true); let _ : (Bool, Bool) = (false, true); () : () ''' diff --git a/tests/succeed/record-patterns/match-bool-pairs.snap b/tests/succeed/record-patterns/match-bool-pairs.snap index d791fac19..81010e964 100644 --- a/tests/succeed/record-patterns/match-bool-pairs.snap +++ b/tests/succeed/record-patterns/match-bool-pairs.snap @@ -1,38 +1,22 @@ stdout = ''' -let and1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then true else false - else false; -let and2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then true else false - else false; -let and3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then true else false - else if (x, y)._1 then false - else false; -let or1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then true - else if (x, y)._1 then true - else false; -let or2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then true - else if (x, y)._1 then true - else false; -let or3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then true else true - else if (x, y)._1 then true - else false; -let xor1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then false else true - else if (x, y)._1 then true - else false; -let xor2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then false else true - else if (x, y)._1 then true - else false; -let xor3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0 - then if (x, y)._1 then false else true - else if (x, y)._1 then true - else false; +let and1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then true else false else false; +let and2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then true else false else false; +let and3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then true else false else if _._1 then false else false; +let or1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then true else if _._1 then true else false; +let or2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then true else if _._1 then true else false; +let or3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then true else true else if _._1 then true else false; +let xor1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then false else true else if _._1 then true else false; +let xor2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then false else true else if _._1 then true else false; +let xor3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y); +if _._0 then if _._1 then false else true else if _._1 then true else false; () : () ''' stderr = ''