mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Simpler exllama
This commit is contained in:
parent
6bf7090ecd
commit
5ca0508d02
3
Makefile
3
Makefile
@ -56,6 +56,3 @@ run-bloom:
|
||||
|
||||
run-bloom-quantize:
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
||||
|
||||
clean:
|
||||
rm -rf target aml
|
||||
|
@ -230,19 +230,16 @@ def launcher(event_loop):
|
||||
shard_uds_path,
|
||||
]
|
||||
|
||||
env = os.environ
|
||||
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize is not None:
|
||||
args.append("--quantize")
|
||||
args.append(quantize)
|
||||
if quantize == "gptq":
|
||||
env["GPTQ_GROUPSIZE"] = "128"
|
||||
env["GPTQ_BITS"] = "4"
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
env = os.environ
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
if not use_flash_attention:
|
||||
|
@ -1,102 +1,103 @@
|
||||
{
|
||||
"generated_text": ", and I am going to visit the Louvre",
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 20628,
|
||||
"text": "Today",
|
||||
"logprob": -11.2265625
|
||||
"logprob": -10.328125,
|
||||
"text": "Today"
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": "I",
|
||||
"logprob": -4.1757812
|
||||
"logprob": -2.390625,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": "am",
|
||||
"logprob": -1.9746094
|
||||
"logprob": -1.8857422,
|
||||
"text": "am"
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"text": "in",
|
||||
"logprob": -5.4648438
|
||||
"logprob": -4.4765625,
|
||||
"text": "in"
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -9.03125
|
||||
"logprob": -7.0703125,
|
||||
"text": "France"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29892,
|
||||
"text": ",",
|
||||
"logprob": -0.31298828,
|
||||
"special": false
|
||||
"logprob": -1.2910156,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"text": " and",
|
||||
"logprob": -1.4345703,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": " I",
|
||||
"logprob": -0.32080078,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": " am",
|
||||
"logprob": -1.3798828,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2675,
|
||||
"text": " going",
|
||||
"logprob": -1.2304688,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": -0.0014791489,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6493,
|
||||
"text": " visit",
|
||||
"logprob": -1.1503906,
|
||||
"special": false
|
||||
"id": 297,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"text": " the",
|
||||
"logprob": -0.41259766,
|
||||
"special": false
|
||||
"logprob": -0.7597656,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 4562,
|
||||
"text": " Lou",
|
||||
"logprob": -1.8134766,
|
||||
"special": false
|
||||
"id": 7062,
|
||||
"logprob": -2.9121094,
|
||||
"special": false,
|
||||
"text": " south"
|
||||
},
|
||||
{
|
||||
"id": 12675,
|
||||
"text": "vre",
|
||||
"logprob": -0.000767231,
|
||||
"special": false
|
||||
"id": 310,
|
||||
"logprob": -1.0302734,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"logprob": -0.58203125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 4234,
|
||||
"logprob": -0.2944336,
|
||||
"special": false,
|
||||
"text": " country"
|
||||
},
|
||||
{
|
||||
"id": 29892,
|
||||
"logprob": -0.7011719,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"logprob": -1.1054688,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"logprob": -0.52490234,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"generated_text": ", in the south of the country, in the"
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_gptq_handle(launcher):
|
||||
with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle:
|
||||
with launcher("huggingface/llama-7b-gptq", num_shard=4, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder_gptq_handle(launcher):
|
||||
with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle:
|
||||
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@ -46,4 +46,4 @@ async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, re
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
assert responses == response_snapshot
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
// #include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
@ -14,7 +14,7 @@ setup(
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
),
|
||||
CppExtension(
|
||||
CUDAExtension(
|
||||
name="custom_kernels.exllama",
|
||||
sources=[
|
||||
"custom_kernels/exllama/exllama_ext.cpp",
|
||||
|
@ -500,7 +500,6 @@ class CausalLM(Model):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -378,6 +378,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
@ -448,7 +449,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
from safetensors import SafetensorError
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
@ -71,12 +72,19 @@ def _load_multi_mqa_gptq(
|
||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
bits, groupsize = weights.get_gptq_qparams()
|
||||
try:
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
qweight = qweight.to(weights.device)
|
||||
qzeros = qzeros.to(weights.device)
|
||||
scales = scales.to(weights.device)
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
|
||||
if bias:
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
@ -90,8 +98,6 @@ def _load_multi_mqa_gptq(
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
|
||||
bias = bias.to(weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
else:
|
||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||
@ -355,7 +361,7 @@ class Block(nn.Module):
|
||||
max_s,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
|
||||
|
||||
hidden_states = self.attn(
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
|
@ -6,8 +6,9 @@ import torch.distributed
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -20,7 +21,6 @@ from text_generation_server.models.types import (
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
@ -684,7 +684,6 @@ class FlashCausalLM(Model):
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
config: PretrainedConfig,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
@ -700,7 +699,6 @@ class FlashCausalLM(Model):
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -68,7 +68,6 @@ class FlashLlama(FlashCausalLM):
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
|
@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
super(FlashNeoXSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
num_kv_heads=model.gpt_neox.num_heads,
|
||||
head_size=model.gpt_neox.head_size,
|
||||
|
@ -65,7 +65,6 @@ class FlashRWSharded(FlashCausalLM):
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=model.transformer.cache_size,
|
||||
head_size=model.transformer.head_size,
|
||||
|
@ -66,7 +66,6 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=1,
|
||||
head_size=model.transformer.head_size,
|
||||
|
@ -198,7 +198,6 @@ class GalacticaSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -63,7 +63,6 @@ class GPTNeoxSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -3,27 +3,19 @@ import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
||||
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear
|
||||
)
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
config: PretrainedConfig,
|
||||
requires_padding: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
@ -46,47 +38,6 @@ class Model(ABC):
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
self.config = config
|
||||
|
||||
if config.quantize == "gptq":
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
self.buffers = {}
|
||||
use_exllama_act_order = False
|
||||
max_dq_buffer_size = 1
|
||||
max_inner_outer_dim = 1
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||
|
||||
if submodule.linear.act_order:
|
||||
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
|
||||
|
||||
use_exllama_act_order = True
|
||||
|
||||
if use_exllama_act_order:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
max_total_tokens = 2048
|
||||
else:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
|
||||
|
||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
|
||||
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.check_initialized()
|
||||
|
||||
|
@ -86,7 +86,6 @@ class MPTSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -61,7 +61,6 @@ class OPTSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -58,7 +58,6 @@ class RW(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -63,7 +63,6 @@ class SantaCoder(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -542,7 +542,6 @@ class Seq2SeqLM(Model):
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -73,7 +73,6 @@ class T5Sharded(Seq2SeqLM):
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -140,6 +140,13 @@ def serve(
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import create_buffers
|
||||
create_buffers()
|
||||
logger.info("Created exllama GPTQ buffers !")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
|
89
server/text_generation_server/utils/gptq/exllama.py
Normal file
89
server/text_generation_server/utils/gptq/exllama.py
Normal file
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
from custom_kernels.exllama import make_q4, q4_matmul, set_tuning_params, prepare_buffers
|
||||
from loguru import logger
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
import os
|
||||
RANK = os.getenv("RANK", "0")
|
||||
DEVICE = torch.device(f"cuda:{RANK}")
|
||||
MAX_TOTAL_TOKENS = 1
|
||||
MAX_INNER_OUTER_DIM = 0
|
||||
MAX_DQ_BUFFER_SIZE = 0
|
||||
|
||||
|
||||
def create_buffers():
|
||||
temp_state = torch.zeros((MAX_TOTAL_TOKENS, MAX_INNER_OUTER_DIM), dtype=torch.float16, device=DEVICE)
|
||||
temp_dq = torch.zeros((1, MAX_DQ_BUFFER_SIZE), dtype=torch.float16, device=DEVICE)
|
||||
logger.info(f"Creating buffers {temp_state.shape} - {temp_dq.shape} - {DEVICE}")
|
||||
|
||||
prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
class Ex4bitLinear:
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
def __init__(self, qweight, qzeros, scales, bias, bits):
|
||||
assert bits == 4, "We cannot run exllama GPTQ kernels if bits != 4"
|
||||
|
||||
global MAX_INNER_OUTER_DIM, MAX_DQ_BUFFER_SIZE
|
||||
dq = qweight.numel() * 8
|
||||
if dq > MAX_DQ_BUFFER_SIZE:
|
||||
MAX_DQ_BUFFER_SIZE = dq
|
||||
|
||||
width = qweight.shape[1]
|
||||
if width > MAX_INNER_OUTER_DIM:
|
||||
MAX_INNER_OUTER_DIM = width
|
||||
height = qweight.shape[0] * 8
|
||||
if height > MAX_INNER_OUTER_DIM:
|
||||
MAX_INNER_OUTER_DIM = height
|
||||
|
||||
# prepare_buffers(DEVICE, TEMP_STATE, TEMP_DQ)
|
||||
|
||||
|
||||
self.q4 = make_q4(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
# Never send g_idx, it MUST be like act_order=False, the exllama kernel does not expect it
|
||||
torch.zeros((0, 0), device=torch.device("meta")),
|
||||
DEVICE.index
|
||||
)
|
||||
self.bias = bias if bias is not None else None
|
||||
self.width = width
|
||||
|
||||
# # Infer groupsize from height of qzeros
|
||||
# self.groupsize = None
|
||||
# if self.qzeros.shape[0] > 1:
|
||||
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
# if self.groupsize is not None:
|
||||
# assert groupsize == self.groupsize
|
||||
|
||||
# # Handle act-order matrix
|
||||
# if self.g_idx is not None:
|
||||
# if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||
# self.act_order = True
|
||||
# else:
|
||||
# self.act_order = False
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x, self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
@ -8,11 +8,6 @@ import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from custom_kernels.exllama import make_q4, q4_matmul
|
||||
except Exception as e:
|
||||
logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}")
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -368,76 +363,3 @@ class QuantLinear(nn.Module):
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device = "meta")
|
||||
|
||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
"""Construct Q4Matrix, return handle"""
|
||||
return make_q4(qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx if g_idx is not None else none_tensor,
|
||||
device)
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
class Ex4bitLinear:
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
assert bits == 4
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
||||
self.empty_g_idx = True
|
||||
self.g_idx = None
|
||||
|
||||
assert self.device.type == "cuda"
|
||||
assert self.device.index is not None
|
||||
|
||||
self.q4 = ext_make_q4(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
self.scales,
|
||||
self.g_idx,
|
||||
self.device.index
|
||||
)
|
||||
|
||||
self.height = qweight.shape[0] * 8
|
||||
self.width = qweight.shape[1]
|
||||
|
||||
# Infer groupsize from height of qzeros
|
||||
self.groupsize = None
|
||||
if self.qzeros.shape[0] > 1:
|
||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||
self.act_order = True
|
||||
else:
|
||||
self.act_order = False
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x, self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
from loguru import logger
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -15,7 +16,14 @@ except ImportError:
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
HAS_EXLLAMA = True
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||
except ImportError:
|
||||
logger.error(f"The CUDA kernels custom_kernels.exllama not installed using base triton kernel")
|
||||
HAS_EXLLAMA = False
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@ -145,13 +153,15 @@ def get_linear(weight, bias, quantize):
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
if use_triton_kernel or bits != 4:
|
||||
if can_exllama and HAS_EXLLAMA:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, bias, bits)
|
||||
else:
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
@ -161,8 +171,6 @@ def get_linear(weight, bias, quantize):
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
else:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
@ -3,7 +3,6 @@ from typing import List, Dict, Optional, Tuple
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
@ -126,9 +125,15 @@ class Weights:
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
can_exllama = True
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
can_exllama = False
|
||||
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
@ -136,52 +141,32 @@ class Weights:
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_triton_kernel = False
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
_, groupsize = self.get_gptq_qparams()
|
||||
|
||||
if g_idx is not None:
|
||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_triton_kernel = True
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
can_exllama = True
|
||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
can_exllama = False
|
||||
|
||||
if use_triton_kernel:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
else:
|
||||
if groupsize >= 16:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
def _get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
@ -194,4 +179,4 @@ class Weights:
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
return bits, groupsize
|
||||
|
Loading…
Reference in New Issue
Block a user