Merge branch 'main' into xpu_fix

This commit is contained in:
Wang, Yi A 2024-06-03 01:44:38 -07:00
commit 0b3f71c6f6
4 changed files with 12 additions and 6 deletions

View File

@ -2,8 +2,6 @@ from typing import Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -155,6 +153,8 @@ def get_linear(weight, bias, quantize):
quant_type="nf4", quant_type="nf4",
) )
elif quantize == "exl2": elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
if not isinstance(weight, Exl2Weight): if not isinstance(weight, Exl2Weight):
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated." f"The passed weight is not `exl2` compatible, loader needs to be updated."
@ -165,6 +165,8 @@ def get_linear(weight, bias, quantize):
linear = ExllamaQuantLinear(weight, bias) linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq": elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight): if not isinstance(weight, GPTQWeight):
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated." f"The passed weight is not `gptq` compatible, loader needs to be updated."

View File

@ -21,7 +21,6 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger from loguru import logger
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu": if SYSTEM != "xpu":
@ -198,6 +197,8 @@ def _load_gqa(config, prefix: str, weights):
v_stop = v_offset + (rank + 1) * kv_block_size v_stop = v_offset + (rank + 1) * kv_block_size
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try: try:
qweight_slice = weights._get_slice(f"{prefix}.qweight") qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop] q_qweight = qweight_slice[:, q_start:q_stop]

View File

@ -5,7 +5,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -39,6 +38,8 @@ def load_multi_mqa(
def _load_multi_mqa_gptq( def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
from text_generation_server.layers.gptq import GPTQWeight
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size() world_size = weights.process_group.size()
rank = weights.process_group.rank() rank = weights.process_group.rank()

View File

@ -7,8 +7,6 @@ import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
@ -221,6 +219,8 @@ class Weights:
def get_weights_col(self, prefix: str, quantize: str): def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2": if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
try: try:
q_weight = self.get_tensor(f"{prefix}.q_weight") q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError: except RuntimeError:
@ -247,6 +247,8 @@ class Weights:
if quantize == "exl2": if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2") raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]: elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try: try:
qweight = torch.cat( qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1