mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Do not initialize scratch space when there are no ExLlamaV2 layers (#2015)
# What does this PR do? Do not attempt to allocate ExLlamaV2 scratch buffers when there are no ExLlama2 layers. Avoids a crash in warmup for models that cannot use exllama when ExLlamaV2 is installed. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
353a9669ba
commit
cdd120ac02
@ -1,10 +1,15 @@
|
|||||||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -15,6 +20,15 @@ except ImportError:
|
|||||||
none_tensor = torch.empty((1, 1), device="meta")
|
none_tensor = torch.empty((1, 1), device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ExtraTensors:
|
||||||
|
"""Additional generated quantizer tensors."""
|
||||||
|
|
||||||
|
q_group_map: Optional[torch.Tensor] = None
|
||||||
|
q_invperm: Optional[torch.Tensor] = None
|
||||||
|
q_perm: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
output_shape = x.shape[:-1] + (q4_width,)
|
output_shape = x.shape[:-1] + (q4_width,)
|
||||||
@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|||||||
return output.view(output_shape)
|
return output.view(output_shape)
|
||||||
|
|
||||||
|
|
||||||
# Group map needed for irregular group sizes
|
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
|
||||||
|
|
||||||
|
|
||||||
def make_group_map(q_groups, num_qrows):
|
|
||||||
|
|
||||||
gr = q_groups.tolist()
|
gr = q_groups.tolist()
|
||||||
group_map = []
|
group_map = []
|
||||||
num_groups = len(gr) // 2
|
num_groups = len(gr) // 2
|
||||||
@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows):
|
|||||||
# Create Q matrix
|
# Create Q matrix
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
def ext_make_q_matrix(
|
||||||
|
w: Exl2Weight | GPTQWeight,
|
||||||
|
extra: _ExtraTensors,
|
||||||
|
temp_dq,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create Q matrix
|
Create Q matrix
|
||||||
"""
|
"""
|
||||||
# EXL2
|
# EXL2
|
||||||
# won't work as the moment because the tensors are not the same.
|
if isinstance(w, Exl2Weight):
|
||||||
if "q_weight" in w:
|
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
||||||
w["q_scale_max"] /= 256
|
extra.q_perm = torch.argsort(w.q_invperm).short()
|
||||||
w["q_perm"] = w["q_perm"].short()
|
|
||||||
w["q_invperm"] = w["q_invperm"].short()
|
|
||||||
|
|
||||||
if "q_group_map" not in w:
|
|
||||||
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
|
|
||||||
|
|
||||||
return make_q_matrix(
|
return make_q_matrix(
|
||||||
w["q_weight"],
|
w.q_weight,
|
||||||
w["q_perm"],
|
extra.q_perm,
|
||||||
w["q_invperm"],
|
w.q_invperm,
|
||||||
w["q_scale"],
|
w.q_scale,
|
||||||
w["q_scale_max"],
|
w.q_scale_max,
|
||||||
w["q_groups"],
|
w.q_groups,
|
||||||
w["q_group_map"],
|
extra.q_group_map,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
temp_dq,
|
temp_dq,
|
||||||
)
|
)
|
||||||
# GPTQ
|
# GPTQ
|
||||||
elif "qweight" in w:
|
elif isinstance(w, GPTQWeight):
|
||||||
if w["scales"].dtype == torch.float:
|
if w.scales.dtype == torch.float:
|
||||||
w["scales"] = w["scales"].half()
|
w.scales = w.scales.half()
|
||||||
|
|
||||||
# GPTQ with g_idx (act_order)
|
# GPTQ with g_idx (act_order)
|
||||||
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
if w.g_idx is not None and not (w.g_idx == 0).all().item():
|
||||||
w["q_perm"] = torch.empty(
|
extra.q_perm = torch.empty(
|
||||||
(w["qweight"].shape[0] * 8,),
|
(w.qweight.shape[0] * 8,),
|
||||||
dtype=torch.short,
|
dtype=torch.short,
|
||||||
device=w["qweight"].device,
|
device=w.qweight.device,
|
||||||
)
|
)
|
||||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
extra.q_invperm = torch.empty_like(extra.q_perm)
|
||||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||||
return make_q_matrix(
|
return make_q_matrix(
|
||||||
w["qweight"],
|
w.qweight,
|
||||||
w["q_perm"],
|
extra.q_perm,
|
||||||
w["q_invperm"],
|
extra.q_invperm,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
w["qzeros"],
|
w.qzeros,
|
||||||
w["scales"],
|
w.scales,
|
||||||
w["g_idx"].cpu(),
|
w.g_idx.cpu(),
|
||||||
temp_dq,
|
temp_dq,
|
||||||
)
|
)
|
||||||
# GPTQ without g_idx
|
# GPTQ without g_idx
|
||||||
else:
|
else:
|
||||||
return make_q_matrix(
|
return make_q_matrix(
|
||||||
w["qweight"],
|
w.qweight,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
w["qzeros"],
|
w.qzeros,
|
||||||
w["scales"],
|
w.scales,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
temp_dq,
|
temp_dq,
|
||||||
)
|
)
|
||||||
@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
|
|
||||||
|
|
||||||
DEVICE = None
|
DEVICE = None
|
||||||
FIXED_BYTES = 0
|
|
||||||
LAYERS = []
|
LAYERS = []
|
||||||
|
|
||||||
|
|
||||||
@ -134,8 +143,19 @@ def set_device(device):
|
|||||||
|
|
||||||
|
|
||||||
def create_exllama_buffers(max_total_tokens: int):
|
def create_exllama_buffers(max_total_tokens: int):
|
||||||
global FIXED_BYTES, LAYERS, DEVICE
|
global LAYERS, DEVICE
|
||||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
|
||||||
|
# No need to initialize scratch space if there are no layers
|
||||||
|
# that use ExLLamav2.
|
||||||
|
if len(LAYERS) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the size of the scratch space.
|
||||||
|
scratch_bytes = max(
|
||||||
|
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
|
||||||
|
for layer in LAYERS
|
||||||
|
)
|
||||||
|
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
||||||
|
|
||||||
for layer in LAYERS:
|
for layer in LAYERS:
|
||||||
layer.post_init(temp_dq)
|
layer.post_init(temp_dq)
|
||||||
@ -146,49 +166,48 @@ class QuantLinear(nn.Module):
|
|||||||
|
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
|
|
||||||
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
def __init__(
|
||||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
self,
|
||||||
|
weight: Exl2Weight | GPTQWeight,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits != 4:
|
|
||||||
raise ValueError(
|
|
||||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
|
|
||||||
)
|
|
||||||
self.q_handle = None
|
self.q_handle = None
|
||||||
self.q_tensors = None
|
self.q_tensors = weight
|
||||||
self.bits = bits
|
self.extra_tensors = _ExtraTensors()
|
||||||
self.maxq = 2**self.bits - 1
|
|
||||||
self.infeatures = qweight.shape[0] // self.bits * 32
|
if isinstance(weight, Exl2Weight):
|
||||||
self.outfeatures = qweight.shape[1]
|
self.infeatures = weight.q_invperm.shape[0]
|
||||||
|
self.outfeatures = weight.q_weight.shape[1]
|
||||||
|
elif isinstance(weight, GPTQWeight):
|
||||||
|
if weight.bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
|
||||||
|
self.outfeatures = weight.qweight.shape[1]
|
||||||
|
|
||||||
self.padding = -self.outfeatures % 32
|
self.padding = -self.outfeatures % 32
|
||||||
self.outfeatures = self.outfeatures + self.padding
|
self.outfeatures = self.outfeatures + self.padding
|
||||||
|
|
||||||
self.device = qweight.device
|
self.device = weight.device
|
||||||
self.qweight = qweight
|
|
||||||
self.qzeros = qzeros
|
|
||||||
self.scales = scales
|
|
||||||
self.g_idx = g_idx
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
self.group_size = groupsize
|
|
||||||
|
|
||||||
global FIXED_BYTES, LAYERS
|
global LAYERS
|
||||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
|
||||||
LAYERS.append(self)
|
LAYERS.append(self)
|
||||||
|
|
||||||
def post_init(self, temp_dq):
|
def post_init(self, temp_dq):
|
||||||
assert self.qweight.device.type == "cuda"
|
device = self.q_tensors.device
|
||||||
assert self.qweight.device.index is not None
|
assert device.type == "cuda"
|
||||||
self.q_tensors = {
|
assert device.index is not None
|
||||||
"qweight": self.qweight,
|
|
||||||
"qzeros": self.qzeros,
|
|
||||||
"scales": self.scales,
|
|
||||||
"g_idx": self.g_idx,
|
|
||||||
}
|
|
||||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||||
|
|
||||||
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
||||||
# and `Memory access fault by GPU node-2` will EAT you.
|
# and `Memory access fault by GPU node-2` will EAT you.
|
||||||
self.temp_dq = temp_dq
|
self.temp_dq = temp_dq
|
||||||
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
|
||||||
|
|
||||||
def forward(self, x, force_cuda=False):
|
def forward(self, x, force_cuda=False):
|
||||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||||
@ -203,7 +222,7 @@ class QuantLinear(nn.Module):
|
|||||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||||
|
|
||||||
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
|
def scratch_space_fixed(self, max_input_len, max_batch_size):
|
||||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user