-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelpers.py
More file actions
274 lines (225 loc) · 10 KB
/
helpers.py
File metadata and controls
274 lines (225 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import torch
import tqdm
import contextlib
import inspect
import numpy as np
import json
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoConfig, AutoTokenizer
from typing import Dict, List
from GPTQ.compressed_linear import Compressed_linear
from GPTQ.QuantizationConfig import QuantizationConfig
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # switch to INFO to suppress debug messages
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(asctime)s | %(levelname)5s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"))
logger.addHandler(handler)
# ───
class ExitEarlyException(Exception):
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs
def initialize_scales_zeroes(module: nn.Module):
'''
Given a module with the required parameters,
creates a scale and zero parameter for the module
Note: The code doesnt take offloading into consideration need to look at that logic separtely
'''
weight = module.weight
shape = (weight.shape[1],1)
logger.debug(f"Initializing scales/zeros for {module} on device={weight.device}, shape={shape}")
# Stores it in the same device as the weights
scale = nn.Parameter(torch.empty(shape, dtype = weight.dtype, device = weight.device),requires_grad = False)
# nbits are 4 for now
zeros = nn.Parameter(torch.empty(shape, dtype = torch.int8, device = weight.device), requires_grad=False)
module.register_parameter("scales", scale)
module.register_parameter("zeros", zeros)
logger.info(f"Registered `scales` and `zeros` on {module}")
def move_inputs_to_device(inputs, device):
if inputs is None:
return None
if isinstance(inputs, (bool, int, float, str)):
return inputs
if isinstance(inputs, torch.Tensor):
return inputs.to(device)
elif isinstance(inputs, dict):
return {key: move_inputs_to_device(value, device) for key, value in inputs.items()}
elif isinstance(inputs, tuple):
return tuple(move_inputs_to_device(input, device) for input in inputs)
else:
print(inputs)
raise ValueError(f"Unsupported input type: {type(inputs)}. Expected Tensor or tuple of Tensors.")
@contextlib.contextmanager
def disable_cache(model):
if hasattr(model.config , "use_cache"):
model.config.use_cache = False
yield
model.config.use_cache = True
def find_blocks(model: nn.Module):
'''
Todo:code this
'''
modules = []
blocks = model._get_no_split_modules("auto")
for name, module in model.named_modules():
if module.__class__.__name__ in blocks:
modules.append(module)
return modules
def update_parameter(module: nn.Module, name: str, value: torch.Tensor):
'''
Update the parameter in the module with the new tensor.
The module should already contain the parameter to update
'''
if not hasattr(module, name):
logger.error(f"Module {module} has no parameter named `{name}`")
raise AttributeError(f"{module} missing parameter `{name}`")
# Check if the shape of the tensor is same as the parameter shape
param = getattr(module, name)
assert value.shape == param.shape, f"Shape mismatch updating {name}: {value.shape} vs {param.shape}"
logger.debug(f"Updating parameter `{name}` of {module} with shape={param.shape}")
# Added this to prevent runtime error
value = value.to(param.dtype)
param.data.copy_(value)
logger.info(f"Parameter `{name}` updated on {module}")
def add_quantization_config(model:nn.Module, config: QuantizationConfig):
'''
config contains the quantization configuration related parameters it is of the form
quantization_config: {
group: {
targets: ["Linear"],
weights: {
num_bits: 4,
type: "int",
}
},
format: "pack",
scheme: "gptq",
ignore: ["lm_head"],
status: "unquantized",
}
Adds quantization configurations to the respective layers in the model
Todo: Clean this up properly
'''
ignore = config.ignore
targets = config.targets
logger.info("Adding quantization config to model")
for name,module in model.named_modules():
#Todo : Create a function which checks this directly
layer_type = module.__class__.__name__
if name not in ignore and layer_type in targets:
logger.debug(f" Configuring `{name}` ({module.__class__.__name__})")
module.quantization_scheme = config.scheme
module.quantization_args = config.quantization_args
# Just initializing zeros and scales into the module not using the compressed layer currently
initialize_scales_zeroes(module)
#compressed_linear = Compressed_linear.from_linear(module, format, quantization_args)
model.status = config.status
model.format = config.format
logger.info(f"Model quantization status set to `{model.status}`")
def extract_config(config_file_path):
logger.info(f"Loading quantization config from `{config_file_path}`")
with open(config_file_path,'r') as config_file:
data = json.dump(config_file)
config = data.get("quantization_config",None)
return config
def find_layers(model: torch.nn.Module):
'''
Finds all the layers for quantization given the configuration
'''
layers = []
for name, module in model.named_modules():
if getattr(module, "quantization_args", None) is not None:
layers.append(module)
logger.info(f"Discovered {len(layers)} quantizable layers")
return layers
@contextlib.contextmanager
def stop_early(module):
def hook(module, args, kwargs):
raise ExitEarlyException(args, kwargs)
handle = module.register_forward_pre_hook(hook, with_kwargs=True)
logger.debug(f"Installed early-stop hook on {module}")
try:
yield
finally:
handle.remove()
logger.debug(f"Removed early-stop hook from {module}")
# Packs the args and kwargs according to the signature input for the module
def packs_args(module : nn.Module,args, kwargs = None):
signature = inspect.signature(module.forward)
layer_args = {name: value for name, value in zip(signature.parameters.keys(), args)}
if kwargs is not None:
layer_args.update(kwargs)
logger.debug(f"Packed args keys for {module}: {layer_args}")
return layer_args
@contextlib.contextmanager
def load_to_device(module: nn.Module):
'''
Loads the module to the device of the layer
'''
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.debug(f"Loading module `{module}` to device `{device}`")
module.to(device)
yield
logger.debug(f"Unloading module `{module}` from device `{device}`")
module.to("cpu")
@contextlib.contextmanager
def eval_context(module: torch.nn.Module):
restore_value = module.training
try:
module.train(False) # equivalent to eval()
yield
finally:
module.train(restore_value)
def findIntermediate(module: torch.nn.Module, layer: torch.nn.Module, dataloader: torch.Tensor):
logger.info(f"Capturing inputs to layer {layer}")
intermediate_input = []
with torch.no_grad(),stop_early(layer),load_to_device(module),disable_cache(module):
# Todo: Look at how the input is read again
device = next(module.parameters()).device
for batch_index,batch in enumerate(tqdm.tqdm(dataloader)):
try:
# move this to the device of the layer
batch = {k: v.to(device) for k, v in batch.items()}
module(**batch)
except ExitEarlyException as e:
layer_args = packs_args(layer, e.args, e.kwargs)
intermediate_input.append(layer_args)
else:
raise ValueError("Early stop exception not raised for intermediate input")
logger.info(f"Captured {len(intermediate_input)} intermediate inputs")
return intermediate_input
def update_intermediate(module: nn.Module, intermediate_input: List, batch_index, output, inputs):
# Update the input according to the signature of the module
layer_args = packs_args(module, output)
signature = inspect.signature(module.forward)
# Need to manually add positional embeddings if present in the input
if "position_embeddings" in signature.parameters.keys() and "position_embeddings" in inputs and "position_embeddings" not in layer_args:
layer_args["position_embeddings"] = inputs["position_embeddings"]
intermediate_input[batch_index] = layer_args
logger.debug(f"Updated intermediate for batch {batch_index} at {module}")
def compress_and_save(model, config_file_path, quantization_config, output_dir = None):
# Check if the model needs to be packed before saving
logger.info("Applying compressed_linear wrappers for GPTQ modules")
for name, module in model.named_modules():
if getattr(module,'quantization_scheme', None) == "gptq":
logger.debug(f" Wrapping module `{name}`")
compressed_linear = Compressed_linear.from_linear(module, model.format, module.quantization_args)
setattr(module, name, compressed_linear)
update_config(model, config_file_path, quantization_config)
if output_dir:
logger.info(f"Saving model to `{output_dir}`")
model.save_pretrained(output_dir)
def update_config(model: nn.Module, config_file_path, quantization_config):
logger.info(f"Updating JSON config at `{config_file_path}`")
with open(config_file_path,'r') as config_file:
config_data = json.load(config_file)
config_data["quantization_config"] = quantization_config["quantization_config"]
logger.debug("New config data written into memory")
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2)
logger.info("JSON config file updated")