mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-24 04:22:10 +00:00
Switch to punica-sgmv kernel from the Hub (#3236)
* Switch to punica-sgmv kernel from the Hub This also switches (temporarily) to the tgi-nix/kernel-builder merge branch, bumping up to CUDA 12.8 (same as non-Nix Torch). * nix: client depends on aiohttp This probably worked before the nixpkgs bump because a dependency propagated aiohttp.
This commit is contained in:
parent
43b1b07fb9
commit
e32528792c
@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN . .venv/bin/activate && make build-awq
|
RUN . .venv/bin/activate && make build-awq
|
||||||
|
|
||||||
# Build Lorax Punica kernels
|
|
||||||
FROM kernel-builder AS lorax-punica-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/Makefile-lorax-punica Makefile
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder AS custom-kernels-builder
|
FROM kernel-builder AS custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
|||||||
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from awq kernels builder
|
# Copy build artifacts from awq kernels builder
|
||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from lorax punica kernels builder
|
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
16
flake.lock
16
flake.lock
@ -718,16 +718,16 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs_6": {
|
"nixpkgs_6": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1737453259,
|
"lastModified": 1746711195,
|
||||||
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
|
"narHash": "sha256-bSpM2ySq12PBOVN7jZdzXsc99iRoYOyolh5wz43+CjQ=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
|
"rev": "6b7a66b06ccb09ac95872ac6ddf952e0660672ab",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"ref": "outlines-v0.1.4-tgi",
|
"ref": "kernel-builder-cuda-12.9.0",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
@ -978,16 +978,16 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1746795305,
|
"lastModified": 1747733488,
|
||||||
"narHash": "sha256-4fpUT4j4w0NDKF22KvG7iGmwQTBPM5SrPEqt+N3fqF0=",
|
"narHash": "sha256-LYov4H9zvqXXlFKdytcVcDioH416c+LWfyw/HWta0qw=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "359cd25f31f0f2ad2cadfbf4e180780a7a06e3c5",
|
"rev": "61c730990efa58e64c652bf15253aae47dd0f7dd",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"ref": "torch-2.7",
|
"ref": "merge-with-kernel-builder",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix/merge-with-kernel-builder";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
buildPythonPackage,
|
buildPythonPackage,
|
||||||
poetry-core,
|
poetry-core,
|
||||||
|
aiohttp,
|
||||||
huggingface-hub,
|
huggingface-hub,
|
||||||
pydantic,
|
pydantic,
|
||||||
}:
|
}:
|
||||||
@ -15,6 +16,7 @@ buildPythonPackage {
|
|||||||
build-system = [ poetry-core ];
|
build-system = [ poetry-core ];
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
aiohttp
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
pydantic
|
pydantic
|
||||||
];
|
];
|
||||||
|
@ -31,7 +31,7 @@
|
|||||||
peft,
|
peft,
|
||||||
pillow,
|
pillow,
|
||||||
prometheus-client,
|
prometheus-client,
|
||||||
punica-kernels,
|
punica-sgmv,
|
||||||
py-cpuinfo,
|
py-cpuinfo,
|
||||||
pydantic,
|
pydantic,
|
||||||
quantization,
|
quantization,
|
||||||
@ -107,7 +107,7 @@ buildPythonPackage {
|
|||||||
peft
|
peft
|
||||||
pillow
|
pillow
|
||||||
prometheus-client
|
prometheus-client
|
||||||
punica-kernels
|
punica-sgmv
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
pydantic
|
pydantic
|
||||||
quantization
|
quantization
|
||||||
|
@ -3,7 +3,6 @@ include Makefile-flash-att-v2
|
|||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
|
||||||
include Makefile-exllamav2
|
include Makefile-exllamav2
|
||||||
include Makefile-flashinfer
|
include Makefile-flashinfer
|
||||||
|
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
|
||||||
|
|
||||||
build-lorax-punica:
|
|
||||||
if [ ! -d 'lorax-punica' ]; then \
|
|
||||||
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
|
|
||||||
fi
|
|
||||||
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
|
||||||
cd lorax-punica && git submodule update --init --recursive
|
|
||||||
cd lorax-punica/server/punica_kernels && python setup.py build
|
|
||||||
|
|
||||||
install-lorax-punica: build-lorax-punica
|
|
||||||
cd lorax-punica/server/punica_kernels && python setup.py install
|
|
@ -163,6 +163,64 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-community/punica-sgmv",
|
||||||
|
"sha": "9ae1b469cb39c33df9ddd61657c6359acc423714",
|
||||||
|
"variants": {
|
||||||
|
"torch26-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu126-aarch64-linux": {
|
||||||
|
"hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu126-aarch64-linux": {
|
||||||
|
"hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu126-aarch64-linux": {
|
||||||
|
"hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu128-aarch64-linux": {
|
||||||
|
"hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu128-x86_64-linux": {
|
||||||
|
"hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"repo_id": "kernels-community/quantization",
|
"repo_id": "kernels-community/quantization",
|
||||||
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",
|
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",
|
||||||
|
@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[tool.kernels.dependencies]
|
[tool.kernels.dependencies]
|
||||||
"kernels-community/paged-attention" = ">=0.0.2"
|
"kernels-community/paged-attention" = ">=0.0.2"
|
||||||
"kernels-community/moe" = ">=0.1.1"
|
"kernels-community/moe" = ">=0.1.1"
|
||||||
|
"kernels-community/punica-sgmv" = ">=0.0.1"
|
||||||
"kernels-community/quantization" = ">=0.0.3"
|
"kernels-community/quantization" = ">=0.0.3"
|
||||||
"kernels-community/quantization-eetq" = ">=0.0.1"
|
"kernels-community/quantization-eetq" = ">=0.0.1"
|
||||||
"kernels-community/rotary" = ">=0.0.1"
|
"kernels-community/rotary" = ">=0.0.1"
|
||||||
|
@ -13,21 +13,20 @@ from torch.distributed import ProcessGroup
|
|||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.adapters.weights import (
|
from text_generation_server.adapters.weights import (
|
||||||
AdapterBatchMetadata,
|
AdapterBatchMetadata,
|
||||||
AdapterWeights,
|
AdapterWeights,
|
||||||
BatchAdapterWeights,
|
BatchAdapterWeights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.sgmv import (
|
|
||||||
BGMV_MAX_RANK,
|
if SYSTEM == "cuda":
|
||||||
MAX_RANK_CUSTOM,
|
punica_sgmv = load_kernel(
|
||||||
get_tmp_tensors,
|
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
|
||||||
orient_for_rank,
|
|
||||||
pad_rank,
|
|
||||||
use_cutlass_shrink,
|
|
||||||
has_sgmv,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
punica_sgmv = None
|
||||||
|
|
||||||
|
|
||||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||||
@ -129,11 +128,13 @@ class LoraWeights(AdapterWeights):
|
|||||||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||||
|
|
||||||
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
|
||||||
self._is_transposed = False
|
self._is_transposed = False
|
||||||
|
|
||||||
# [num_layers, hidden_size, r]
|
# [num_layers, hidden_size, r]
|
||||||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
weights_a = [
|
||||||
|
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
|
||||||
|
]
|
||||||
self._weights_a = torch.stack(weights_a)
|
self._weights_a = torch.stack(weights_a)
|
||||||
|
|
||||||
# [num_layers, r, hidden_size]
|
# [num_layers, r, hidden_size]
|
||||||
@ -244,8 +245,12 @@ class LoraWeights(AdapterWeights):
|
|||||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
# pad lora ranks to be compatible with sgmv
|
# pad lora ranks to be compatible with sgmv
|
||||||
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
lora_a_list = [
|
||||||
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
|
||||||
|
]
|
||||||
|
lora_b_list = [
|
||||||
|
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
|
||||||
|
]
|
||||||
|
|
||||||
if lora_a_list:
|
if lora_a_list:
|
||||||
# update rank if it was padded
|
# update rank if it was padded
|
||||||
@ -293,7 +298,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||||||
|
|
||||||
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||||
return all(
|
return all(
|
||||||
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM
|
||||||
for rank_data in self.rank_data.values()
|
for rank_data in self.rank_data.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -337,8 +342,8 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
use_sgmv = False
|
use_sgmv = False
|
||||||
if prefill or max_rank > BGMV_MAX_RANK:
|
if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK:
|
||||||
if has_sgmv():
|
if punica_sgmv is not None:
|
||||||
use_sgmv = True
|
use_sgmv = True
|
||||||
lora_a_ptr = torch.tensor(
|
lora_a_ptr = torch.tensor(
|
||||||
[
|
[
|
||||||
@ -425,7 +430,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||||||
|
|
||||||
if use_sgmv:
|
if use_sgmv:
|
||||||
lora_a_ptr_indices = lora_a_ptr[indices]
|
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||||
tmp_shrink, tmp_expand = get_tmp_tensors(
|
tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors(
|
||||||
lora_a_ptr_indices.size(0), rank, device
|
lora_a_ptr_indices.size(0), rank, device
|
||||||
)
|
)
|
||||||
segment_starts = meta.adapter_segments[indices]
|
segment_starts = meta.adapter_segments[indices]
|
||||||
|
@ -5,14 +5,16 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from text_generation_server.utils.sgmv import (
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
add_lora_a_bgmv,
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
add_lora_b_bgmv,
|
|
||||||
has_sgmv,
|
if SYSTEM == "cuda":
|
||||||
lora_a_sgmv_cutlass,
|
punica_sgmv = load_kernel(
|
||||||
lora_b_sgmv_cutlass,
|
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
|
||||||
orient_for_rank,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
punica_sgmv = None
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from text_generation_server.adapters import AdapterBatchData
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
@ -41,7 +43,11 @@ class LoraLinear(nn.Module):
|
|||||||
return result
|
return result
|
||||||
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||||
|
|
||||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
if (
|
||||||
|
punica_sgmv is not None
|
||||||
|
and data is not None
|
||||||
|
and data.can_vectorize(self.process_group)
|
||||||
|
):
|
||||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
# The 'result' tensor represents the full output, which can vary in size based on
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
@ -68,7 +74,7 @@ class LoraLinear(nn.Module):
|
|||||||
|
|
||||||
if data.use_sgmv:
|
if data.use_sgmv:
|
||||||
# Use SGMV for prefill
|
# Use SGMV for prefill
|
||||||
v = lora_a_sgmv_cutlass(
|
v = punica_sgmv.lora_a_sgmv_cutlass(
|
||||||
input,
|
input,
|
||||||
rank_segments.tmp_shrink,
|
rank_segments.tmp_shrink,
|
||||||
lora_a_ptr,
|
lora_a_ptr,
|
||||||
@ -81,7 +87,7 @@ class LoraLinear(nn.Module):
|
|||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
v = self.collect_lora_a(v)
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
lora_b_sgmv_cutlass(
|
punica_sgmv.lora_b_sgmv_cutlass(
|
||||||
proj,
|
proj,
|
||||||
v,
|
v,
|
||||||
rank_segments.tmp_expand,
|
rank_segments.tmp_expand,
|
||||||
@ -96,7 +102,7 @@ class LoraLinear(nn.Module):
|
|||||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
)
|
)
|
||||||
# TODO: error with [-1, 0], but not [0, -1]
|
# TODO: error with [-1, 0], but not [0, -1]
|
||||||
add_lora_a_bgmv(
|
punica_sgmv.add_lora_a_bgmv(
|
||||||
v,
|
v,
|
||||||
input,
|
input,
|
||||||
lora_a_ptr,
|
lora_a_ptr,
|
||||||
@ -107,7 +113,7 @@ class LoraLinear(nn.Module):
|
|||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
v = self.collect_lora_a(v)
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
add_lora_b_bgmv(
|
punica_sgmv.add_lora_b_bgmv(
|
||||||
proj,
|
proj,
|
||||||
v,
|
v,
|
||||||
lora_b_ptr,
|
lora_b_ptr,
|
||||||
@ -142,7 +148,7 @@ class LoraLinear(nn.Module):
|
|||||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||||
|
|
||||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0))
|
||||||
|
|
||||||
a_out = input @ lora_a
|
a_out = input @ lora_a
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
@ -1,252 +0,0 @@
|
|||||||
# Origin: https://github.com/predibase/lorax
|
|
||||||
# Path: lorax/server/lorax_server/utils/sgmv.py
|
|
||||||
# License: Apache License Version 2.0, January 2004
|
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
try:
|
|
||||||
import punica_kernels as _kernels
|
|
||||||
|
|
||||||
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
|
||||||
except ImportError:
|
|
||||||
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
|
|
||||||
_kernels = None
|
|
||||||
HAS_SGMV = False
|
|
||||||
|
|
||||||
|
|
||||||
MIN_SGMV_RANK = 8
|
|
||||||
MIN_RANK_CUSTOM = 16
|
|
||||||
MAX_RANK_CUSTOM = 128
|
|
||||||
SGMV_BLOCK_SIZE = 16
|
|
||||||
BGMV_MAX_RANK = 64
|
|
||||||
|
|
||||||
|
|
||||||
def has_sgmv() -> bool:
|
|
||||||
return HAS_SGMV
|
|
||||||
|
|
||||||
|
|
||||||
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
|
||||||
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
|
||||||
if not has_sgmv():
|
|
||||||
return t
|
|
||||||
|
|
||||||
# tensor parallelism will result in effective rank being divided by world_size,
|
|
||||||
# so we need to scale the min rank to offset that effect
|
|
||||||
min_rank = MIN_SGMV_RANK * world_size
|
|
||||||
|
|
||||||
# if we're at or below the min rank, pad up to the min rank
|
|
||||||
# otherwise, pad to the nearest multiple of the block size
|
|
||||||
current_rank = t.size(dim)
|
|
||||||
target_rank = (
|
|
||||||
min_rank
|
|
||||||
if current_rank <= min_rank
|
|
||||||
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
|
||||||
)
|
|
||||||
if current_rank == target_rank:
|
|
||||||
return t
|
|
||||||
|
|
||||||
pad_size = target_rank - current_rank
|
|
||||||
|
|
||||||
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
|
||||||
pad = [0, 0] * t.dim()
|
|
||||||
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
|
||||||
pad = tuple(pad)
|
|
||||||
|
|
||||||
return F.pad(t, pad, mode="constant", value=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def use_cutlass_shrink(lora_rank: int) -> bool:
|
|
||||||
return lora_rank < MIN_RANK_CUSTOM
|
|
||||||
|
|
||||||
|
|
||||||
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
|
||||||
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
|
||||||
return t.transpose(0, 1)
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
|
|
||||||
def add_lora_sgmv_cutlass(
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
wa_ptr: torch.Tensor,
|
|
||||||
wb_ptr: torch.Tensor,
|
|
||||||
s_start: torch.Tensor,
|
|
||||||
s_end: torch.Tensor,
|
|
||||||
layer_idx: int,
|
|
||||||
lora_rank: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Semantics:
|
|
||||||
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
|
||||||
x: Shape: `[B, H1]`. Input vectors.
|
|
||||||
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
|
||||||
Weight matrix shape: `[num_layers, R, H1]`.
|
|
||||||
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
|
||||||
Weight matrix shape: `[num_layers, R, H2]`.
|
|
||||||
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
|
||||||
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
|
||||||
layer_idx: Layer index of the weight matrices.
|
|
||||||
"""
|
|
||||||
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
|
||||||
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
|
||||||
_add_lora_sgmv_cutlass_legacy(
|
|
||||||
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
|
||||||
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
|
||||||
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
|
||||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
|
||||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
|
||||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def _add_lora_sgmv_cutlass_legacy(
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
wa_ptr: torch.Tensor,
|
|
||||||
wb_ptr: torch.Tensor,
|
|
||||||
s_start: torch.IntTensor,
|
|
||||||
s_end: torch.IntTensor,
|
|
||||||
layer_idx: int,
|
|
||||||
lora_rank: int,
|
|
||||||
):
|
|
||||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
|
||||||
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
|
||||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
|
||||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
|
||||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
|
||||||
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32)
|
|
||||||
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
|
||||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
|
|
||||||
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
|
|
||||||
return torch.empty((size,), dtype=torch.uint8, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tmp_expand_size(size: int) -> int:
|
|
||||||
return _kernels.sgmv_cutlass_tmp_size(size)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tmp_tensors(
|
|
||||||
nsegments: int, lora_rank: int, device: torch.device
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()
|
|
||||||
has_sgmv_available = has_sgmv()
|
|
||||||
|
|
||||||
if use_cutlass:
|
|
||||||
tmp = get_tmp_tensor_for_size(nsegments, device)
|
|
||||||
return tmp, tmp
|
|
||||||
elif has_sgmv_available:
|
|
||||||
return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)
|
|
||||||
else:
|
|
||||||
tmp = get_tmp_tensor_for_size(nsegments, device)
|
|
||||||
return tmp, tmp
|
|
||||||
|
|
||||||
|
|
||||||
def lora_a_sgmv_cutlass(
|
|
||||||
x: torch.Tensor,
|
|
||||||
tmp: torch.Tensor,
|
|
||||||
wa_ptr: torch.Tensor,
|
|
||||||
s_start: torch.IntTensor,
|
|
||||||
s_end: torch.IntTensor,
|
|
||||||
layer_idx: int,
|
|
||||||
lora_rank: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
|
||||||
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
|
||||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
|
||||||
else:
|
|
||||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def lora_b_sgmv_cutlass(
|
|
||||||
y: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
tmp: torch.Tensor,
|
|
||||||
wb_ptr: torch.Tensor,
|
|
||||||
s_start: torch.IntTensor,
|
|
||||||
s_end: torch.IntTensor,
|
|
||||||
layer_idx: int,
|
|
||||||
):
|
|
||||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Semantics:
|
|
||||||
y[i] += (
|
|
||||||
x[i].unsqueeze(0)
|
|
||||||
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
|
||||||
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
|
||||||
* scale
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
|
||||||
v: Shape: `[B, R]`. Temporary vector.
|
|
||||||
x: Shape: `[B, H1]`. Input vectors.
|
|
||||||
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
|
|
||||||
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
|
|
||||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
|
||||||
layer_idx: Layer index of LoRA weights.
|
|
||||||
scale: Scaling factor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def add_lora_a_bgmv(
|
|
||||||
v: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
wa_T_all: torch.Tensor,
|
|
||||||
indicies: torch.LongTensor,
|
|
||||||
layer_idx: int,
|
|
||||||
):
|
|
||||||
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def add_lora_b_bgmv(
|
|
||||||
y: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
wb_T_all: torch.Tensor,
|
|
||||||
indicies: torch.LongTensor,
|
|
||||||
layer_idx: int,
|
|
||||||
):
|
|
||||||
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def segmented_matmul(
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w: List[torch.Tensor],
|
|
||||||
b: List[torch.Tensor],
|
|
||||||
s_start: torch.IntTensor,
|
|
||||||
s_end: torch.IntTensor,
|
|
||||||
):
|
|
||||||
for i in range(len(w)):
|
|
||||||
if s_end[i] - s_start[i] <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
xi = x[s_start[i] : s_end[i]]
|
|
||||||
wi = w[i]
|
|
||||||
bi = b[i]
|
|
||||||
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
|
|
Loading…
Reference in New Issue
Block a user