This commit is contained in:
Felix Marty 2023-07-12 16:16:58 +00:00
parent 67a46b7361
commit 67d687609b
3 changed files with 4 additions and 15 deletions

View File

@ -21,17 +21,13 @@ from text_generation_server.utils.layers import (
get_linear, get_linear,
) )
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
from custom_kernels.exllama import prepare_buffers, set_tuning_params
def load_multi_mqa( def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
if config.quantize == "gptq": if config.quantize == "gptq":
layer = _load_multi_mqa_gptq( return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
) )
return layer
else: else:
return _load_multi_mqa( return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
@ -190,21 +186,18 @@ def load_col(config, prefix: str, weights, bias: bool):
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
quantize = config.quantize
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else: else:
weight = weights.get_multi_weights_row(prefix, quantize=quantize) weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
else: else:
bias = None bias = None
return TensorParallelRowLinear( return TensorParallelRowLinear(
get_linear(weight, bias, quantize), process_group=weights.process_group get_linear(weight, bias, config.quantize), process_group=weights.process_group
) )
@ -272,8 +265,6 @@ class FlashMQAttention(torch.nn.Module):
# Expand from 1 to num_heads # Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,

View File

@ -6,8 +6,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import torch import torch
from loguru import logger
try: try:
from custom_kernels.exllama import make_q4, q4_matmul from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e: except Exception as e:

View File

@ -18,7 +18,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
from typing import Optional from typing import Optional
from loguru import logger
# Monkey patching # Monkey patching
@classmethod @classmethod
def load_layer_norm(cls, prefix, weights, eps): def load_layer_norm(cls, prefix, weights, eps):