Skip to content

Commit

Permalink
Refactor jit quantized tensor representation (#2604)
Browse files Browse the repository at this point in the history
* Remove q_shape to use TensorMetadata instead

* Fix spirv bool type

* Refactor burn-jit quantized tensor representation

* Remove dead comment

* Update cubecl rev

* Remove dead code

* Fix comments

* Fix clippy

* Remove unnecessary loop for input line size of 1

* Remove quantized kindremnant

* Remove no longer valid comment

* Get qparams values as tuple

* Move data into async context

* Fix ReprBackend handle type for JitBackend and Fusion

* Fusion client read takes ownership

* Fix clippy
  • Loading branch information
laggui authored Dec 13, 2024
1 parent 834ff44 commit 0dd228c
Show file tree
Hide file tree
Showing 42 changed files with 587 additions and 1,050 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
todo!()
}

fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
B::q_shape(tensor)
}

fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
B::q_device(tensor)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
unimplemented!()
}

fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
super::base::shape(&tensor.qtensor)
}

fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
super::base::device(&tensor.qtensor)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ impl QTensorPrimitive for CandleQTensor {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}

fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}

impl TensorMetadata for CandleQTensor {
Expand Down
19 changes: 7 additions & 12 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::{
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, QFusionTensor,
};
use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor};
use burn_tensor::{
backend::{Backend, DeviceOps},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle},
repr::{OperationDescription, ReprBackend, TensorHandle},
Device, Element,
};
use serde::{de::DeserializeOwned, Serialize};
Expand Down Expand Up @@ -37,7 +35,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {

type BoolElem = B::BoolElem;

type QuantizedTensorPrimitive = QFusionTensor<B::FusionRuntime>;
type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;

type QuantizedEncoding = B::QuantizedEncoding;

Expand Down Expand Up @@ -184,11 +182,8 @@ impl<B: FusionBackend> ReprBackend for Fusion<B> {
handle.handle
}

fn quantized_tensor(
_handles: QuantizedKind<TensorHandle<Self::Handle>>,
_scheme: burn_tensor::quantization::QuantizationScheme,
) -> QuantizedTensor<Self> {
todo!() // not as simple
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
handle.handle
}

fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
Expand All @@ -203,7 +198,7 @@ impl<B: FusionBackend> ReprBackend for Fusion<B> {
tensor
}

fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
todo!() // not as simple
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
tensor
}
}
22 changes: 11 additions & 11 deletions crates/burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::future::Future;

use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor,
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
repr::{OperationDescription, TensorDescription, TensorId},
DType, TensorData,
};

Expand Down Expand Up @@ -36,33 +36,33 @@ where
) -> FusionTensor<R>;
/// Read the values contained by a float tensor.
fn read_tensor_float<B>(
&self,
self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by an int tensor.
fn read_tensor_int<B>(
&self,
self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a bool tensor.
fn read_tensor_bool<B>(
&self,
self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a quantized tensor.
fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
self,
tensor: TensorDescription,
streams: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
Expand Down Expand Up @@ -108,10 +108,10 @@ where
/// Change the client of the given quantized tensor.
fn change_client_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
tensor: TensorDescription,
client: Self,
streams: Vec<StreamId>,
) -> QFusionTensor<R>
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>;
/// Drop the tensor with the given [tensor id](TensorId).
Expand Down
62 changes: 15 additions & 47 deletions crates/burn-fusion/src/client/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use super::FusionClient;
use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime,
FusionServer, FusionTensor, QFusionTensor,
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
repr::{OperationDescription, TensorDescription, TensorId},
DType,
};
use spin::Mutex;
Expand Down Expand Up @@ -80,7 +79,7 @@ where
}

fn read_tensor_float<B>(
&self,
self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = burn_tensor::TensorData> + 'static
Expand All @@ -92,7 +91,7 @@ where
}

fn read_tensor_int<B>(
&self,
self,
tensor: TensorDescription,
id: StreamId,
) -> impl Future<Output = burn_tensor::TensorData> + 'static
Expand All @@ -103,7 +102,7 @@ where
}

fn read_tensor_bool<B>(
&self,
self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = burn_tensor::TensorData> + 'static
Expand All @@ -114,14 +113,14 @@ where
}

fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_quantized::<B>(tensor, streams)
self.server.lock().read_quantized::<B>(tensor, stream)
}

fn change_client_float<B>(
Expand Down Expand Up @@ -190,55 +189,24 @@ where

fn change_client_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
tensor: TensorDescription,
client: Self,
streams: Vec<StreamId>,
) -> QFusionTensor<R>
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
for stream in streams {
server_current.drain_stream(stream);
}
server_current.drain_stream(stream);

let mut ids =
let id =
server_current.change_server_quantized::<B>(&tensor, &client.device, &mut server_other);

core::mem::drop(server_other);
core::mem::drop(server_current);

// NOTE: the expected order is known [qtensor, scale, <offset>]
let offset = tensor.qparams.offset.map(|desc| {
FusionTensor::new(
ids.pop().unwrap(),
desc.shape,
desc.dtype,
client.clone(),
StreamId::current(),
)
});
let scale = FusionTensor::new(
ids.pop().unwrap(),
tensor.qparams.scale.shape,
tensor.qparams.scale.dtype,
client.clone(),
StreamId::current(),
);
let qtensor = FusionTensor::new(
ids.pop().unwrap(),
tensor.tensor.shape,
tensor.tensor.dtype,
client,
StreamId::current(),
);

QFusionTensor {
qtensor,
scheme: tensor.scheme,
qparams: FusionQuantizationParameters { scale, offset },
}
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}

fn register_orphan(&self, id: &TensorId) {
Expand Down
Loading

0 comments on commit 0dd228c

Please sign in to comment.