_custom_C.LLMM1 and HIP_FORCE_DEV_KERNARG=1

This commit is contained in:
fxmarty 2024-04-19 11:50:01 +00:00
parent f723e5ccb5
commit 1b4c8b4b3e
2 changed files with 63 additions and 2 deletions

View File

@ -141,7 +141,8 @@ FROM base as base-copy
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80 \
HIP_FORCE_DEV_KERNARG=1
# Copy builds artifacts from triton builder # Copy builds artifacts from triton builder
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

View File

@ -67,6 +67,11 @@ try:
except ImportError: except ImportError:
pass pass
if IS_ROCM_SYSTEM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
# Monkey patching # Monkey patching
@classmethod @classmethod
@ -324,9 +329,64 @@ def warn_deprecate_bnb():
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
) )
class FastLinearROCm(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = nn.Parameter(weight)
if bias is not None:
self.bias = nn.Parameter(bias)
else:
self.bias = None
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
batched = False
if inp.dim() == 3:
inp = inp.view(-1, inp.size(-1))
batched = True
m, k = weight.shape[0], inp.shape[1]
out = torch.empty(inp.shape[0],
weight.shape[0],
dtype=inp.dtype,
device='cuda')
if (k == 8192 and
(m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weight, inp, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)
if batched:
out = out.view(inp.shape[0], inp.shape[1], weight.shape[0])
if bias is not None:
out = out + bias
return out
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias, quantize): def get_linear(weight, bias, quantize):
if quantize is None: if quantize is None:
if IS_ROCM_SYSTEM:
linear = FastLinearROCm(weight, bias)
else:
linear = FastLinear(weight, bias) linear = FastLinear(weight, bias)
elif quantize == "eetq": elif quantize == "eetq":
if HAS_EETQ: if HAS_EETQ: