-
Notifications
You must be signed in to change notification settings - Fork 0
/
Quantization.py
41 lines (30 loc) · 1.25 KB
/
Quantization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
###################
# Quantization
###################
import numpy as np
def uniform_quantization(weights, num_bits):
q_min, q_max = -2**(num_bits - 1), 2**(num_bits - 1) - 1
min_val, max_val = np.min(weights), np.max(weights)
scale = (max_val - min_val) / (q_max - q_min)
quantized_weights = np.round((weights - min_val) / scale + q_min)
return quantized_weights.astype(int), scale, min_val
def dequantize_weights(quantized_weights, scale, min_val, q_min):
dequantized_weights = (quantized_weights - q_min) * scale + min_val
return dequantized_weights
# Define a simple weight matrix
weights = np.array([[0.2, -0.5, 0.8],
[0.6, 0.1, -0.3]])
# Set the number of bits for quantization (e.g., 8 bits)
num_bits = 8
# Quantize the weights
integer_quantized_weights, scale, min_val = uniform_quantization(weights, num_bits)
# Calculate the quantization range minimum
q_min = -2**(num_bits - 1)
# Dequantize the weights
dequantized_weights = dequantize_weights(integer_quantized_weights, scale, min_val, q_min)
print("Original weights:")
print(weights)
print("\nInteger quantized weights:")
print(integer_quantized_weights)
print("\nDequantized weights:")
print(dequantized_weights)