From e32528792cc7ccfdc5dd4b10fecedeb907422261 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 21 May 2025 15:44:15 +0200 Subject: [PATCH] Switch to punica-sgmv kernel from the Hub (#3236) * Switch to punica-sgmv kernel from the Hub This also switches (temporarily) to the tgi-nix/kernel-builder merge branch, bumping up to CUDA 12.8 (same as non-Nix Torch). * nix: client depends on aiohttp This probably worked before the nixpkgs bump because a dependency propagated aiohttp. --- Dockerfile | 9 - flake.lock | 16 +- flake.nix | 2 +- nix/client.nix | 2 + nix/server.nix | 4 +- server/Makefile | 1 - server/Makefile-lorax-punica | 12 - server/kernels.lock | 58 ++++ server/pyproject.toml | 1 + .../text_generation_server/adapters/lora.py | 41 +-- server/text_generation_server/layers/lora.py | 34 ++- server/text_generation_server/utils/sgmv.py | 252 ------------------ 12 files changed, 115 insertions(+), 317 deletions(-) delete mode 100644 server/Makefile-lorax-punica delete mode 100644 server/text_generation_server/utils/sgmv.py diff --git a/Dockerfile b/Dockerfile index e72d9b70..869596d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile # Build specific version of transformers RUN . .venv/bin/activate && make build-awq -# Build Lorax Punica kernels -FROM kernel-builder AS lorax-punica-builder -WORKDIR /usr/src -COPY server/Makefile-lorax-punica Makefile -# Build specific version of transformers -RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica - # Build Transformers CUDA kernels FROM kernel-builder AS custom-kernels-builder WORKDIR /usr/src @@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages # Copy build artifacts from awq kernels builder COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages -# Copy build artifacts from lorax punica kernels builder -COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages diff --git a/flake.lock b/flake.lock index 4540a736..2c6e8063 100644 --- a/flake.lock +++ b/flake.lock @@ -718,16 +718,16 @@ }, "nixpkgs_6": { "locked": { - "lastModified": 1737453259, - "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", + "lastModified": 1746711195, + "narHash": "sha256-bSpM2ySq12PBOVN7jZdzXsc99iRoYOyolh5wz43+CjQ=", "owner": "danieldk", "repo": "nixpkgs", - "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", + "rev": "6b7a66b06ccb09ac95872ac6ddf952e0660672ab", "type": "github" }, "original": { "owner": "danieldk", - "ref": "outlines-v0.1.4-tgi", + "ref": "kernel-builder-cuda-12.9.0", "repo": "nixpkgs", "type": "github" } @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1746795305, - "narHash": "sha256-4fpUT4j4w0NDKF22KvG7iGmwQTBPM5SrPEqt+N3fqF0=", + "lastModified": 1747733488, + "narHash": "sha256-LYov4H9zvqXXlFKdytcVcDioH416c+LWfyw/HWta0qw=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "359cd25f31f0f2ad2cadfbf4e180780a7a06e3c5", + "rev": "61c730990efa58e64c652bf15253aae47dd0f7dd", "type": "github" }, "original": { "owner": "huggingface", - "ref": "torch-2.7", + "ref": "merge-with-kernel-builder", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index e405b60d..13f40054 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/merge-with-kernel-builder"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/nix/client.nix b/nix/client.nix index 351fd08a..be8e2fc7 100644 --- a/nix/client.nix +++ b/nix/client.nix @@ -1,6 +1,7 @@ { buildPythonPackage, poetry-core, + aiohttp, huggingface-hub, pydantic, }: @@ -15,6 +16,7 @@ buildPythonPackage { build-system = [ poetry-core ]; dependencies = [ + aiohttp huggingface-hub pydantic ]; diff --git a/nix/server.nix b/nix/server.nix index e6493531..a45f39cc 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -31,7 +31,7 @@ peft, pillow, prometheus-client, - punica-kernels, + punica-sgmv, py-cpuinfo, pydantic, quantization, @@ -107,7 +107,7 @@ buildPythonPackage { peft pillow prometheus-client - punica-kernels + punica-sgmv py-cpuinfo pydantic quantization diff --git a/server/Makefile b/server/Makefile index f4855392..a95a4ae5 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,7 +3,6 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-selective-scan -include Makefile-lorax-punica include Makefile-exllamav2 include Makefile-flashinfer diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica deleted file mode 100644 index 72f06f76..00000000 --- a/server/Makefile-lorax-punica +++ /dev/null @@ -1,12 +0,0 @@ -lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc - -build-lorax-punica: - if [ ! -d 'lorax-punica' ]; then \ - git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ - fi - cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) - cd lorax-punica && git submodule update --init --recursive - cd lorax-punica/server/punica_kernels && python setup.py build - -install-lorax-punica: build-lorax-punica - cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/server/kernels.lock b/server/kernels.lock index 1bce05c6..a06cbff3 100644 --- a/server/kernels.lock +++ b/server/kernels.lock @@ -163,6 +163,64 @@ } } }, + { + "repo_id": "kernels-community/punica-sgmv", + "sha": "9ae1b469cb39c33df9ddd61657c6359acc423714", + "variants": { + "torch26-cxx11-cu118-x86_64-linux": { + "hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu124-x86_64-linux": { + "hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-x86_64-linux": { + "hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu118-x86_64-linux": { + "hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu124-x86_64-linux": { + "hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-x86_64-linux": { + "hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1", + "hash_type": "git_lfs_concat" + } + } + }, { "repo_id": "kernels-community/quantization", "sha": "6470f9b005797e00279eb9103463dfe0f8b7da00", diff --git a/server/pyproject.toml b/server/pyproject.toml index 5489b19d..7f2addb6 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta" [tool.kernels.dependencies] "kernels-community/paged-attention" = ">=0.0.2" "kernels-community/moe" = ">=0.1.1" +"kernels-community/punica-sgmv" = ">=0.0.1" "kernels-community/quantization" = ">=0.0.3" "kernels-community/quantization-eetq" = ">=0.0.1" "kernels-community/rotary" = ">=0.0.1" diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 782d66e4..c8eb48a2 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -13,21 +13,20 @@ from torch.distributed import ProcessGroup from text_generation_server.utils.log import log_master from text_generation_server.adapters.config import AdapterConfig, ModuleMap - +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.adapters.weights import ( AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, ) -from text_generation_server.utils.sgmv import ( - BGMV_MAX_RANK, - MAX_RANK_CUSTOM, - get_tmp_tensors, - orient_for_rank, - pad_rank, - use_cutlass_shrink, - has_sgmv, -) + +if SYSTEM == "cuda": + punica_sgmv = load_kernel( + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" + ) +else: + punica_sgmv = None def get_start_stop_idxs_for_rank(offset, size, rank, world_size): @@ -129,11 +128,13 @@ class LoraWeights(AdapterWeights): self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 - self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) self._is_transposed = False # [num_layers, hidden_size, r] - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + weights_a = [ + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a + ] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] @@ -244,8 +245,12 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] - lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] + lora_a_list = [ + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list + ] + lora_b_list = [ + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list + ] if lora_a_list: # update rank if it was padded @@ -293,7 +298,7 @@ class BatchLoraWeights(BatchAdapterWeights): def can_vectorize(self, pg: ProcessGroup) -> bool: return all( - rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM for rank_data in self.rank_data.values() ) @@ -337,8 +342,8 @@ class BatchLoraWeights(BatchAdapterWeights): ) use_sgmv = False - if prefill or max_rank > BGMV_MAX_RANK: - if has_sgmv(): + if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK: + if punica_sgmv is not None: use_sgmv = True lora_a_ptr = torch.tensor( [ @@ -425,7 +430,7 @@ class BatchLoraWeights(BatchAdapterWeights): if use_sgmv: lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors( + tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors( lora_a_ptr_indices.size(0), rank, device ) segment_starts = meta.adapter_segments[indices] diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index a4537b55..abfb097d 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -5,14 +5,16 @@ import torch.distributed from torch import nn from torch.distributed import ProcessGroup -from text_generation_server.utils.sgmv import ( - add_lora_a_bgmv, - add_lora_b_bgmv, - has_sgmv, - lora_a_sgmv_cutlass, - lora_b_sgmv_cutlass, - orient_for_rank, -) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel + +if SYSTEM == "cuda": + punica_sgmv = load_kernel( + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" + ) +else: + punica_sgmv = None + if TYPE_CHECKING: from text_generation_server.adapters import AdapterBatchData @@ -41,7 +43,11 @@ class LoraLinear(nn.Module): return result data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) - if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if ( + punica_sgmv is not None + and data is not None + and data.can_vectorize(self.process_group) + ): # In tensor-parallel configurations, each GPU processes a specific segment of the output. # The 'result' tensor represents the full output, which can vary in size based on # the layer type (e.g., attention vs. feed-forward layers). We define the current @@ -68,7 +74,7 @@ class LoraLinear(nn.Module): if data.use_sgmv: # Use SGMV for prefill - v = lora_a_sgmv_cutlass( + v = punica_sgmv.lora_a_sgmv_cutlass( input, rank_segments.tmp_shrink, lora_a_ptr, @@ -81,7 +87,7 @@ class LoraLinear(nn.Module): if self.process_group.size() > 1: v = self.collect_lora_a(v) - lora_b_sgmv_cutlass( + punica_sgmv.lora_b_sgmv_cutlass( proj, v, rank_segments.tmp_expand, @@ -96,7 +102,7 @@ class LoraLinear(nn.Module): (input.size(0), r), dtype=input.dtype, device=input.device ) # TODO: error with [-1, 0], but not [0, -1] - add_lora_a_bgmv( + punica_sgmv.add_lora_a_bgmv( v, input, lora_a_ptr, @@ -107,7 +113,7 @@ class LoraLinear(nn.Module): if self.process_group.size() > 1: v = self.collect_lora_a(v) - add_lora_b_bgmv( + punica_sgmv.add_lora_b_bgmv( proj, v, lora_b_ptr, @@ -142,7 +148,7 @@ class LoraLinear(nn.Module): lora_a = data.lora_a[adapter_index][self.layer_id, :, :] lora_b = data.lora_b[adapter_index][self.layer_id, :, :] - lora_a = orient_for_rank(lora_a, lora_b.size(0)) + lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0)) a_out = input @ lora_a if self.process_group.size() > 1: diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py deleted file mode 100644 index 2d0a73a5..00000000 --- a/server/text_generation_server/utils/sgmv.py +++ /dev/null @@ -1,252 +0,0 @@ -# Origin: https://github.com/predibase/lorax -# Path: lorax/server/lorax_server/utils/sgmv.py -# License: Apache License Version 2.0, January 2004 - -import os -import warnings -from functools import lru_cache -from typing import List, Tuple - -import torch -import torch.nn.functional as F - -try: - import punica_kernels as _kernels - - HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) -except ImportError: - warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") - _kernels = None - HAS_SGMV = False - - -MIN_SGMV_RANK = 8 -MIN_RANK_CUSTOM = 16 -MAX_RANK_CUSTOM = 128 -SGMV_BLOCK_SIZE = 16 -BGMV_MAX_RANK = 64 - - -def has_sgmv() -> bool: - return HAS_SGMV - - -def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: - """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" - if not has_sgmv(): - return t - - # tensor parallelism will result in effective rank being divided by world_size, - # so we need to scale the min rank to offset that effect - min_rank = MIN_SGMV_RANK * world_size - - # if we're at or below the min rank, pad up to the min rank - # otherwise, pad to the nearest multiple of the block size - current_rank = t.size(dim) - target_rank = ( - min_rank - if current_rank <= min_rank - else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE - ) - if current_rank == target_rank: - return t - - pad_size = target_rank - current_rank - - # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - pad = [0, 0] * t.dim() - pad[(t.dim() - dim - 1) * 2 + 1] = pad_size - pad = tuple(pad) - - return F.pad(t, pad, mode="constant", value=0.0) - - -def use_cutlass_shrink(lora_rank: int) -> bool: - return lora_rank < MIN_RANK_CUSTOM - - -def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: - if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: - return t.transpose(0, 1) - return t - - -# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py -def add_lora_sgmv_cutlass( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.Tensor, - s_end: torch.Tensor, - layer_idx: int, - lora_rank: int, -): - """ - Semantics: - y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H1]`. - wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H2]`. - s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. - s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. - layer_idx: Layer index of the weight matrices. - """ - if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: - # Custom SGMV shrink only supports rank 16, 32, 64, 128 - _add_lora_sgmv_cutlass_legacy( - y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank - ) - return - - tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) - tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) - - -def _add_lora_sgmv_cutlass_legacy( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -): - tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -@lru_cache(maxsize=1) -def get_tmp_tensor(device: torch.device) -> torch.Tensor: - return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) - - -@lru_cache(maxsize=32) -def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: - tmp_size = _kernels.sgmv_cutlass_tmp_size(size) - return torch.empty((tmp_size,), dtype=torch.uint8, device=device) - - -def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: - return torch.empty((size,), dtype=torch.uint8, device=device) - - -def get_tmp_expand_size(size: int) -> int: - return _kernels.sgmv_cutlass_tmp_size(size) - - -def get_tmp_tensors( - nsegments: int, lora_rank: int, device: torch.device -) -> Tuple[torch.Tensor, torch.Tensor]: - use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv() - has_sgmv_available = has_sgmv() - - if use_cutlass: - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - elif has_sgmv_available: - return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device) - else: - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - - -def lora_a_sgmv_cutlass( - x: torch.Tensor, - tmp: torch.Tensor, - wa_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -) -> torch.Tensor: - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - else: - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - return v - - -def lora_b_sgmv_cutlass( - y: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, -): - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -""" -Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - -Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - v: Shape: `[B, R]`. Temporary vector. - x: Shape: `[B, H1]`. Input vectors. - wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. - wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. -""" - - -def add_lora_a_bgmv( - v: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) - - -def add_lora_b_bgmv( - y: torch.Tensor, - v: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) - - -def segmented_matmul( - y: torch.Tensor, - x: torch.Tensor, - w: List[torch.Tensor], - b: List[torch.Tensor], - s_start: torch.IntTensor, - s_end: torch.IntTensor, -): - for i in range(len(w)): - if s_end[i] - s_start[i] <= 0: - continue - - xi = x[s_start[i] : s_end[i]] - wi = w[i] - bi = b[i] - y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)