Skip to content

Commit

Permalink
Support if statements in typle_args and typle_bound macros
Browse files Browse the repository at this point in the history
These `if` statements are parsed differently to the `if` expressions in
`typle_for`, removing the need for some types to use `typle_ty!`.

Allow `typle!` to be used in place of `typle_args!`.
  • Loading branch information
jongiddy committed Oct 6, 2024
1 parent e7e601d commit e3f0f81
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 37 deletions.
153 changes: 116 additions & 37 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,35 +919,44 @@ impl<'a> TypleContext<'a> {
match &mut expr {
Expr::Macro(syn::ExprMacro { mac, .. }) => {
if let Some(ident) = mac.path.get_ident() {
if ident == "typle_args" {
if ident == "typle" || ident == "typle_args" {
let token_stream = std::mem::take(&mut mac.tokens);
let default_span = token_stream.span();
let mut tokens = token_stream.into_iter();
let (pattern, range) =
match self.parse_pattern_range(&mut tokens, default_span) {
Ok(t) => t,
Err(e) => return ListIterator4::Variant0(std::iter::once(Err(e))),
Err(e) => return ListIterator4::Variant1(std::iter::once(Err(e))),
};
if range.is_empty() {
return ListIterator4::Variant1(std::iter::empty());
return ListIterator4::Variant0(std::iter::empty());
}
let token_stream = tokens.collect::<TokenStream>();
let expr = match syn::parse2::<Expr>(token_stream) {
Ok(expr) => expr,
Err(e) => return ListIterator4::Variant0(std::iter::once(Err(e))),
};
let mut context = self.clone();
if let Some(ident) = pattern.clone() {
context.constants.insert(ident, 0);
}
return ListIterator4::Variant2(range.zip_clone(expr).map({
move |(index, mut expr)| {
return ListIterator4::Variant2(range.zip_clone(token_stream).flat_map({
move |(index, token_stream)| {
if let Some(ident) = &pattern {
*context.constants.get_mut(ident).unwrap() = index;
}
let token_stream = match context.evaluate_if(token_stream) {
Ok(Some(token_stream)) => token_stream,
Ok(None) => {
return None;
}
Err(e) => {
return Some(Err(e));
}
};
let mut expr = match syn::parse2::<Expr>(token_stream) {
Ok(expr) => expr,
Err(e) => return Some(Err(e)),
};
match context.replace_expr(&mut expr, &mut state) {
Ok(()) => Ok(expr),
Err(e) => Err(e),
Ok(()) => Some(Ok(expr)),
Err(e) => Some(Err(e)),
}
}
}));
Expand All @@ -960,12 +969,12 @@ impl<'a> TypleContext<'a> {
let mut iter = array.elems.iter_mut().fuse();
if let (Some(field), None) = (iter.next(), iter.next()) {
if let Err(e) = self.replace_expr(field, &mut state) {
return ListIterator4::Variant0(std::iter::once(Err(e)));
return ListIterator4::Variant1(std::iter::once(Err(e)));
}
if let Some((start, end)) = evaluate_range(field) {
let start = match start {
Bound::Included(Err(span)) | Bound::Excluded(Err(span)) => {
return ListIterator4::Variant0(std::iter::once(Err(
return ListIterator4::Variant1(std::iter::once(Err(
Error::new(span, "expected integer for start of range"),
)));
}
Expand All @@ -975,7 +984,7 @@ impl<'a> TypleContext<'a> {
};
let end = match end {
Bound::Included(Err(span)) | Bound::Excluded(Err(span)) => {
return ListIterator4::Variant0(std::iter::once(Err(
return ListIterator4::Variant1(std::iter::once(Err(
Error::new(span, "expected integer for end of range"),
)));
}
Expand All @@ -984,7 +993,7 @@ impl<'a> TypleContext<'a> {
Bound::Unbounded => match self.typle_len {
Some(end) => end,
None => {
return ListIterator4::Variant0(std::iter::once(Err(
return ListIterator4::Variant1(std::iter::once(Err(
Error::new(expr.span(), "need an explicit range end"),
)));
}
Expand All @@ -1011,8 +1020,8 @@ impl<'a> TypleContext<'a> {
_ => {}
}
match self.replace_expr(&mut expr, &mut state) {
Ok(()) => ListIterator4::Variant0(std::iter::once(Ok(expr))),
Err(e) => ListIterator4::Variant0(std::iter::once(Err(e))),
Ok(()) => ListIterator4::Variant1(std::iter::once(Ok(expr))),
Err(e) => ListIterator4::Variant1(std::iter::once(Err(e))),
}
}

Expand Down Expand Up @@ -1138,15 +1147,20 @@ impl<'a> TypleContext<'a> {
continue;
}
let token_stream = tokens.collect::<TokenStream>();
let r#type = syn::parse2::<Type>(token_stream)?;
let mut context = self.clone();
if let Some(ident) = pattern.clone() {
context.constants.insert(ident, 0);
}
for (index, mut bounded_ty) in range.zip_clone(r#type) {
for (index, token_stream) in range.zip_clone(token_stream) {
if let Some(ident) = &pattern {
*context.constants.get_mut(ident).unwrap() = index;
}
let Some(token_stream) =
context.evaluate_if(token_stream)?
else {
continue;
};
let mut bounded_ty = syn::parse2::<Type>(token_stream)?;
context.replace_type(&mut bounded_ty)?;
let bounds = predicate_type
.bounds
Expand Down Expand Up @@ -2931,39 +2945,104 @@ impl<'a> TypleContext<'a> {
Ok(())
}

fn evaluate_if(&self, ts: TokenStream) -> Result<Option<TokenStream>> {
let mut tokens = ts.into_iter();
match tokens.next() {
Some(TokenTree::Ident(ident)) if ident == "if" => {
let mut tokens = tokens.collect::<Vec<_>>();
match tokens.pop() {
Some(TokenTree::Group(group1)) => match tokens.last() {
Some(TokenTree::Ident(ident)) if ident == "else" => {
tokens.pop().unwrap();
match tokens.pop() {
Some(TokenTree::Group(group0)) => {
let mut cond =
syn::parse2::<Expr>(tokens.into_iter().collect())?;
let mut state = BlockState::default();
self.replace_expr(&mut cond, &mut state)?;
let b = evaluate_bool(&cond)?;
self.evaluate_if(if b {
group0.stream()
} else {
group1.stream()
})
}
Some(tt) => {
abort!(tt, "Expect body before `else`");
}
None => {
unreachable!("there is at least one token (the `if` token)");
}
}
}
Some(_) => {
let mut cond = syn::parse2::<Expr>(tokens.into_iter().collect())?;
let mut state = BlockState::default();
self.replace_expr(&mut cond, &mut state)?;
let b = evaluate_bool(&cond)?;
if b {
self.evaluate_if(group1.stream())
} else {
Ok(None)
}
}
None => unreachable!("there is at least one token (the `if` token)"),
},
Some(tt) => {
abort!(tt, "Expect body at end of `if`");
}
None => {
unreachable!("there is at least one token (the `if` token)");
}
}
}
Some(tt) => Ok(Some(std::iter::once(tt).chain(tokens).collect())),
None => Ok(None),
}
}

fn replace_type_in_list(&'a self, mut ty: Type) -> impl Iterator<Item = Result<Type>> + 'a {
match &mut ty {
Type::Macro(syn::TypeMacro { mac }) => {
if let Some(ident) = mac.path.get_ident() {
if ident == "typle_args" {
if ident == "typle" || ident == "typle_args" {
let token_stream = std::mem::take(&mut mac.tokens);
let default_span = token_stream.span();
let mut tokens = token_stream.into_iter();
let (pattern, range) =
match self.parse_pattern_range(&mut tokens, default_span) {
Ok(t) => t,
Err(e) => return ListIterator4::Variant0(std::iter::once(Err(e))),
Err(e) => return ListIterator4::Variant1(std::iter::once(Err(e))),
};
if range.is_empty() {
return ListIterator4::Variant3(std::iter::empty());
return ListIterator4::Variant0(std::iter::empty());
}
let token_stream = tokens.collect::<TokenStream>();
let ty = match syn::parse2::<Type>(token_stream) {
Ok(ty) => ty,
Err(e) => return ListIterator4::Variant0(std::iter::once(Err(e))),
};
let mut context = self.clone();
if let Some(ident) = pattern.clone() {
context.constants.insert(ident, 0);
}
return ListIterator4::Variant2(range.zip_clone(ty).map({
move |(index, mut ty)| {
return ListIterator4::Variant2(range.zip_clone(token_stream).flat_map({
move |(index, token_stream)| {
if let Some(ident) = &pattern {
*context.constants.get_mut(ident).unwrap() = index;
}
let token_stream = match context.evaluate_if(token_stream) {
Ok(Some(token_stream)) => token_stream,
Ok(None) => {
return None;
}
Err(e) => {
return Some(Err(e));
}
};
let mut ty = match syn::parse2::<Type>(token_stream) {
Ok(ty) => ty,
Err(e) => return Some(Err(e)),
};
match context.replace_type(&mut ty) {
Ok(()) => Ok(ty),
Err(e) => Err(e),
Ok(()) => Some(Ok(ty)),
Err(e) => Some(Err(e)),
}
}
}));
Expand All @@ -2987,13 +3066,13 @@ impl<'a> TypleContext<'a> {
let mut state = BlockState::default();
match self.replace_expr(expr, &mut state) {
Ok(_) => {}
Err(e) => return ListIterator4::Variant0(std::iter::once(Err(e))),
Err(e) => return ListIterator4::Variant1(std::iter::once(Err(e))),
}
if let Some((start, end)) = evaluate_range(expr) {
// T<{..}>
let start = match start {
Bound::Included(Err(span)) | Bound::Excluded(Err(span)) => {
return ListIterator4::Variant0(std::iter::once(Err(
return ListIterator4::Variant1(std::iter::once(Err(
Error::new(span, "expected integer for start of range"),
)));
}
Expand All @@ -3003,7 +3082,7 @@ impl<'a> TypleContext<'a> {
};
let end = match end {
Bound::Included(Err(span)) | Bound::Excluded(Err(span)) => {
return ListIterator4::Variant0(std::iter::once(Err(
return ListIterator4::Variant1(std::iter::once(Err(
Error::new(span, "expected integer for end of range"),
)));
}
Expand All @@ -3012,7 +3091,7 @@ impl<'a> TypleContext<'a> {
Bound::Unbounded => match self.typle_len {
Some(end) => end,
None => {
return ListIterator4::Variant0(std::iter::once(Err(
return ListIterator4::Variant1(std::iter::once(Err(
Error::new(
expr.span(),
"need an explicit range end",
Expand All @@ -3021,7 +3100,7 @@ impl<'a> TypleContext<'a> {
}
},
};
return ListIterator4::Variant1((start..end).map({
return ListIterator4::Variant3((start..end).map({
let span = path.span();
move |i| self.get_type(typle, i, span)
}));
Expand All @@ -3033,8 +3112,8 @@ impl<'a> TypleContext<'a> {
_ => {}
}
match self.replace_type(&mut ty) {
Ok(()) => ListIterator4::Variant0(std::iter::once(Ok(ty))),
Err(e) => ListIterator4::Variant0(std::iter::once(Err(e))),
Ok(()) => ListIterator4::Variant1(std::iter::once(Ok(ty))),
Err(e) => ListIterator4::Variant1(std::iter::once(Err(e))),
}
}

Expand Down
90 changes: 90 additions & 0 deletions tests/compile/mod.expanded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,96 @@ pub mod typle_args {
{
<(T, u32) as _typle_fn_append_double>::apply((t, a))
}
#[allow(non_camel_case_types)]
trait _typle_fn_append_even {
type Return;
fn apply(self) -> Self::Return;
}
impl _typle_fn_append_even for ((), u32) {
type Return = (u32,);
fn apply(self) -> Self::Return {
#[allow(unused_variables)]
let (t, a) = self;
{ (a,) }
}
}
impl<T0> _typle_fn_append_even for ((T0,), u32) {
type Return = (T0, u32);
fn apply(self) -> Self::Return {
let (t, a) = self;
{ (t.0, a) }
}
}
impl<T0, T1> _typle_fn_append_even for ((T0, T1), u32) {
type Return = (T0, u32);
fn apply(self) -> Self::Return {
let (t, a) = self;
{ (t.0, a) }
}
}
impl<T0, T1, T2> _typle_fn_append_even for ((T0, T1, T2), u32) {
type Return = (T0, T2, u32);
fn apply(self) -> Self::Return {
let (t, a) = self;
{ (t.0, t.2, a) }
}
}
fn append_even<T>(t: T, a: u32) -> <(T, u32) as _typle_fn_append_even>::Return
where
(T, u32): _typle_fn_append_even,
{
<(T, u32) as _typle_fn_append_even>::apply((t, a))
}
#[allow(non_camel_case_types)]
trait _typle_fn_even_string_odd {
type Return;
fn apply(self) -> Self::Return;
}
impl _typle_fn_even_string_odd for ((),) {
type Return = ();
fn apply(self) -> Self::Return {
#[allow(unused_variables)]
let (t,) = self;
{ #[allow(clippy::unused_unit)] () }
}
}
impl<T0> _typle_fn_even_string_odd for ((T0,),)
where
T0: ToString,
{
type Return = (String,);
fn apply(self) -> Self::Return {
let (t,) = self;
{ (t.0.to_string(),) }
}
}
impl<T0, T1> _typle_fn_even_string_odd for ((T0, T1),)
where
T0: ToString,
{
type Return = (String, T1);
fn apply(self) -> Self::Return {
let (t,) = self;
{ (t.0.to_string(), t.1) }
}
}
impl<T0, T1, T2> _typle_fn_even_string_odd for ((T0, T1, T2),)
where
T0: ToString,
T2: ToString,
{
type Return = (String, T1, String);
fn apply(self) -> Self::Return {
let (t,) = self;
{ (t.0.to_string(), t.1, t.2.to_string()) }
}
}
fn even_string_odd<T>(t: T) -> <(T,) as _typle_fn_even_string_odd>::Return
where
(T,): _typle_fn_even_string_odd,
{
<(T,) as _typle_fn_even_string_odd>::apply((t,))
}
struct World {}
trait ExclusiveSystemParam {}
struct ExclusiveSystemParamItem<F> {
Expand Down
Loading

0 comments on commit e3f0f81

Please sign in to comment.