mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixes.
This commit is contained in:
parent
fe4ef95d92
commit
ddc0dd57f7
@ -8,7 +8,6 @@ from text_generation_server.layers.linear import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import (
|
|
||||||
get_linear,
|
# Just to add the `load` methods.
|
||||||
FastLinear,
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
)
|
|
||||||
|
97
server/text_generation_server/layers/awq/conversion_utils.py
Normal file
97
server/text_generation_server/layers/awq/conversion_utils.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||||
|
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of packing, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||||
|
|
||||||
|
qmatrix = qmatrix.to(torch.int32)
|
||||||
|
|
||||||
|
return qmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
qmatrix (torch.Tensor): matrix of packed integers
|
||||||
|
direction (str): direction of unpacking, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, :, None], shifts[None, None, :]
|
||||||
|
).view(qmatrix.shape[0], -1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, None, :], shifts[None, :, None]
|
||||||
|
).view(-1, qmatrix.shape[-1])
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def apply_order(
|
||||||
|
imatrix: torch.Tensor,
|
||||||
|
direction: str = "column",
|
||||||
|
order: List[int] = AWQ_PACK_ORDER,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies the order to a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of applying order, either "column" or "row"
|
||||||
|
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def fast_awq_to_gptq(qweight, qzeros):
|
||||||
|
# awq uses column packing for both weights and zeros
|
||||||
|
izeros = unpack(qzeros, direction="column")
|
||||||
|
iweights = unpack(qweight, direction="column")
|
||||||
|
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||||
|
izeros = izeros - 1
|
||||||
|
# exllama uses row packing for weights and column packing for zeros
|
||||||
|
qzeros = pack(izeros, direction="column")
|
||||||
|
qweight = pack(iweights, direction="row")
|
||||||
|
|
||||||
|
return qweight, qzeros
|
50
server/text_generation_server/layers/awq/quantize/qmodule.py
Normal file
50
server/text_generation_server/layers/awq/quantize/qmodule.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import awq_inference_engine # with CUDA kernels
|
||||||
|
|
||||||
|
|
||||||
|
# class ScaledActivation(nn.Module):
|
||||||
|
# def __init__(self, module, scales):
|
||||||
|
# super().__init__()
|
||||||
|
# self.act = module
|
||||||
|
# self.scales = nn.Parameter(scales.data)
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
class WQLinear(nn.Module):
|
||||||
|
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if w_bit not in [4]:
|
||||||
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
|
self.in_features = qweight.shape[0]
|
||||||
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
|
self.w_bit = w_bit
|
||||||
|
self.group_size = group_size if group_size != -1 else self.in_features
|
||||||
|
# quick sanity check (make sure aligment)
|
||||||
|
assert self.in_features % self.group_size == 0
|
||||||
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.qzeros = qzeros
|
||||||
|
self.scales = scales
|
||||||
|
if bias:
|
||||||
|
self.bias = bias
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
|
out = awq_inference_engine.gemm_forward_cuda(
|
||||||
|
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
||||||
|
)
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out.reshape(out_shape)
|
@ -36,4 +36,4 @@ elif CAN_EXLLAMA:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from text_generation_server.layers.gptq.gptq.quant_linear import QuantLinear
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||||
|
@ -119,6 +119,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
none_tensor,
|
none_tensor,
|
||||||
temp_dq,
|
temp_dq,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
RuntimeError("Cannot create handle")
|
||||||
|
|
||||||
|
|
||||||
DEVICE = None
|
DEVICE = None
|
||||||
|
@ -80,7 +80,7 @@ elif SYSTEM == "rocm":
|
|||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
@ -96,50 +96,6 @@ elif SYSTEM == "xpu":
|
|||||||
return out, res_out
|
return out, res_out
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
|
||||||
def forward(self, hidden_states, residual=None):
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
res_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_layer_norm(
|
|
||||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
|
||||||
)
|
|
||||||
if residual is not None:
|
|
||||||
res_out = residual
|
|
||||||
return out, res_out
|
|
||||||
elif hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
normed_hidden_states,
|
|
||||||
residual,
|
|
||||||
*rest,
|
|
||||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
self.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
if residual is None:
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
|
||||||
|
|
||||||
|
|
||||||
class FastRMSNorm(nn.Module):
|
class FastRMSNorm(nn.Module):
|
||||||
def __init__(self, weight: torch.Tensor, eps: float):
|
def __init__(self, weight: torch.Tensor, eps: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
@ -97,7 +98,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.gptq.quant_linear import (
|
from text_generation_server.layers.gptq import (
|
||||||
ExllamaQuantLinear,
|
ExllamaQuantLinear,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -109,7 +110,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
qweight, qzeros, scales, g_idx, bias, bits, groupsize
|
qweight, qzeros, scales, g_idx, bias, bits, groupsize
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
linear = QuantLinear(
|
linear = QuantLinear(
|
||||||
qweight,
|
qweight,
|
||||||
@ -133,7 +134,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
||||||
|
|
||||||
linear = WQLinear(
|
linear = WQLinear(
|
||||||
w_bit=bits,
|
w_bit=bits,
|
||||||
|
@ -3,6 +3,10 @@ from torch import nn
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.layers.linear import FastLinear
|
from text_generation_server.layers.linear import FastLinear
|
||||||
|
from text_generation_server.layers.tensor_parallel import (
|
||||||
|
TensorParallelHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -48,6 +48,33 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
from text_generation_server.models.flash_llama import (
|
||||||
|
FlashLlama,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.flash_qwen2 import (
|
||||||
|
FlashQwen2,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.flash_cohere import (
|
||||||
|
FlashCohere,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.flash_gemma import (
|
||||||
|
FlashGemma,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.flash_santacoder import (
|
||||||
|
FlashSantacoderSharded,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
|
from text_generation_server.models.llava_next import LlavaNext
|
||||||
|
from text_generation_server.models.idefics2 import Idefics2
|
||||||
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
@ -241,7 +241,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.awq import (
|
from text_generation_server.layers.awq.conveersion_utils import (
|
||||||
fast_awq_to_gptq,
|
fast_awq_to_gptq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,9 +31,11 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,9 +31,11 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,10 +34,14 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
FastLayerNorm,
|
|
||||||
PositionRotaryEmbedding,
|
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
@ -11,11 +11,15 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
|
@ -10,9 +10,11 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,10 +13,14 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
FastLayerNorm,
|
|
||||||
PositionRotaryEmbedding,
|
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
@ -11,9 +11,11 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastLayerNorm,
|
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
@ -80,7 +82,7 @@ def _load_multi_mqa_gptq(
|
|||||||
g_idx = g_idx.to(device=weights.device)
|
g_idx = g_idx.to(device=weights.device)
|
||||||
elif quant_method == "awq":
|
elif quant_method == "awq":
|
||||||
g_idx = None
|
g_idx = None
|
||||||
from text_generation_server.layers.awq import (
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
fast_awq_to_gptq,
|
fast_awq_to_gptq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,11 +31,15 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,9 +52,9 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
|
@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
# For which we have the finale shapes only after the model has loaded
|
# For which we have the finale shapes only after the model has loaded
|
||||||
# This will allocate those buffers.
|
# This will allocate those buffers.
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers.gptq import (
|
||||||
create_exllama_buffers,
|
create_exllama_buffers,
|
||||||
set_device,
|
set_device,
|
||||||
)
|
)
|
||||||
|
@ -171,7 +171,7 @@ class Weights:
|
|||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.awq import (
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
fast_awq_to_gptq,
|
fast_awq_to_gptq,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -227,7 +227,7 @@ class Weights:
|
|||||||
|
|
||||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
from text_generation_server.layers import HAS_EXLLAMA
|
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = (
|
use_exllama = (
|
||||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||||
@ -242,7 +242,7 @@ class Weights:
|
|||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.awq import (
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
fast_awq_to_gptq,
|
fast_awq_to_gptq,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ class Weights:
|
|||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
from text_generation_server.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
if not HAS_EXLLAMA:
|
if not HAS_EXLLAMA:
|
||||||
@ -348,7 +348,7 @@ class Weights:
|
|||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.awq import (
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
fast_awq_to_gptq,
|
fast_awq_to_gptq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user