Skip to content

Commit

Permalink
Fixing generics bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kengorab committed Sep 30, 2023
1 parent c4af7bd commit f8581d5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
31 changes: 23 additions & 8 deletions abra_core/src/typechecker/typechecker2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2039,13 +2039,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
let ty_with_generics = self.project.get_type_by_id(type_id_containing_generics);

match (ty_with_generics, hint_ty) {
(Type::Generic(_, _), Type::Generic(_, _)) => {}
(Type::Generic(_, _), _) => {
if let Some(_) = substitutions.get(type_id_containing_generics) {
// If we already have a substitution for this generic, don't overwrite. If the known value does not align with the hint
// type, it should be reported by this function's caller.
return;
}
substitutions.insert(*type_id_containing_generics, *hint_type_id);
}
(Type::GenericInstance(s_id1, g_ids1), Type::GenericInstance(s_id2, g_ids2)) if s_id1 == s_id2 => {
Expand Down Expand Up @@ -2078,9 +2072,30 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {

match ty.clone() {
Type::Generic(_, _) => {
substitutions.get(&type_id)
let mut resolved_type_id = substitutions.get(&type_id)
.map(|substituted_type_id| *substituted_type_id)
.unwrap_or(*type_id)
.unwrap_or(*type_id);
// The `substitutions` map is not flattened; it's possible that, in order to resolve
// a generic's substitution, it may point to another generic which also needs to have
// substitution applied (and so on). For example, consider this structure:
// type A<AT> { x: AT }
// type B<BT> { a: A<BT> }
// val b = B(a: A(x: 12))
// When resolving ^^^^^^^^, the typechecker has the following as its known substitutions:
// { AT -> BT, BT -> Int }
// In order to learn that AT -> Int, we need to flatten the map. In doing so, we bubble
// up the proper type, so the instantiation of B correctly typechecks to `B<Int>`.
// TODO: This could maybe be improved by flattening as it's being built (in `extract_values_for_generics`).
while self.type_contains_generics(&resolved_type_id) {
if let Some(type_id) = substitutions.get(&resolved_type_id) {
if resolved_type_id == *type_id { break; }
resolved_type_id = *type_id;
} else {
break;
}
}

resolved_type_id
}
Type::GenericInstance(struct_id, generic_ids) => {
let substituted_generic_ids = generic_ids.iter().map(|generic_type_id| self.substitute_generics_with_known(generic_type_id, substitutions)).collect();
Expand Down
28 changes: 16 additions & 12 deletions abra_core/src/typechecker/typechecker2_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2744,6 +2744,10 @@ fn typecheck_function_declaration() {
// Misc other tests
assert!(test_typecheck("func foo(x: Bool[] = []) {}").is_ok());
assert!(test_typecheck("func foo(x = 12): Int = x").is_ok());
assert!(test_typecheck(r#"
type Foo<T> { t: T }
func f<T1>(t1: T1): Foo<T1> = Foo(t: t1)
"#).is_ok());
}

#[test]
Expand Down Expand Up @@ -3187,13 +3191,13 @@ fn typecheck_invocation() {
is_parameter: false,
};
assert_eq!(&expected, foo_var);
let var_invocation = &module.code[2];
let var_invocation = &module.code[3];
assert_eq!(PRELUDE_INT_TYPE_ID, *var_invocation.type_id());
let accessor_invocation = &module.code[3];
let accessor_invocation = &module.code[4];
assert_eq!(PRELUDE_INT_TYPE_ID, *accessor_invocation.type_id());
let accessor_invocation_arg_label = &module.code[4];
let accessor_invocation_arg_label = &module.code[5];
assert_eq!(PRELUDE_INT_TYPE_ID, *accessor_invocation_arg_label.type_id());
let accessor_invocation_arg_labels = &module.code[5];
let accessor_invocation_arg_labels = &module.code[6];
assert_eq!(PRELUDE_INT_TYPE_ID, *accessor_invocation_arg_labels.type_id());

// Invoking field of type
Expand All @@ -3220,7 +3224,7 @@ fn typecheck_invocation() {
is_parameter: false,
};
assert_eq!(&expected, foo_var);
let accessor_invocation = &module.code[2];
let accessor_invocation = &module.code[3];
assert_eq!(PRELUDE_INT_TYPE_ID, *accessor_invocation.type_id());

// Invoking method of enum variant
Expand Down Expand Up @@ -3253,7 +3257,7 @@ fn typecheck_invocation() {
is_parameter: false,
};
assert_eq!(&expected, foo_var);
let var_invocation = &module.code[2];
let var_invocation = &module.code[3];
assert_eq!(PRELUDE_INT_TYPE_ID, *var_invocation.type_id());
let accessor_invocation = &module.code[3];
assert_eq!(PRELUDE_INT_TYPE_ID, *accessor_invocation.type_id());
Expand Down Expand Up @@ -3305,7 +3309,7 @@ fn typecheck_invocation() {
is_parameter: false,
};
assert_eq!(&expected, foo_var);
let f_invocation = &module.code[2];
let f_invocation = &module.code[3];
assert_eq!(project.find_type_id(&ScopeId(ModuleId(1), 2), &Type::GenericEnumInstance(enum_id, vec![], Some(0))).unwrap(), *f_invocation.type_id());

// Invoking variadic functions
Expand Down Expand Up @@ -3630,7 +3634,7 @@ fn typecheck_invocation_instantiation() {
],
type_id: struct_.self_type_id,
};
assert_eq!(expected, module.code[0]);
assert_eq!(expected, module.code[1]);

// Test generics
let project = test_typecheck("\
Expand Down Expand Up @@ -3694,7 +3698,7 @@ fn typecheck_invocation_instantiation() {
],
type_id: struct_.self_type_id,
};
assert_eq!(expected, module.code[0]);
assert_eq!(expected, module.code[1]);
}

#[test]
Expand Down Expand Up @@ -3805,7 +3809,7 @@ fn typecheck_accessor() {
member_span: Range { start: Position::new(3, 3), end: Position::new(3, 3) },
type_id: PRELUDE_INT_TYPE_ID,
};
assert_eq!(expected, module.code[1]);
assert_eq!(expected, module.code[2]);

// Accessing method
let project = test_typecheck("\
Expand Down Expand Up @@ -3834,7 +3838,7 @@ fn typecheck_accessor() {
project.find_type_id(&ScopeId(ModuleId(1), 0), &project.function_type(vec![PRELUDE_INT_TYPE_ID], 1, false, int_array_type_id)).unwrap()
},
};
assert_eq!(expected, module.code[1]);
assert_eq!(expected, module.code[2]);

// Option-chaining accessor
assert_typecheck_ok(r#"
Expand Down Expand Up @@ -4169,7 +4173,7 @@ fn typecheck_assignment() {
").unwrap();
let foo_struct = project.find_struct_by_name(&ModuleId(1), &"Foo".to_string()).unwrap();
let foo_type_id = project.find_type_id(&ScopeId(ModuleId(1), 0), &Type::GenericInstance(foo_struct.id, vec![])).unwrap();
let node = &project.modules[1].code[1];
let node = &project.modules[1].code[2];
let expected = TypedNode::Assignment {
span: Range { start: Position::new(3, 1), end: Position::new(3, 10) },
kind: AssignmentKind::Accessor {
Expand Down

0 comments on commit f8581d5

Please sign in to comment.