mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
cleanup
This commit is contained in:
parent
67a46b7361
commit
67d687609b
@ -21,17 +21,13 @@ from text_generation_server.utils.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
|
||||||
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
layer = _load_multi_mqa_gptq(
|
return _load_multi_mqa_gptq(
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
)
|
)
|
||||||
return layer
|
|
||||||
else:
|
else:
|
||||||
return _load_multi_mqa(
|
return _load_multi_mqa(
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
@ -190,21 +186,18 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
quantize = config.quantize
|
|
||||||
|
|
||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=quantize)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
|
|
||||||
return TensorParallelRowLinear(
|
return TensorParallelRowLinear(
|
||||||
get_linear(weight, bias, quantize), process_group=weights.process_group
|
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -272,8 +265,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
|
@ -6,8 +6,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from custom_kernels.exllama import make_q4, q4_matmul
|
from custom_kernels.exllama import make_q4, q4_matmul
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -18,7 +18,7 @@ from accelerate import init_empty_weights
|
|||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from loguru import logger
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_layer_norm(cls, prefix, weights, eps):
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
|
Loading…
Reference in New Issue
Block a user