From ca4c620068593f8cff6a8e1091211b09890e7569 Mon Sep 17 00:00:00 2001 From: urso Date: Sun, 18 Feb 2024 02:20:22 +0100 Subject: [PATCH] Derive macro: Printable --- Cargo.toml | 3 +- pliron-derive/Cargo.toml | 1 + pliron-derive/src/derive_attr.rs | 11 +- pliron-derive/src/derive_op.rs | 11 +- pliron-derive/src/derive_printable.rs | 278 +++++++++++++ pliron-derive/src/derive_type.rs | 11 +- pliron-derive/src/irfmt/eval.rs | 136 +++++++ pliron-derive/src/irfmt/mod.rs | 554 ++++++++++++++++++++++++++ pliron-derive/src/irfmt/parser.rs | 137 +++++++ pliron-derive/src/lib.rs | 16 + pliron-derive/src/macro_attr.rs | 135 +++++++ 11 files changed, 1289 insertions(+), 4 deletions(-) create mode 100644 pliron-derive/src/derive_printable.rs create mode 100644 pliron-derive/src/irfmt/eval.rs create mode 100644 pliron-derive/src/irfmt/mod.rs create mode 100644 pliron-derive/src/irfmt/parser.rs create mode 100644 pliron-derive/src/macro_attr.rs diff --git a/Cargo.toml b/Cargo.toml index 13ea988..5cb6213 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ intertrait = "0.2.2" linkme = "0.2" paste = "1.0" inventory = "0.3" -combine = "4.6.6" +combine.workspace = true regex = "1.10.2" dyn-clone = "1.0.16" @@ -52,3 +52,4 @@ proc-macro2 = "1.0.72" quote = "1.0.33" prettyplease = "0.2.16" syn = { version = "2.0.43", features = ["derive"] } +combine = "4.6.6" diff --git a/pliron-derive/Cargo.toml b/pliron-derive/Cargo.toml index dc0e849..f18b66a 100644 --- a/pliron-derive/Cargo.toml +++ b/pliron-derive/Cargo.toml @@ -17,6 +17,7 @@ proc-macro = true proc-macro2.workspace = true quote.workspace = true syn.workspace = true +combine.workspace = true [dev-dependencies] prettyplease.workspace = true diff --git a/pliron-derive/src/derive_attr.rs b/pliron-derive/src/derive_attr.rs index d8ae4bd..f0fbd04 100644 --- a/pliron-derive/src/derive_attr.rs +++ b/pliron-derive/src/derive_attr.rs @@ -50,11 +50,18 @@ impl DefAttribute { )); } - let attrs: Vec<_> = input + let mut attrs: Vec<_> = input .attrs .into_iter() .filter(|attr| !attr.path().is_ident(PROC_MACRO_NAME)) .collect(); + attrs.push(syn::parse_quote! { + #[derive(::pliron_derive::DeriveAttribAcceptor)] + }); + attrs.push(syn::parse_quote! { + #[ir_kind = "attribute"] + }); + let input = DeriveInput { attrs, ..input }; let verifiers = VerifiersRegister { @@ -157,6 +164,8 @@ mod tests { expect![[r##" #[derive(PartialEq, Eq, Debug, Clone)] + #[derive(::pliron_derive::DeriveAttribAcceptor)] + #[ir_kind = "attribute"] pub struct UnitAttr(); #[allow(non_camel_case_types)] pub struct AttrInterfaceVerifier_UnitAttr( diff --git a/pliron-derive/src/derive_op.rs b/pliron-derive/src/derive_op.rs index e5d2e4f..a67c091 100644 --- a/pliron-derive/src/derive_op.rs +++ b/pliron-derive/src/derive_op.rs @@ -52,11 +52,18 @@ impl DefOp { )); } - let attrs: Vec<_> = input + let mut attrs: Vec<_> = input .attrs .into_iter() .filter(|attr| !attr.path().is_ident(PROC_MACRO_NAME)) .collect(); + attrs.push(syn::parse_quote! { + #[derive(::pliron_derive::DeriveAttribAcceptor)] + }); + attrs.push(syn::parse_quote! { + #[ir_kind = "op"] + }); + let input = DeriveInput { attrs, ..input }; let verifiers = VerifiersRegister { @@ -171,6 +178,8 @@ mod tests { expect![[r##" #[derive(Clone, Copy)] + #[derive(::pliron_derive::DeriveAttribAcceptor)] + #[ir_kind = "op"] struct TestOp { op: ::pliron::context::Ptr<::pliron::operation::Operation>, } diff --git a/pliron-derive/src/derive_printable.rs b/pliron-derive/src/derive_printable.rs new file mode 100644 index 0000000..9c91fb8 --- /dev/null +++ b/pliron-derive/src/derive_printable.rs @@ -0,0 +1,278 @@ +use std::collections::BTreeSet; + +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::Result; + +use crate::{ + irfmt::{ + AttribTypeFmtEvaler, Directive, Elem, FieldIdent, FmtData, Format, IRFmtInput, Lit, + Optional, UnnamedVar, Var, + }, + macro_attr::IRKind, +}; + +pub(crate) fn derive(input: impl Into) -> Result { + let input = syn::parse2::(input.into())?; + let p = DerivedPrinter::try_from(input)?; + Ok(p.into_token_stream()) +} + +enum DerivedPrinter { + AttribType(DerivedAttribTypePrinter), + Op(DerivedOpPrinter), +} + +impl TryFrom for DerivedPrinter { + type Error = syn::Error; + + fn try_from(input: IRFmtInput) -> Result { + match input.kind { + IRKind::Type | IRKind::Attribute => { + DerivedAttribTypePrinter::try_from(input).map(Self::AttribType) + } + IRKind::Op => Ok(Self::Op(DerivedOpPrinter::from(input))), + } + } +} + +impl ToTokens for DerivedPrinter { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::AttribType(p) => p.to_tokens(tokens), + Self::Op(p) => p.to_tokens(tokens), + } + } +} + +struct DerivedAttribTypePrinter { + ident: syn::Ident, + format: Format, + fields: Vec, +} + +impl TryFrom for DerivedAttribTypePrinter { + type Error = syn::Error; + + fn try_from(input: IRFmtInput) -> Result { + let fields = match input.data { + FmtData::Struct(s) => s.fields, + }; + let evaler = AttribTypeFmtEvaler::new(input.ident.span(), &fields); + let format = evaler.eval(input.format)?; + Ok(Self { + ident: input.ident, + format, + fields, + }) + } +} + +impl ToTokens for DerivedAttribTypePrinter { + fn to_tokens(&self, tokens: &mut TokenStream) { + let builder = AttribTypePrinterBuilder::new(&self.fields); + tokens.extend(builder.build(&self.ident, &self.format)); + } +} + +struct AttribTypePrinterBuilder<'a> { + fields: &'a [FieldIdent], +} + +impl<'a> AttribTypePrinterBuilder<'a> { + fn new(fields: &'a [FieldIdent]) -> Self { + Self { fields } + } +} + +impl<'a> PrinterBuilder for AttribTypePrinterBuilder<'a> { + fn build_directive(&self, d: &Directive, _toplevel: bool) -> TokenStream { + let printer_args = quote! { (ctx, state, fmt) }; + let field_names = self.fields; + let args = d.args.iter().map(|e| self.build_elem(e, false)); + let directive = format_ident!("at_{}_directive", d.name); + quote! { + ::pliron::irfmt::printers::#directive!(self, #printer_args, (#(#args),*), (#(#field_names)*)); + } + } +} + +struct DerivedOpPrinter { + ident: syn::Ident, + format: Format, + fields: Vec, +} + +impl From for DerivedOpPrinter { + fn from(input: IRFmtInput) -> Self { + let fields = match input.data { + FmtData::Struct(s) => s.fields, + }; + Self { + ident: input.ident, + format: input.format, + fields, + } + } +} + +impl ToTokens for DerivedOpPrinter { + fn to_tokens(&self, tokens: &mut TokenStream) { + let builder = OpPrinterBuilder::new(&self.fields); + tokens.extend(builder.build(&self.ident, &self.format)); + } +} + +struct OpPrinterBuilder { + fields: BTreeSet, +} + +impl OpPrinterBuilder { + fn new(fields: &[FieldIdent]) -> Self { + Self { + fields: BTreeSet::from_iter(fields.iter().cloned()), + } + } +} + +impl PrinterBuilder for OpPrinterBuilder { + fn build_var(&self, name: &str, toplevel: bool) -> TokenStream { + let ident = FieldIdent::Named(name.into()); + if self.fields.contains(&ident) { + return self.build_field_use(ident, toplevel); + } + + make_print_if( + toplevel, + quote! { + ::pliron::irfmt::printers::get_attr!(self, #name) + }, + ) + } + + fn build_directive(&self, d: &Directive, toplevel: bool) -> TokenStream { + let directive = format_ident!("op_{}_directive", d.name); + let args = d.args.iter().map(|e| self.build_elem(e, false)); + let printer = quote! { + ::pliron::irfmt::printers::#directive!(ctx, self #(, #args)*) + }; + make_print_if(toplevel, printer) + } +} + +trait PrinterBuilder { + fn build_directive(&self, d: &Directive, toplevel: bool) -> TokenStream; + + fn build(&self, name: &syn::Ident, attr: &Format) -> TokenStream { + let body = self.build_body(attr); + quote! { + impl ::pliron::printable::Printable for #name { + fn fmt( + &self, + ctx: & ::pliron::context::Context, + state: & ::pliron::printable::State, + fmt: &mut ::std::fmt::Formatter<'_>, + ) -> ::std::fmt::Result { + #body + Ok(()) + } + } + } + } + + fn build_body(&self, attr: &Format) -> TokenStream { + self.build_format(attr, true) + } + + fn build_lit(&self, lit: &str, toplevel: bool) -> TokenStream { + if toplevel { + make_print(quote! { + ::pliron::irfmt::printers::literal(#lit) + }) + } else { + quote! { #lit } + } + } + + fn build_var(&self, name: &str, toplevel: bool) -> TokenStream { + self.build_field_use(format_ident!("{}", name), toplevel) + } + + fn build_unnamed_var(&self, index: usize, toplevel: bool) -> TokenStream { + self.build_field_use(syn::Index::from(index), toplevel) + } + + fn build_field_use(&self, ident: T, toplevel: bool) -> TokenStream + where + T: quote::ToTokens, + { + if toplevel { + make_print(quote! { + ::pliron::irfmt::printers::print_var!(&self.#ident) + }) + } else { + quote! { + #ident + } + } + } + + fn build_check(&self, check: &Elem) -> TokenStream { + let check = Directive::new_with_args("check", vec![check.clone()]); + self.build_directive(&check, false) + } + + fn build_format(&self, format: &Format, toplevel: bool) -> TokenStream { + TokenStream::from_iter( + format + .elems + .iter() + .map(|elem| self.build_elem(elem, toplevel)), + ) + } + + fn build_elem(&self, elem: &Elem, toplevel: bool) -> TokenStream { + match elem { + Elem::Lit(Lit { lit, .. }) => self.build_lit(lit, toplevel), + Elem::Var(Var { name, .. }) => self.build_var(name, toplevel), + Elem::UnnamedVar(UnnamedVar { index, .. }) => self.build_unnamed_var(*index, toplevel), + Elem::Directive(ref d) => self.build_directive(d, toplevel), + Elem::Optional(ref opt) => self.build_optional(opt, toplevel), + } + } + + fn build_optional(&self, d: &Optional, toplevel: bool) -> TokenStream { + let check = self.build_check(&d.check); + let then_block = self.build_format(&d.then_format, toplevel); + if let Some(else_format) = &d.else_format { + let else_block = self.build_format(else_format, toplevel); + quote! { + if #check { + #then_block + } else { + #else_block + } + } + } else { + quote! { + if #check { + #then_block + } + } + } + } +} + +fn make_print(stmt: TokenStream) -> TokenStream { + quote! { + #stmt.fmt(ctx, state, fmt)?; + } +} + +fn make_print_if(cond: bool, stmt: TokenStream) -> TokenStream { + if cond { + make_print(stmt) + } else { + stmt + } +} diff --git a/pliron-derive/src/derive_type.rs b/pliron-derive/src/derive_type.rs index 3d7d4b0..c927c73 100644 --- a/pliron-derive/src/derive_type.rs +++ b/pliron-derive/src/derive_type.rs @@ -46,11 +46,18 @@ impl DefType { )); } - let attrs: Vec<_> = input + let mut attrs: Vec<_> = input .attrs .into_iter() .filter(|attr| !attr.path().is_ident(PROC_MACRO_NAME)) .collect(); + attrs.push(syn::parse_quote! { + #[derive(::pliron_derive::DeriveAttribAcceptor)] + }); + attrs.push(syn::parse_quote! { + #[ir_kind = "type"] + }); + let input = DeriveInput { attrs, ..input }; let impl_type = ImplType { @@ -131,6 +138,8 @@ mod tests { expect![[r##" #[derive(Hash, PartialEq, Eq, Debug)] + #[derive(::pliron_derive::DeriveAttribAcceptor)] + #[ir_kind = "type"] pub struct SimpleType {} impl ::pliron::r#type::Type for SimpleType { fn hash_type(&self) -> ::pliron::storage_uniquer::TypeValueHash { diff --git a/pliron-derive/src/irfmt/eval.rs b/pliron-derive/src/irfmt/eval.rs new file mode 100644 index 0000000..5bca14b --- /dev/null +++ b/pliron-derive/src/irfmt/eval.rs @@ -0,0 +1,136 @@ +use proc_macro2::Span; +use syn; + +use super::{Directive, Elem, FieldIdent, FmtValue, Format, Optional}; + +pub struct AttribTypeFmtEvaler<'a> { + span: Span, + fields: &'a [FieldIdent], +} + +impl<'a> AttribTypeFmtEvaler<'a> { + pub fn new(span: Span, fields: &'a [FieldIdent]) -> Self { + Self { span, fields } + } + + fn span(&self) -> Span { + self.span + } + + pub fn eval(&self, f: Format) -> syn::Result { + Ok(self.eval_format(f, true)?.into()) + } + + fn eval_format(&self, f: Format, toplevel: bool) -> syn::Result { + let elems = self.eval_elems(f.elems, toplevel)?; + Ok(elems.into()) + } + + fn eval_elems(&self, elem: Vec, toplevel: bool) -> syn::Result { + let results = elem.into_iter().map(|e| self.eval_elem(e, toplevel)); + let mut elems = vec![]; + for r in results { + r?.flatten_into(&mut elems); + } + Ok(FmtValue(elems)) + } + + fn eval_elem(&self, elem: Elem, toplevel: bool) -> syn::Result { + match elem { + Elem::Lit(_) | Elem::Var(_) | Elem::UnnamedVar(_) => Ok(elem.into()), + Elem::Directive(d) => self.eval_directive(d, toplevel), + Elem::Optional(opt) => self.eval_optional(opt, toplevel), + } + } + + fn eval_directive(&self, d: Directive, toplevel: bool) -> syn::Result { + match d.name.as_str() { + "params" => { + require_no_args(self.span, "params", &d.args)?; + if toplevel { + Ok(FmtValue::from(d)) + } else { + Ok(FmtValue::from( + self.fields.iter().map(|f| f.into()).collect::>(), + )) + } + } + "struct" => { + require_toplevel(self.span, &d.name, toplevel)?; + require_args(self.span, "struct", &d.args)?; + let args = self.eval_args(d.args)?; + Ok(FmtValue::from(Directive { args, ..d })) + } + _ => { + require_toplevel(self.span, &d.name, toplevel)?; + let args = self.eval_args(d.args)?; + Ok(FmtValue::from(Directive { args, ..d })) + } + } + } + + fn eval_args(&self, args: Vec) -> syn::Result> { + let values = self.eval_elems(args, false)?; + Ok(values.into()) + } + + fn eval_optional(&self, opt: Optional, toplevel: bool) -> syn::Result { + require_toplevel(self.span(), "optional", toplevel).unwrap(); + + let mut check_tmp = self.eval_elem(*opt.check, false)?.flatten(); + let Some(check) = check_tmp.pop() else { + return Err(syn::Error::new( + self.span(), + "`check` argument of `optional` has no value", + )); + }; + if !check_tmp.is_empty() { + return Err(syn::Error::new( + self.span(), + "`check` argument of `optional` directive must be a single value", + )); + } + + let then_format = self.eval_format(opt.then_format, toplevel)?; + let else_format = opt + .else_format + .map(|f| self.eval_format(f, toplevel)) + .transpose()?; + + Ok(FmtValue::from(Optional { + check: Box::new(check), + then_format, + else_format, + })) + } +} + +fn require_toplevel(span: Span, directive: &str, toplevel: bool) -> syn::Result<()> { + if !toplevel { + return Err(syn::Error::new( + span, + format!("`{}` directive is only allowed at the top-level", directive), + )); + } + Ok(()) +} + +fn require_no_args(span: Span, directive: &str, args: &[Elem]) -> syn::Result<()> { + if !args.is_empty() { + return Err(syn::Error::new( + span, + format!("`{}` directive does not take any arguments", directive), + )); + } + Ok(()) +} + +fn require_args(span: Span, directive: &str, args: &[Elem]) -> syn::Result<()> { + if args.is_empty() { + return Err(syn::Error::new( + span, + format!("`{}` directive requires arguments", directive), + )); + } + Ok(()) +} diff --git a/pliron-derive/src/irfmt/mod.rs b/pliron-derive/src/irfmt/mod.rs new file mode 100644 index 0000000..75c19f1 --- /dev/null +++ b/pliron-derive/src/irfmt/mod.rs @@ -0,0 +1,554 @@ +use proc_macro2::TokenStream; +use quote::format_ident; +use syn::parse::{Parse, ParseStream}; +use syn::Data; +use syn::{self, DataStruct, DeriveInput}; + +mod eval; +mod parser; + +pub use eval::AttribTypeFmtEvaler; + +use crate::macro_attr::{require_once, IRFormat, IRKind}; + +pub(crate) struct IRFmtInput { + pub ident: syn::Ident, + pub kind: IRKind, + pub format: Format, + pub data: FmtData, +} + +pub(crate) enum FmtData { + Struct(Struct), +} + +impl Parse for IRFmtInput { + fn parse(input: ParseStream) -> syn::Result { + let input = DeriveInput::parse(input)?; + Self::try_from(input) + } +} + +impl TryFrom for IRFmtInput { + type Error = syn::Error; + + fn try_from(input: DeriveInput) -> syn::Result { + let mut kind = None; + let mut format = None; + + for attr in &input.attrs { + if attr.path().is_ident(IRFormat::ATTR_NAME) { + require_once(IRFormat::ATTR_NAME, &format, attr)?; + format = Some(IRFormat::from_syn(attr)?); + } + if attr.path().is_ident(IRKind::ATTR_NAME) { + require_once(IRKind::ATTR_NAME, &kind, attr)?; + kind = Some(IRKind::from_syn(attr)?); + } + } + + let Some(kind) = kind else { + return Err(syn::Error::new_spanned( + input, + "unknown IR object type. Use #[ir_kind=...] or one of the supported derive clauses Type, Attrib, ...", + )); + }; + + let data = match input.data { + Data::Struct(ref data) => Struct::from_syn(data).map(FmtData::Struct), + Data::Enum(_) => Err(syn::Error::new_spanned( + &input, + "Type can only be derived for structs", + )), + Data::Union(_) => Err(syn::Error::new_spanned( + &input, + "Type can only be derived for structs", + )), + }?; + + let format = match format { + Some(f) => f, + None => { + let mut format = match kind { + IRKind::Op => generic_op_format(), + IRKind::Type | IRKind::Attribute => try_format_from_input(&input)?, + }; + if !format.is_empty() && kind != IRKind::Op { + format.enclose(Elem::Lit("<".into()), Elem::Lit(">".into())); + } + format.into() + } + }; + + let mut format: Format = format.into(); + if kind == IRKind::Op { + format.prepend(Optional::new( + Elem::new_directive("results"), + Format::from(vec![Elem::new_directive("results"), Elem::new_lit(" = ")]), + )); + } + + Ok(Self { + ident: input.ident, + kind, + format, + data, + }) + } +} + +pub(crate) struct Struct { + pub fields: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) enum FieldIdent { + Named(String), + Unnamed(usize), +} + +impl From for Elem { + fn from(value: FieldIdent) -> Self { + match value { + FieldIdent::Named(name) => Elem::new_var(name), + FieldIdent::Unnamed(index) => Elem::new_unnamed_var(index), + } + } +} + +impl From<&FieldIdent> for Elem { + fn from(value: &FieldIdent) -> Self { + match value { + FieldIdent::Named(name) => Elem::new_var(name), + FieldIdent::Unnamed(index) => Elem::new_unnamed_var(*index), + } + } +} + +impl From<&str> for FieldIdent { + fn from(s: &str) -> Self { + Self::Named(s.to_string()) + } +} + +impl From for FieldIdent { + fn from(s: String) -> Self { + Self::Named(s) + } +} + +impl From for FieldIdent { + fn from(i: usize) -> Self { + Self::Unnamed(i) + } +} + +impl quote::ToTokens for FieldIdent { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::Named(name) => { + let ident = format_ident!("{}", name); + ident.to_tokens(tokens); + } + Self::Unnamed(index) => { + let ident = syn::Index::from(*index); + ident.to_tokens(tokens); + } + } + } +} + +impl Struct { + fn from_syn(data: &DataStruct) -> syn::Result { + let fields = data + .fields + .iter() + .enumerate() + .map(|(i, f)| match f.ident { + Some(ref ident) => FieldIdent::Named(ident.to_string()), + None => FieldIdent::Unnamed(i), + }) + .collect(); + + Ok(Self { fields }) + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub(crate) struct Format { + pub elems: Vec, +} + +impl From> for Format { + fn from(elems: Vec) -> Self { + Self { elems } + } +} + +impl Format { + pub fn is_empty(&self) -> bool { + self.elems.is_empty() + } + + pub fn prepend(&mut self, elem: impl Into) { + self.elems.insert(0, elem.into()); + } + + pub fn append(&mut self, elem: impl Into) { + self.elems.push(elem.into()); + } + + pub fn enclose(&mut self, open: impl Into, close: impl Into) { + self.prepend(open); + self.append(close); + } +} + +impl Format { + pub fn parse(input: &str) -> Result { + parser::parse(input) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Elem { + // Literal is a custom string enclosed in backticks. For example `lit` or `(`. + Lit(Lit), + + // varialbes are custom identifiers starting with a dollar sign. For example $var or $0. + Var(Var), + + // Unnamed variables are custom identifiers starting with a dollar sign and a number. + UnnamedVar(UnnamedVar), + + // Directives are builtin identifiers. Some directives may have optional arguments enclosed + // in parens. For example `attr-dict` or `directive($arg1, other-directive)`. + Directive(Directive), + + Optional(Optional), +} + +impl Default for Elem { + fn default() -> Self { + Self::Lit(Lit::new("")) + } +} + +impl Elem { + pub fn new_lit(s: impl Into) -> Self { + Self::Lit(Lit::new(s)) + } + + pub fn new_lit_at(pos: usize, s: impl Into) -> Self { + Self::Lit(Lit::new_at(pos, s)) + } + + pub fn new_var(s: impl Into) -> Self { + Self::Var(Var::new(s)) + } + + pub fn new_var_at(pos: usize, s: impl Into) -> Self { + Self::Var(Var::new_at(pos, s.into())) + } + + pub fn new_unnamed_var(index: usize) -> Self { + Self::UnnamedVar(UnnamedVar::new(index)) + } + + pub fn new_unnamed_var_at(pos: usize, index: usize) -> Self { + Self::UnnamedVar(UnnamedVar::new_at(pos, index)) + } + + pub fn new_directive(name: impl Into) -> Self { + Self::Directive(Directive::new(name)) + } + + #[allow(dead_code)] // used in tests. + pub fn new_directive_at(pos: usize, name: impl Into) -> Self { + Self::Directive(Directive::new_at(pos, name)) + } + + #[allow(dead_code)] + pub fn new_directive_with_args(name: impl Into, args: Vec) -> Self { + Self::Directive(Directive::new_with_args(name, args)) + } + + pub fn new_directive_with_args_at( + pos: usize, + name: impl Into, + args: Vec, + ) -> Self { + Self::Directive(Directive::new_with_args_at(pos, name, args)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Lit { + pub pos: Option, + pub lit: String, +} + +impl From for Elem { + fn from(lit: Lit) -> Self { + Self::Lit(lit) + } +} + +impl From<&str> for Lit { + fn from(s: &str) -> Self { + Self::new(s) + } +} + +impl Lit { + pub fn new(s: impl Into) -> Self { + Self { + pos: None, + lit: s.into(), + } + } + + pub fn new_at(pos: usize, s: impl Into) -> Self { + Self { + pos: Some(pos), + lit: s.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Var { + pub pos: Option, + pub name: String, +} + +impl Var { + pub fn new(s: impl Into) -> Self { + Self { + pos: None, + name: s.into(), + } + } + + pub fn new_at(pos: usize, s: impl Into) -> Self { + Self { + pos: Some(pos), + name: s.into(), + } + } +} + +impl From for Elem { + fn from(lit: Var) -> Self { + Self::Var(lit) + } +} + +impl From<&str> for Var { + fn from(s: &str) -> Self { + Self::new(s) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UnnamedVar { + pub pos: Option, + pub index: usize, +} + +impl From for Elem { + fn from(var: UnnamedVar) -> Self { + Self::UnnamedVar(var) + } +} + +impl From for UnnamedVar { + fn from(index: usize) -> Self { + Self::new(index) + } +} + +impl UnnamedVar { + pub fn new(index: usize) -> Self { + Self { pos: None, index } + } + + pub fn new_at(pos: usize, index: usize) -> Self { + Self { + pos: Some(pos), + index, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Directive { + pub pos: Option, + pub name: String, + pub args: Vec, +} + +impl Directive { + pub fn new(name: impl Into) -> Self { + Self { + pos: None, + name: name.into(), + args: vec![], + } + } + + pub fn new_at(pos: usize, name: impl Into) -> Self { + Self { + pos: Some(pos), + name: name.into(), + args: vec![], + } + } + + pub fn new_with_args(name: impl Into, args: Vec) -> Self { + Self { + pos: None, + name: name.into(), + args, + } + } + + pub fn new_with_args_at(pos: usize, name: impl Into, args: Vec) -> Self { + Self { + pos: Some(pos), + name: name.into(), + args, + } + } +} + +impl From for Elem { + fn from(directive: Directive) -> Self { + Self::Directive(directive) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Optional { + pub check: Box, + pub then_format: Format, + pub else_format: Option, +} + +impl From for Elem { + fn from(optional: Optional) -> Self { + Self::Optional(optional) + } +} + +impl Optional { + pub fn new(check: Elem, then_format: Format) -> Self { + Self { + check: Box::new(check), + then_format, + else_format: None, + } + } + + #[allow(dead_code)] + pub fn new_with_else(check: Elem, then_format: Format, else_format: Format) -> Self { + Self { + check: Box::new(check), + then_format, + else_format: Some(else_format), + } + } +} + +type Result = std::result::Result; + +type Error = Box; + +struct FmtValue(Vec); + +impl From for FmtValue { + fn from(elem: Elem) -> Self { + Self(vec![elem]) + } +} + +impl From> for FmtValue { + fn from(elems: Vec) -> Self { + Self(elems) + } +} + +impl From for FmtValue { + fn from(d: Directive) -> Self { + Self(vec![Elem::Directive(d)]) + } +} + +impl From for FmtValue { + fn from(opt: Optional) -> Self { + Self(vec![Elem::Optional(opt)]) + } +} + +impl From for Vec { + fn from(value: FmtValue) -> Self { + value.0 + } +} + +impl From for Format { + fn from(value: FmtValue) -> Self { + Self { elems: value.0 } + } +} + +impl FmtValue { + // flattens a FmtValue such that it contains no nested Values. + fn flatten(self) -> Vec { + self.0 + } + + fn flatten_into(self, values: &mut Vec) { + values.extend(self.0); + } +} + +pub(crate) fn generic_op_format() -> Format { + Format { + elems: vec![Directive::new("operation_generic_format").into()], + } +} + +pub(crate) fn try_format_from_input(input: &syn::DeriveInput) -> syn::Result { + // TODO: add support for per field attributes? + + let data = match input.data { + Data::Struct(ref data) => data, + _ => { + return Err(syn::Error::new_spanned( + input, + "Type can only be derived for structs", + )) + } + }; + + let elems = match data.fields { + syn::Fields::Named(ref fields) => { + let mut elems = vec![]; + for (i, field) in fields.named.iter().enumerate() { + let ident = field.ident.as_ref().unwrap(); + if i > 0 { + elems.push(Elem::new_lit(", ")); + } + elems.push(Elem::new_lit(format!("{}=", ident))); + elems.push(Elem::new_var(ident.to_string())); + } + elems + } + syn::Fields::Unnamed(ref fields) => (0..(fields.unnamed.len())) + .map(|i| Elem::new_unnamed_var(i as usize)) + .collect::>(), + syn::Fields::Unit => vec![], + }; + Ok(Format { elems }) +} diff --git a/pliron-derive/src/irfmt/parser.rs b/pliron-derive/src/irfmt/parser.rs new file mode 100644 index 0000000..a77a8fa --- /dev/null +++ b/pliron-derive/src/irfmt/parser.rs @@ -0,0 +1,137 @@ +use combine::{ + between, choice, many, one_of, optional, + parser::{ + char::spaces, + range::{recognize, take_while1}, + repeat::escaped, + }, + position, sep_by, + stream::position::IndexPositioner, + token, Parser, +}; + +use super::{Elem, Format}; + +type Stream<'a> = combine::stream::position::Stream<&'a str, IndexPositioner>; + +pub(crate) fn parse(input: &str) -> super::Result { + let input = Stream::with_positioner(input, IndexPositioner::new()); + let (elems, _rest) = match parse_fmt_elems().parse(input) { + Ok(elems) => elems, + Err(err) => { + let msg = format!("{}", err); + return Err(msg.into()); + } + }; + Ok(Format { elems }) +} + +fn parse_fmt_elems<'a>() -> impl Parser, Output = Vec> { + many(parse_fmt_elem()) +} + +combine::parser! { + fn parse_fmt_elem['a]()(Stream<'a>) -> Elem where [] { + parse_fmt_elem_() + } +} + +fn parse_fmt_elem_<'a>() -> impl Parser, Output = Elem> { + spaces() + .with(choice((parse_lit(), parse_var(), parse_directive()))) + .skip(spaces()) +} + +fn parse_lit<'a>() -> impl Parser, Output = Elem> { + let body = recognize(escaped( + take_while1(|c| c != '`' && c != '\\'), + '\\', + one_of(r#"`nrt\"#.chars()), + )); + let lit = between(token('`'), token('`'), body); + (position(), lit).map(|(pos, s)| Elem::new_lit_at(pos, s)) +} + +fn parse_var<'a>() -> impl Parser, Output = Elem> { + let tok = token('$').with(take_while1(|c: char| c.is_alphanumeric() || c == '_')); + (position(), tok).map(|(pos, s): (_, &str)| match s.parse::() { + Ok(n) => Elem::new_unnamed_var_at(pos, n.into()), + Err(_) => Elem::new_var_at(pos, s), + }) +} + +fn parse_directive<'a>() -> impl Parser, Output = Elem> { + let name = take_while1(|c: char| c.is_alphanumeric() || c == '-' || c == '_').skip(spaces()); + let args = between(token('('), token(')'), sep_by(parse_fmt_elem(), token(','))); + (position(), name, optional(args)).map(|(pos, name, args)| { + Elem::new_directive_with_args_at(pos, name, args.unwrap_or_default()) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple_literal() { + let input = "`lit`"; + let want = vec![Elem::new_lit_at(0, "lit")]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } + + #[test] + fn literal_with_escaped_chars() { + let input = r#"`hello\n \`world\``"#; + let want = vec![Elem::new_lit_at(0, r#"hello\n \`world\`"#)]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } + + #[test] + fn simple_variable() { + let input = "$var"; + let want = vec![Elem::new_var_at(0, "var")]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } + + #[test] + fn simple_unnamed_variable() { + let input = "$1"; + let want = vec![Elem::new_unnamed_var_at(0, 1)]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } + + #[test] + fn simple_directive() { + let input = "directive"; + let want = vec![Elem::new_directive_at(0, "directive")]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } + + #[test] + fn directive_with_empty_args() { + let input = "directive()"; + let want = vec![Elem::new_directive_at(0, "directive")]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } + + #[test] + fn directive_with_args() { + let input = "directive($arg1, other-directive)"; + let want = vec![Elem::new_directive_with_args_at( + 0, + "directive", + vec![ + Elem::new_var_at(10, "arg1"), + Elem::new_directive_at(17, "other-directive"), + ], + )]; + let got = parse(input).unwrap(); + assert_eq!(got.elems, want); + } +} diff --git a/pliron-derive/src/lib.rs b/pliron-derive/src/lib.rs index 78194ed..45d9ae8 100644 --- a/pliron-derive/src/lib.rs +++ b/pliron-derive/src/lib.rs @@ -1,5 +1,9 @@ +mod irfmt; +mod macro_attr; + mod derive_attr; mod derive_op; +mod derive_printable; mod derive_shared; mod derive_type; @@ -90,6 +94,11 @@ pub fn def_op(args: TokenStream, input: TokenStream) -> TokenStream { to_token_stream(derive_op::def_op(args, input)) } +#[proc_macro_derive(Printable, attributes(dialect, ir_kind, ir_format))] +pub fn derive_printable(input: TokenStream) -> TokenStream { + to_token_stream(derive_printable::derive(input)) +} + pub(crate) fn to_token_stream(res: syn::Result) -> TokenStream { let tokens = match res { Ok(tokens) => tokens, @@ -102,3 +111,10 @@ pub(crate) fn to_token_stream(res: syn::Result) -> Tok }; TokenStream::from(tokens) } + +// Helper derive macro to accept internal attributes that we pass to Printable, Parsable and other +// derive macros. The helper ensures that the injected attributes do not cause a compilation error if no other derive macro is used. +#[proc_macro_derive(DeriveAttribAcceptor, attributes(ir_kind))] +pub fn derive_attrib_dummy(_input: TokenStream) -> TokenStream { + TokenStream::new() +} diff --git a/pliron-derive/src/macro_attr.rs b/pliron-derive/src/macro_attr.rs new file mode 100644 index 0000000..7f827d2 --- /dev/null +++ b/pliron-derive/src/macro_attr.rs @@ -0,0 +1,135 @@ +use std::fmt; + +use crate::irfmt::Format; +use syn::{Expr, ExprLit, Lit, Result}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum IRKind { + Type, + Attribute, + Op, +} + +impl IRKind { + pub fn as_str(&self) -> &'static str { + match self { + Self::Type => "type", + Self::Attribute => "attribute", + Self::Op => "op", + } + } +} + +impl IRKind { + pub const ATTR_NAME: &'static str = "ir_kind"; + + pub fn from_syn(attr: &syn::Attribute) -> Result { + attr.meta.require_name_value()?; + let value = attrib_lit_value(attr)?.value(); + Self::try_from(value).map_err(|e| syn::Error::new_spanned(attr, e)) + } +} + +impl TryFrom for IRKind { + type Error = String; + + fn try_from(value: String) -> std::result::Result { + Self::try_from(value.as_str()) + } +} + +impl<'a> TryFrom<&'a str> for IRKind { + type Error = String; + + fn try_from(value: &'a str) -> std::result::Result { + match value { + "type" => Ok(Self::Type), + "attribute" => Ok(Self::Attribute), + "op" => Ok(Self::Op), + _ => Err(format!("unknown IR kind: {}", value)), + } + } +} + +impl fmt::Display for IRKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl quote::ToTokens for IRKind { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let value = self.as_str(); + quote::quote!( + #[ir_kind = #value] + ) + .to_tokens(tokens) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct IRFormat(pub Format); + +impl IRFormat { + pub const ATTR_NAME: &'static str = "ir_format"; + + pub fn from_syn(attr: &syn::Attribute) -> Result { + attr.meta.require_name_value()?; + let value = attrib_lit_value(attr)?.value(); + Self::try_from(value).map_err(|e| syn::Error::new_spanned(attr, e)) + } +} + +impl From for Format { + fn from(f: IRFormat) -> Self { + f.0 + } +} + +impl From for IRFormat { + fn from(f: Format) -> Self { + Self(f) + } +} + +impl TryFrom for IRFormat { + type Error = Box; + + fn try_from(value: String) -> std::result::Result { + let f = Format::parse(&value)?; + Ok(Self(f)) + } +} + +pub(crate) fn require_once( + attr_name: &str, + value: &Option, + attr: &syn::Attribute, +) -> syn::Result<()> { + if value.is_some() { + Err(syn::Error::new_spanned( + attr, + format!("{} attribute can only be applied once", attr_name), + )) + } else { + Ok(()) + } +} + +pub(crate) fn attrib_lit_value(attr: &syn::Attribute) -> syn::Result<&syn::LitStr> { + let nv = attr.meta.require_name_value()?; + let Expr::Lit(ExprLit { + lit: Lit::Str(ref lit), + .. + }) = nv.value + else { + return Err(syn::Error::new_spanned( + nv, + format!( + "expected {} attribute to be a string literal", + attr.path().get_ident().unwrap() + ), + )); + }; + Ok(lit) +}