-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpack_quantized.py
More file actions
49 lines (36 loc) · 1.55 KB
/
pack_quantized.py
File metadata and controls
49 lines (36 loc) · 1.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import numpy as np
import torch
import numpy as np
def pack_int32(weight, quantization_args):
'''
The weight needs to be quantized to int8 before applying packing.
Also the packing is done across columns and the number of columns should be divisible by 32/num_bits
In our current case that is the number examples we are using which is 128.
'''
print("weight ", weight)
num_bits = quantization_args.num_bits
num_elements = 32 // num_bits
weight = weight.cpu().numpy().astype(np.uint32)
packed_weight = np.zeros((weight.shape[0],weight.shape[1]//num_elements), dtype = np.uint32)
print('packed_weight',packed_weight)
for i in range(0 ,weight.shape[1],num_elements):
ind = 0
for j in range(i,i+num_elements):
packed_weight[:,ind] |= (weight[:,j]<<(num_bits*(j-i)))
ind+=1
packed_weight = np.ascontiguousarray(packed_weight).view(np.int32)
return torch.from_numpy(packed_weight)
def unpack_int32(packed_weight , quantization_args):
'''
'''
num_bits = quantization_args.num_bits
num_elements = 32 // num_bits
mask = (1<<num_bits)-1
unpacked_weight = torch.zeros((packed_weight.shape[0],packed_weight.shape[1]*num_elements), device = packed_weight.device, dtype = torch.int32)
for i in range(0 ,packed_weight.shape[1],num_elements):
ind = 0
for j in range(i,i+num_elements):
unpacked_weight[:,j] = (packed_weight[:,ind]>>(num_bits*(j-i)))& mask
ind+=1
return unpacked_weight.to(torch.int8)