From 3664c6ac698950cef651438be4d7b124622ad3df Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 28 Aug 2024 10:55:10 -0400 Subject: [PATCH] Fix tensor data elem type conversion in book (#2211) --- burn-book/src/basic-workflow/data.md | 13 +++++++------ examples/guide/src/data.rs | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/burn-book/src/basic-workflow/data.md b/burn-book/src/basic-workflow/data.md index 962d6f8b3f..4e3683c219 100644 --- a/burn-book/src/basic-workflow/data.md +++ b/burn-book/src/basic-workflow/data.md @@ -68,8 +68,8 @@ impl Batcher> for MnistBatcher { fn batch(&self, items: Vec) -> MnistBatch { let images = items .iter() - .map(|item| TensorData::from(item.image)) - .map(|data| Tensor::::from_data(data.convert(), &self.device)) + .map(|item| TensorData::from(item.image).convert::()) + .map(|data| Tensor::::from_data(data, &self.device)) .map(|tensor| tensor.reshape([1, 28, 28])) // Normalize: make between [0,1] and make the mean=0 and std=1 // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example @@ -119,8 +119,8 @@ images. ```rust, ignore let images = items // take items Vec .iter() // create an iterator over it - .map(|item| TensorData::from(item.image)) // for each item, convert the image to float32 data struct - .map(|data| Tensor::::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device + .map(|item| TensorData::from(item.image).convert::()) // for each item, convert the image to float data struct + .map(|data| Tensor::::from_data(data, &self.device)) // for each data struct, create a tensor on the device .map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W] .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization .collect(); // consume the resulting iterator & collect the values into a new vector @@ -138,5 +138,6 @@ a targets tensor that contains the indexes of the correct digit class. The first the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate tensor storage information without being specific for a backend. When creating a tensor from data, we often need to convert the data precision to the current backend in use. This can be done with the -`.convert()` method. While importing the `burn::tensor::ElementConversion` trait, you can call -`.elem()` on a specific number to convert it to the current backend element type in use. +`.convert()` method (in this example, the data is converted backend's float element type +`B::FloatElem`). While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()` +on a specific number to convert it to the current backend element type in use. diff --git a/examples/guide/src/data.rs b/examples/guide/src/data.rs index 78794d4372..dd73209c4a 100644 --- a/examples/guide/src/data.rs +++ b/examples/guide/src/data.rs @@ -24,8 +24,8 @@ impl Batcher> for MnistBatcher { fn batch(&self, items: Vec) -> MnistBatch { let images = items .iter() - .map(|item| TensorData::from(item.image)) - .map(|data| Tensor::::from_data(data.convert::(), &self.device)) + .map(|item| TensorData::from(item.image).convert::()) + .map(|data| Tensor::::from_data(data, &self.device)) .map(|tensor| tensor.reshape([1, 28, 28])) // normalize: make between [0,1] and make the mean = 0 and std = 1 // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example