mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into xpu_fix
This commit is contained in:
commit
0b3f71c6f6
@ -2,8 +2,6 @@ from typing import Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
@ -155,6 +153,8 @@ def get_linear(weight, bias, quantize):
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
if not isinstance(weight, Exl2Weight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `exl2` compatible, loader needs to be updated."
|
||||
@ -165,6 +165,8 @@ def get_linear(weight, bias, quantize):
|
||||
linear = ExllamaQuantLinear(weight, bias)
|
||||
|
||||
elif quantize == "gptq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if not isinstance(weight, GPTQWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
|
@ -21,7 +21,6 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "xpu":
|
||||
@ -198,6 +197,8 @@ def _load_gqa(config, prefix: str, weights):
|
||||
v_stop = v_offset + (rank + 1) * kv_block_size
|
||||
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
qweight_slice = weights._get_slice(f"{prefix}.qweight")
|
||||
q_qweight = qweight_slice[:, q_start:q_stop]
|
||||
|
@ -5,7 +5,6 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
@ -39,6 +38,8 @@ def load_multi_mqa(
|
||||
def _load_multi_mqa_gptq(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
@ -7,8 +7,6 @@ import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
@ -221,6 +219,8 @@ class Weights:
|
||||
|
||||
def get_weights_col(self, prefix: str, quantize: str):
|
||||
if quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
@ -247,6 +247,8 @@ class Weights:
|
||||
if quantize == "exl2":
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
elif quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
|
Loading…
Reference in New Issue
Block a user