Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support trait upcasting #796

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 136 additions & 50 deletions chalk-solve/src/clauses/builtin_traits/unsize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashSet;
use std::iter;
use std::ops::ControlFlow;

use crate::clauses::ClauseBuilder;
use crate::clauses::{super_traits::super_traits, ClauseBuilder};
use crate::rust_ir::AdtKind;
use crate::{Interner, RustIrDatabase, TraitRef, WellKnownTrait};
use chalk_ir::{
Expand Down Expand Up @@ -136,17 +136,27 @@ fn uses_outer_binder_params<I: Interner>(
matches!(flow, ControlFlow::Break(_))
}

fn principal_id<I: Interner>(
fn principal_trait_ref<I: Interner>(
db: &dyn RustIrDatabase<I>,
bounds: &Binders<QuantifiedWhereClauses<I>>,
) -> Option<TraitId<I>> {
let interner = db.interner();

) -> Option<Binders<Binders<TraitRef<I>>>> {
bounds
.skip_binders()
.iter(interner)
.filter_map(|b| b.trait_id())
.find(|&id| !db.trait_datum(id).is_auto_trait())
.map_ref(|b| b.iter(db.interner()))
.into_iter()
.find_map(|b| {
b.filter_map(|qwc| {
qwc.as_ref().filter_map(|wc| match wc {
WhereClause::Implemented(trait_ref) => {
if db.trait_datum(trait_ref.trait_id).is_auto_trait() {
None
} else {
Some(trait_ref.clone())
}
}
_ => None,
})
})
})
}

fn auto_trait_ids<'a, I: Interner>(
Expand Down Expand Up @@ -191,6 +201,7 @@ pub fn add_unsize_program_clauses<I: Interner>(

match (source_ty.kind(interner), target_ty.kind(interner)) {
// dyn Trait + AutoX + 'a -> dyn Trait + AutoY + 'b
// dyn TraitA + AutoX + 'a -> dyn TraitB + AutoY + 'b (upcasting)
(
TyKind::Dyn(DynTy {
bounds: bounds_a,
Expand All @@ -201,21 +212,30 @@ pub fn add_unsize_program_clauses<I: Interner>(
lifetime: lifetime_b,
}),
) => {
let principal_a = principal_id(db, bounds_a);
let principal_b = principal_id(db, bounds_b);
let principal_trait_ref_a = principal_trait_ref(db, bounds_a);
let principal_a = principal_trait_ref_a
.as_ref()
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);
let principal_b = principal_trait_ref(db, bounds_b)
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);

let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a).collect();
let auto_trait_ids_b: Vec<_> = auto_trait_ids(db, bounds_b).collect();

let may_apply = principal_a == principal_b
&& auto_trait_ids_b
.iter()
.all(|id_b| auto_trait_ids_a.iter().any(|id_a| id_a == id_b));

if !may_apply {
let auto_traits_compatible = auto_trait_ids_a
.iter()
.all(|id_b| auto_trait_ids_a.contains(&id_b));
if !auto_traits_compatible {
return;
}

// Check that source lifetime outlives target lifetime
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
a: lifetime_a.clone(),
b: lifetime_b.clone(),
})
.cast(interner);

// COMMENT FROM RUSTC:
// ------------------
// Require that the traits involved in this upcast are **equal**;
Expand All @@ -239,42 +259,108 @@ pub fn add_unsize_program_clauses<I: Interner>(
//
// In order for the coercion to be valid, this new type
// should be equal to target type.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).filter(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return true,
};

if auto_trait_ids_a.iter().all(|&id_a| id_a != trait_id) {
return true;
}
auto_trait_ids_b.iter().any(|&id_b| id_b == trait_id)
if principal_a == principal_b {
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).filter(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return true,
};

if !auto_trait_ids_a.contains(&trait_id) {
return true;
}
auto_trait_ids_b.contains(&trait_id)
}),
)
}),
lifetime: lifetime_b.clone(),
})
.intern(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
}
.cast(interner);

builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal]);
} else if let (Some(principal_a), Some(principal_b)) = (principal_a, principal_b) {
let principal_trait_ref_a = principal_trait_ref_a.unwrap();
let applicable_super_traits = super_traits(db, principal_a)
.map(|(super_trait_refs, _)| super_trait_refs)
.into_iter()
.filter(|trait_ref| {
trait_ref.skip_binders().skip_binders().trait_id == principal_b
});

for super_trait_ref in applicable_super_traits {
// `super_trait_ref` is, at this point, quantified over generic params of
// `principal_a` and relevant higher-ranked lifetimes that come from super
// trait elaboration (see comments on `super_traits()`).
//
// So if we have `trait Trait<'a, T>: for<'b> Super<'a, 'b, T> {}`,
// `super_trait_ref` can be something like
// `for<Self, 'a, T> for<'b> Self: Super<'a, 'b, T>`.
//
// We need to convert it into a bound for `DynTy`. We do this by substituting
// bound vars of `principal_trait_ref_a` and then fusing inner binders for
// higher-ranked lifetimes.
let rebound_super_trait_ref = principal_trait_ref_a.map_ref(|q_trait_ref_a| {
q_trait_ref_a
.map_ref(|trait_ref_a| {
super_trait_ref.substitute(interner, &trait_ref_a.substitution)
})
.fuse_binders(interner)
});

// Skip `for<Self>` binder. We'll rebind it immediately below.
let new_principal_trait_ref = rebound_super_trait_ref
.into_value_and_skipped_binders()
.0
.map(|it| it.cast(interner));

// Swap trait ref for `principal_a` with the new trait ref, drop the auto
// traits not included in the upcast target.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).cloned().filter_map(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return Some(bound),
};

if principal_a == trait_id {
Some(new_principal_trait_ref.clone())
} else {
auto_trait_ids_b.contains(&trait_id).then_some(bound)
}
}),
)
}),
)
}),
lifetime: lifetime_b.clone(),
})
.intern(interner);
lifetime: lifetime_b.clone(),
})
.intern(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
}
.cast(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
// We don't push goal for `principal_b`'s object safety because it's implied by
// `principal_a`'s object safety.
builder
.push_clause(trait_ref.clone(), [eq_goal, lifetime_outlives_goal.clone()]);
}
}
.cast(interner);

// Check that source lifetime outlives target lifetime
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
a: lifetime_a.clone(),
b: lifetime_b.clone(),
})
.cast(interner);

builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
}

// T -> dyn Trait + 'a
Expand Down
23 changes: 22 additions & 1 deletion chalk-solve/src/clauses/super_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,28 @@ pub(super) fn push_trait_super_clauses<I: Interner>(
}
}

fn super_traits<I: Interner>(
/// Returns super-`TraitRef`s and super-`Projection`s that are quantified over the parameters of
/// `trait_id` and relevant higher-ranked lifetimes. The outer `Binders` is for the former and the
/// inner `Binders` is for the latter.
///
/// For example, given the following trait definitions and `C` as `trait_id`,
///
/// ```
/// trait A<'a, T> {}
/// trait B<'b, U> where Self: for<'x> A<'x, U> {}
/// trait C<'c, V> where Self: B<'c, V> {}
/// ```
///
/// returns the following quantified `TraitRef`s.
///
/// ```notrust
/// for<Self, 'c, V> {
/// for<'x> { Self: A<'x, V> }
/// for<> { Self: B<'c, V> }
/// for<> { Self: C<'c, V> }
/// }
/// ```
pub(crate) fn super_traits<I: Interner>(
db: &dyn RustIrDatabase<I>,
trait_id: TraitId<I>,
) -> Binders<(
Expand Down
94 changes: 94 additions & 0 deletions tests/test/unsize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,100 @@ fn dyn_to_dyn_unsizing() {
}
}

#[test]
fn dyn_upcasting() {
test! {
program {
#[lang(unsize)]
trait Unsize<T> {}

#[object_safe]
trait Super {}
#[object_safe]
trait GenericSuper<T> {}
#[object_safe]
trait Sub
where
Self: Super,
Self: GenericSuper<i32>,
Self: GenericSuper<i64>,
{}
#[object_safe]
trait Principal where Self: Sub {}

#[auto]
#[object_safe]
trait Auto1 {}

#[auto]
#[object_safe]
trait Auto2 {}
}

goal {
forall<'a> {
dyn Principal + 'a: Unsize<dyn Sub + 'a>
}
} yields {
expect![[r#"Unique; lifetime constraints [InEnvironment { environment: Env([]), goal: '!1_0: '!1_0 }]"#]]
}

goal {
forall<'a> {
dyn Principal + Auto1 + 'a: Unsize<dyn Sub + Auto1 + 'a>
}
} yields {
expect![[r#"Unique; lifetime constraints [InEnvironment { environment: Env([]), goal: '!1_0: '!1_0 }]"#]]
}

// Different set of auto traits
goal {
forall<'a> {
dyn Principal + Auto1 + 'a: Unsize<dyn Sub + Auto2 + 'a>
}
} yields {
expect![[r#"No possible solution"#]]
}

// Dropping auto traits is allowed
goal {
forall<'a> {
dyn Principal + Auto1 + Auto2 + 'a: Unsize<dyn Sub + Auto1 + 'a>
}
} yields {
expect![[r#"Unique; lifetime constraints [InEnvironment { environment: Env([]), goal: '!1_0: '!1_0 }]"#]]
}

// Upcasting to indirect super traits
goal {
forall<'a> {
dyn Principal + 'a: Unsize<dyn Super + 'a>
}
} yields {
expect![[r#"Unique; lifetime constraints [InEnvironment { environment: Env([]), goal: '!1_0: '!1_0 }]"#]]
}

goal {
forall<'a> {
dyn Principal + 'a: Unsize<dyn GenericSuper<i32> + 'a>
}
} yields {
expect![[r#"Unique; lifetime constraints [InEnvironment { environment: Env([]), goal: '!1_0: '!1_0 }]"#]]
}

// Ambiguous if there are multiple super traits applicable
goal {
exists<T> {
forall<'a> {
dyn Principal + 'a: Unsize<dyn GenericSuper<T> + 'a>
}
}
} yields {
expect![[r#"Ambiguous; no inference guidance"#]]
}
}
}

#[test]
fn ty_to_dyn_unsizing() {
test! {
Expand Down