mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
cleanup
This commit is contained in:
parent
67a46b7361
commit
67d687609b
@ -21,17 +21,13 @@ from text_generation_server.utils.layers import (
|
||||
get_linear,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
||||
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
||||
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
if config.quantize == "gptq":
|
||||
layer = _load_multi_mqa_gptq(
|
||||
return _load_multi_mqa_gptq(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
)
|
||||
return layer
|
||||
else:
|
||||
return _load_multi_mqa(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
@ -190,21 +186,18 @@ def load_col(config, prefix: str, weights, bias: bool):
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
quantize = config.quantize
|
||||
|
||||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
else:
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=quantize)
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, quantize), process_group=weights.process_group
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
@ -272,8 +265,6 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
|
@ -6,8 +6,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from custom_kernels.exllama import make_q4, q4_matmul
|
||||
except Exception as e:
|
||||
|
@ -18,7 +18,7 @@ from accelerate import init_empty_weights
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
||||
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
|
Loading…
Reference in New Issue
Block a user