Simpler exllama

This commit is contained in:
Nicolas Patry 2023-07-20 15:36:53 +00:00
parent 6bf7090ecd
commit 5ca0508d02
29 changed files with 223 additions and 276 deletions

View File

@ -56,6 +56,3 @@ run-bloom:
run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
clean:
rm -rf target aml

View File

@ -230,19 +230,16 @@ def launcher(event_loop):
shard_uds_path,
]
env = os.environ
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize is not None:
args.append("--quantize")
args.append(quantize)
if quantize == "gptq":
env["GPTQ_GROUPSIZE"] = "128"
env["GPTQ_BITS"] = "4"
if trust_remote_code:
args.append("--trust-remote-code")
env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention:

View File

@ -1,102 +1,103 @@
{
"generated_text": ", and I am going to visit the Louvre",
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"seed": null,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 20628,
"text": "Today",
"logprob": -11.2265625
"logprob": -10.328125,
"text": "Today"
},
{
"id": 306,
"text": "I",
"logprob": -4.1757812
"logprob": -2.390625,
"text": "I"
},
{
"id": 626,
"text": "am",
"logprob": -1.9746094
"logprob": -1.8857422,
"text": "am"
},
{
"id": 297,
"text": "in",
"logprob": -5.4648438
"logprob": -4.4765625,
"text": "in"
},
{
"id": 3444,
"text": "France",
"logprob": -9.03125
"logprob": -7.0703125,
"text": "France"
}
],
"seed": null,
"tokens": [
{
"id": 29892,
"text": ",",
"logprob": -0.31298828,
"special": false
"logprob": -1.2910156,
"special": false,
"text": ","
},
{
"id": 322,
"text": " and",
"logprob": -1.4345703,
"special": false
},
{
"id": 306,
"text": " I",
"logprob": -0.32080078,
"special": false
},
{
"id": 626,
"text": " am",
"logprob": -1.3798828,
"special": false
},
{
"id": 2675,
"text": " going",
"logprob": -1.2304688,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": -0.0014791489,
"special": false
},
{
"id": 6493,
"text": " visit",
"logprob": -1.1503906,
"special": false
"id": 297,
"logprob": -1.9394531,
"special": false,
"text": " in"
},
{
"id": 278,
"text": " the",
"logprob": -0.41259766,
"special": false
"logprob": -0.7597656,
"special": false,
"text": " the"
},
{
"id": 4562,
"text": " Lou",
"logprob": -1.8134766,
"special": false
"id": 7062,
"logprob": -2.9121094,
"special": false,
"text": " south"
},
{
"id": 12675,
"text": "vre",
"logprob": -0.000767231,
"special": false
"id": 310,
"logprob": -1.0302734,
"special": false,
"text": " of"
},
{
"id": 278,
"logprob": -0.58203125,
"special": false,
"text": " the"
},
{
"id": 4234,
"logprob": -0.2944336,
"special": false,
"text": " country"
},
{
"id": 29892,
"logprob": -0.7011719,
"special": false,
"text": ","
},
{
"id": 297,
"logprob": -1.1054688,
"special": false,
"text": " in"
},
{
"id": 278,
"logprob": -0.52490234,
"special": false,
"text": " the"
}
]
}
},
"generated_text": ", in the south of the country, in the"
}

View File

@ -1,10 +1,9 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher):
with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle:
with launcher("huggingface/llama-7b-gptq", num_shard=4, quantize="gptq") as handle:
yield handle

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_starcoder_gptq_handle(launcher):
with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle:
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
yield handle
@ -46,4 +46,4 @@ async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, re
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
assert responses == response_snapshot

View File

@ -3,7 +3,7 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
// #include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>

View File

@ -14,7 +14,7 @@ setup(
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
CppExtension(
CUDAExtension(
name="custom_kernels.exllama",
sources=[
"custom_kernels/exllama/exllama_ext.cpp",

View File

@ -500,7 +500,6 @@ class CausalLM(Model):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -378,6 +378,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.config = config
process_group = weights.process_group
self.tp_rank = process_group.rank()
@ -448,7 +449,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.config = config
self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,

View File

@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
)
from safetensors import SafetensorError
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
@ -71,12 +72,19 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits, groupsize = weights.get_gptq_qparams()
try:
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
qweight = qweight.to(weights.device)
qzeros = qzeros.to(weights.device)
scales = scales.to(weights.device)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
@ -90,8 +98,6 @@ def _load_multi_mqa_gptq(
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
@ -355,7 +361,7 @@ class Block(nn.Module):
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlen_prefill,

View File

@ -6,8 +6,9 @@ import torch.distributed
import numpy as np
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
@ -20,7 +21,6 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16
@ -684,7 +684,6 @@ class FlashCausalLM(Model):
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
config: PretrainedConfig,
num_layers: int,
num_kv_heads: int,
head_size: int,
@ -700,7 +699,6 @@ class FlashCausalLM(Model):
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=False,
dtype=dtype,
device=device,

View File

@ -68,7 +68,6 @@ class FlashLlama(FlashCausalLM):
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,

View File

@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM):
super(FlashNeoXSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,

View File

@ -65,7 +65,6 @@ class FlashRWSharded(FlashCausalLM):
super(FlashRWSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,

View File

@ -66,7 +66,6 @@ class FlashSantacoderSharded(FlashCausalLM):
super(FlashSantacoderSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,

View File

@ -198,7 +198,6 @@ class GalacticaSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -63,7 +63,6 @@ class GPTNeoxSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -3,27 +3,19 @@ import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
from custom_kernels.exllama import prepare_buffers, set_tuning_params
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear
)
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
config: PretrainedConfig,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
@ -46,47 +38,6 @@ class Model(ABC):
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.config = config
if config.quantize == "gptq":
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
use_exllama_act_order = False
max_dq_buffer_size = 1
max_inner_outer_dim = 1
for name, submodule in model.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
if submodule.linear.act_order:
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
use_exllama_act_order = True
if use_exllama_act_order:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
max_total_tokens = 2048
else:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()
self.check_initialized()

View File

@ -86,7 +86,6 @@ class MPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=False,
dtype=dtype,
device=device,

View File

@ -61,7 +61,6 @@ class OPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -58,7 +58,6 @@ class RW(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -63,7 +63,6 @@ class SantaCoder(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -542,7 +542,6 @@ class Seq2SeqLM(Model):
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -73,7 +73,6 @@ class T5Sharded(Seq2SeqLM):
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -140,6 +140,13 @@ def serve(
logger.exception("Error when initializing model")
raise
try:
from text_generation_server.utils.gptq.exllama import create_buffers
create_buffers()
logger.info("Created exllama GPTQ buffers !")
except ImportError:
pass
server = aio.server(
interceptors=[
ExceptionInterceptor(),

View File

@ -0,0 +1,89 @@
import torch
from custom_kernels.exllama import make_q4, q4_matmul, set_tuning_params, prepare_buffers
from loguru import logger
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
import os
RANK = os.getenv("RANK", "0")
DEVICE = torch.device(f"cuda:{RANK}")
MAX_TOTAL_TOKENS = 1
MAX_INNER_OUTER_DIM = 0
MAX_DQ_BUFFER_SIZE = 0
def create_buffers():
temp_state = torch.zeros((MAX_TOTAL_TOKENS, MAX_INNER_OUTER_DIM), dtype=torch.float16, device=DEVICE)
temp_dq = torch.zeros((1, MAX_DQ_BUFFER_SIZE), dtype=torch.float16, device=DEVICE)
logger.info(f"Creating buffers {temp_state.shape} - {temp_dq.shape} - {DEVICE}")
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, bias, bits):
assert bits == 4, "We cannot run exllama GPTQ kernels if bits != 4"
global MAX_INNER_OUTER_DIM, MAX_DQ_BUFFER_SIZE
dq = qweight.numel() * 8
if dq > MAX_DQ_BUFFER_SIZE:
MAX_DQ_BUFFER_SIZE = dq
width = qweight.shape[1]
if width > MAX_INNER_OUTER_DIM:
MAX_INNER_OUTER_DIM = width
height = qweight.shape[0] * 8
if height > MAX_INNER_OUTER_DIM:
MAX_INNER_OUTER_DIM = height
# prepare_buffers(DEVICE, TEMP_STATE, TEMP_DQ)
self.q4 = make_q4(
qweight,
qzeros,
scales,
# Never send g_idx, it MUST be like act_order=False, the exllama kernel does not expect it
torch.zeros((0, 0), device=torch.device("meta")),
DEVICE.index
)
self.bias = bias if bias is not None else None
self.width = width
# # Infer groupsize from height of qzeros
# self.groupsize = None
# if self.qzeros.shape[0] > 1:
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
# if self.groupsize is not None:
# assert groupsize == self.groupsize
# # Handle act-order matrix
# if self.g_idx is not None:
# if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
# self.act_order = True
# else:
# self.act_order = False
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -8,11 +8,6 @@ import torch
from loguru import logger
try:
from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e:
logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}")
try:
import triton
import triton.language as tl
@ -368,76 +363,3 @@ class QuantLinear(nn.Module):
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
assert bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -1,5 +1,6 @@
import torch
import torch.distributed
from loguru import logger
from torch import nn
from torch.nn import functional as F
@ -15,7 +16,14 @@ except ImportError:
from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_EXLLAMA = True
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
except ImportError:
logger.error(f"The CUDA kernels custom_kernels.exllama not installed using base triton kernel")
HAS_EXLLAMA = False
from typing import Optional
@ -145,13 +153,15 @@ def get_linear(weight, bias, quantize):
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
if use_triton_kernel or bits != 4:
if can_exllama and HAS_EXLLAMA:
linear = Ex4bitLinear(qweight, qzeros, scales, bias, bits)
else:
linear = QuantLinear(
qweight,
qzeros,
@ -161,8 +171,6 @@ def get_linear(weight, bias, quantize):
bits,
groupsize,
)
else:
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear

View File

@ -3,7 +3,6 @@ from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError
import torch
class Weights:
def __init__(
self,
@ -126,9 +125,15 @@ class Weights:
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
can_exllama = True
bits, groupsize = self._get_gptq_qparams()
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
can_exllama = False
bits, groupsize = self.get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
bits, groupsize = self._get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
@ -136,52 +141,32 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_triton_kernel = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
_, groupsize = self.get_gptq_qparams()
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_triton_kernel = True
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
bits, groupsize = self._get_gptq_qparams()
g_idx = self.get_tensor(f"{prefix}.g_idx")
bits, groupsize = self.get_gptq_qparams()
can_exllama = True
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
can_exllama = False
if use_triton_kernel:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
else:
if groupsize >= 16:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def get_gptq_qparams(self) -> Tuple[int, int]:
def _get_gptq_qparams(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
@ -194,4 +179,4 @@ class Weights:
except Exception:
raise e
return bits, groupsize
return bits, groupsize