This commit is contained in:
Felix Marty 2023-07-12 16:16:58 +00:00
parent 67a46b7361
commit 67d687609b
3 changed files with 4 additions and 15 deletions

View File

@ -21,17 +21,13 @@ from text_generation_server.utils.layers import (
get_linear,
)
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
from custom_kernels.exllama import prepare_buffers, set_tuning_params
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if config.quantize == "gptq":
layer = _load_multi_mqa_gptq(
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
return layer
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
@ -190,21 +186,18 @@ def load_col(config, prefix: str, weights, bias: bool):
def load_row(config, prefix: str, weights, bias: bool):
quantize = config.quantize
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_multi_weights_row(prefix, quantize=quantize)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, quantize), process_group=weights.process_group
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
@ -272,8 +265,6 @@ class FlashMQAttention(torch.nn.Module):
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,

View File

@ -6,8 +6,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import torch
from loguru import logger
try:
from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e:

View File

@ -18,7 +18,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
from typing import Optional
from loguru import logger
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):