Skip to content

Commit

Permalink
Format all type names (fixes #2324) (#2436)
Browse files Browse the repository at this point in the history
* Format all type names (fixes #2324)

* Fix references
  • Loading branch information
samolego authored Oct 30, 2024
1 parent bb9f5b1 commit ff371de
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
4 changes: 2 additions & 2 deletions crates/burn-import/src/burn/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
input_names.iter().for_each(|input| {
self.graph_input_types.push(
inputs
.get(&TensorType::format_name(input))
.get(&Type::format_name(input))
.unwrap_or_else(|| panic!("Input type not found for {input}"))
.clone(),
);
Expand All @@ -562,7 +562,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
output_names.iter().for_each(|output| {
self.graph_output_types.push(
outputs
.get(&TensorType::format_name(output))
.get(&Type::format_name(output))
.unwrap_or_else(|| panic!("Output type not found for {output}"))
.clone(),
);
Expand Down
33 changes: 18 additions & 15 deletions crates/burn-import/src/burn/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ pub enum Type {
}

impl Type {
// This is used, because types might have number literal name, which cannot be
// used as a variable name.
pub fn format_name(name: &str) -> String {
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
if name_is_number {
format!("_{}", name)
} else {
name.to_string()
}
}
pub fn name(&self) -> &Ident {
match self {
Type::Tensor(tensor) => &tensor.name,
Expand Down Expand Up @@ -107,8 +117,10 @@ impl ScalarType {
if name.as_ref().is_empty() {
panic!("Scalar of Type {:?} was passed with empty name", kind);
}

let formatted_name = Type::format_name(name.as_ref());
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
name: Ident::new(&formatted_name, Span::call_site()),
kind,
}
}
Expand Down Expand Up @@ -150,8 +162,9 @@ impl ShapeType {
if name.as_ref().is_empty() {
panic!("Shape was passed with empty name");
}
let formatted_name = Type::format_name(name.as_ref());
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
name: Ident::new(&formatted_name, Span::call_site()),
dim,
}
}
Expand All @@ -173,17 +186,6 @@ impl ShapeType {
}

impl TensorType {
// This is used, because Tensors might have number literal name, which cannot be
// used as a variable name.
pub fn format_name(name: &str) -> String {
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
if name_is_number {
format!("_{}", name)
} else {
name.to_string()
}
}

pub fn new<S: AsRef<str>>(
name: S,
dim: usize,
Expand All @@ -196,7 +198,7 @@ impl TensorType {
kind, shape
);
}
let formatted_name = Self::format_name(name.as_ref());
let formatted_name = Type::format_name(name.as_ref());
assert_ne!(
dim, 0,
"Trying to create TensorType with dim = 0 - should be a Scalar instead!"
Expand Down Expand Up @@ -277,8 +279,9 @@ impl OtherType {
tokens
);
}
let formatted_name = Type::format_name(name.as_ref());
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
name: Ident::new(&formatted_name, Span::call_site()),
ty: tokens,
}
}
Expand Down

0 comments on commit ff371de

Please sign in to comment.