mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
_custom_C.LLMM1 and HIP_FORCE_DEV_KERNARG=1
This commit is contained in:
parent
f723e5ccb5
commit
1b4c8b4b3e
@ -141,7 +141,8 @@ FROM base as base-copy
|
|||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80 \
|
||||||
|
HIP_FORCE_DEV_KERNARG=1
|
||||||
|
|
||||||
# Copy builds artifacts from triton builder
|
# Copy builds artifacts from triton builder
|
||||||
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
@ -67,6 +67,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -324,9 +329,64 @@ def warn_deprecate_bnb():
|
|||||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class FastLinearROCm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias = nn.Parameter(bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(weight, bias)
|
||||||
|
|
||||||
|
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
||||||
|
weight = self.weight
|
||||||
|
bias = self.bias
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
|
||||||
|
batched = False
|
||||||
|
|
||||||
|
if inp.dim() == 3:
|
||||||
|
inp = inp.view(-1, inp.size(-1))
|
||||||
|
batched = True
|
||||||
|
|
||||||
|
m, k = weight.shape[0], inp.shape[1]
|
||||||
|
out = torch.empty(inp.shape[0],
|
||||||
|
weight.shape[0],
|
||||||
|
dtype=inp.dtype,
|
||||||
|
device='cuda')
|
||||||
|
if (k == 8192 and
|
||||||
|
(m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||||
|
_custom_C.LLMM1(weight, inp, out, 8)
|
||||||
|
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||||
|
_custom_C.LLMM1(weight, inp, out, 4)
|
||||||
|
else:
|
||||||
|
out = F.linear(inp, weight)
|
||||||
|
if batched:
|
||||||
|
out = out.view(inp.shape[0], inp.shape[1], weight.shape[0])
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out
|
||||||
|
return F.linear(inp, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
linear = FastLinearROCm(weight, bias)
|
||||||
|
else:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
elif quantize == "eetq":
|
elif quantize == "eetq":
|
||||||
if HAS_EETQ:
|
if HAS_EETQ:
|
||||||
|
Loading…
Reference in New Issue
Block a user