Use kernels from the kernel hub (#2988)

* Use Hub kernels for Marlin and cutlass quantization kernels

* Use hub kernels for MoE/GPTQ-Marlin MoE

* Use attention kernels from the Hub

* Cache the kernels in the Docker image

* Update moe kernels

* Support loading local kernels for development

* Support latest moe kernels

* Update to moe 0.1.1

* CI: download locked kernels for server tests

* Fixup some imports

* CI: activate venv

* Fix unused imports

* Nix: add attention/moe/quantization kernels

* Update hf-kernels to 0.1.5

* Update kernels

* Update tgi-nix flake for hf-kernels

* Fix EOF

* Take `load_kernel` out of a frequently-called function

* Hoist another case of kernel loading out of a somewhat hot function

* marlin-kernels -> quantization

* attention -> paged-attention

* EOF fix

* Update hf-kernels, fixup Docker

* ipex fix

* Remove outdated TODO
This commit is contained in:
Daniël de Kok 2025-02-10 19:19:25 +01:00 committed by GitHub
parent 4b8cda684b
commit 571ac9b507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 7206 additions and 257 deletions

View File

@ -48,6 +48,10 @@ jobs:
uv venv
source ./.venv/bin/activate
make install-cpu
- name: Download locked kernels
run: |
source ./.venv/bin/activate
hf-kernels download server
- name: Run server tests
run: |
source ./.venv/bin/activate

View File

@ -206,11 +206,13 @@ COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \
pip install -U pip uv && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project && \
. ./.venv/bin/activate && \
make gen-server-raw
make gen-server-raw && \
hf-kernels download .
RUN cd server && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines && \

View File

@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1737685583,
"narHash": "sha256-p+NVABRpGi+pT+xxf9HcLcFVxG6L+vEEy+NwzB9T0f8=",
"lastModified": 1738549608,
"narHash": "sha256-GdyT9QEUSx5k/n8kILuNy83vxxdyUfJ8jL5mMpQZWfw=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "eb64cbcc8eee0fa87ebded92805280d2ec97415a",
"rev": "35c6f8c4352f995ecd53896200769f80a3e8f22d",
"type": "github"
},
"original": {
@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1738323634,
"narHash": "sha256-lKPzgEm7pEuQJVhacsxFHqg1MOtrUMZvr+9IuJzC5J4=",
"lastModified": 1738769628,
"narHash": "sha256-hgHf1mscFbH9XtT3dYtFQcxRfict9N+Vi6QSW1c+FjU=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "eb5fede2756f544f75e01f55a4097f9c9a8c5005",
"rev": "9a5a58219dead9704d83d9d32f105b6b90bd31f2",
"type": "github"
},
"original": {

View File

@ -90,7 +90,7 @@ mkShell {
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd server ; python -m pip install --no-build-isolation --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';

View File

@ -3,7 +3,6 @@
buildPythonPackage,
poetry-core,
mypy-protobuf,
attention-kernels,
awq-inference-engine,
causal-conv1d,
compressed-tensors,
@ -19,22 +18,24 @@
grpcio-reflection,
grpcio-status,
grpcio-tools,
hf-kernels,
hf-transfer,
loguru,
mamba-ssm,
marlin-kernels,
moe-kernels,
moe,
opentelemetry-api,
opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc,
opentelemetry-semantic-conventions,
outlines,
paged-attention,
peft,
pillow,
prometheus-client,
punica-kernels,
py-cpuinfo,
pydantic,
quantization,
safetensors,
tokenizers,
torch,
@ -78,7 +79,6 @@ buildPythonPackage {
pythonRemoveDeps = [ "scipy" ];
dependencies = [
attention-kernels
awq-inference-engine
eetq
causal-conv1d
@ -93,22 +93,24 @@ buildPythonPackage {
grpcio-reflection
grpcio-status
grpcio-tools
hf-kernels
hf-transfer
loguru
mamba-ssm
marlin-kernels
moe-kernels
moe
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions
outlines
paged-attention
peft
pillow
prometheus-client
punica-kernels
py-cpuinfo
pydantic
quantization
safetensors
sentencepiece
tokenizers

6740
server/hf-kernels.lock Normal file

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ dependencies = [
"grpcio>=1.67.0",
"grpcio-reflection>=1.67.0",
"grpcio-status>=1.67.0",
"hf-kernels>=0.1.5",
"hf-transfer>=0.1.8",
"loguru>=0.7.3",
"numpy>=1.26,<3",
@ -33,6 +34,15 @@ dependencies = [
"transformers>=4.48.0"
]
[build-system]
requires = ["hf-kernels>=0.1.2", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/paged-attention" = ">=0.0.2"
"kernels-community/moe" = ">=0.1.1"
"kernels-community/quantization" = ">=0.0.3"
[project.scripts]
text-generation-server = "text_generation_server.cli:app"
@ -60,24 +70,11 @@ quantize = [
"texttable>=1.6.7,<2",
"datasets>=2.21,<3",
]
moe = [ "moe-kernels" ]
attention = [ "attention-kernels" ]
marlin = [ "marlin-kernels" ]
gen = [
"grpcio-tools>=1.69.0",
"mypy-protobuf>=3.6.0",
]
[tool.uv.sources]
attention-kernels.url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
]
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

View File

@ -1,6 +1,7 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
@ -13,6 +14,18 @@ major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
if SYSTEM == "cuda":
try:
paged_attention_kernels = load_kernel(
module="paged_attention", repo_id="kernels-community/paged-attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
paged_attention_kernels = None
def paged_attention(
query: torch.Tensor,
@ -107,7 +120,6 @@ def paged_attention(
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
import attention_kernels
out = torch.empty_like(query)
@ -117,7 +129,7 @@ def paged_attention(
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
attention_kernels.paged_attention_v1(
paged_attention_kernels.paged_attention_v1(
out,
query,
kv_cache.key,
@ -130,8 +142,8 @@ def paged_attention(
max_s,
None,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
)
else:
# Run PagedAttention V2.
@ -148,7 +160,7 @@ def paged_attention(
)
max_logits = torch.empty_like(exp_sums)
attention_kernels.paged_attention_v2(
paged_attention_kernels.paged_attention_v2(
out,
exp_sums,
max_logits,
@ -164,8 +176,8 @@ def paged_attention(
max_s,
None,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
)
return out

View File

@ -7,9 +7,22 @@ import torch
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
if SYSTEM == "cuda":
try:
paged_attention = load_kernel(
module="paged_attention", repo_id="kernels-community/paged-attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
paged_attention = None
@dataclass
class KVScales:
@ -119,7 +132,7 @@ class KVCache:
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif self.dtype == torch.float8_e4m3fn and (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
):
log_once(logger.info, "Using FP8 KV cache scales")
@ -220,19 +233,19 @@ def paged_reshape_and_cache(
):
if SYSTEM == "cuda":
try:
import attention_kernels
except Exception as e:
raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
)
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8"
attention_kernels.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
paged_attention.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype,
torch.tensor(k_scale),
torch.tensor(v_scale),
)
elif SYSTEM == "rocm":
try:

View File

@ -6,13 +6,17 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
class W8A8IntLoader(WeightsLoader):
@ -159,8 +163,8 @@ class Int8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
assert marlin_kernels is not None
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
assert quantization is not None
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
return W8A8IntLinear(
bias=bias,
input_symmetric=self.input_symmetric,
@ -204,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module):
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
input=input,
scale=None,
azp=None,
@ -214,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module):
)
if self.input_symmetric:
return marlin_kernels.cutlass_scaled_mm(
return quantization.cutlass_scaled_mm(
a=qinput,
b=self.weight,
scale_a=input_scale,
@ -229,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module):
and (self.input_symmetric or input_zero_point is not None)
)
return marlin_kernels.cutlass_scaled_mm_azp(
return quantization.cutlass_scaled_mm_azp(
a=qinput,
b=self.weight,
scale_a=input_scale,

View File

@ -6,6 +6,7 @@ import torch
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
@ -14,10 +15,12 @@ from text_generation_server.utils.weights import (
)
from text_generation_server.utils.log import log_once
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
try:
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
@ -29,9 +32,9 @@ quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and marlin_kernels is not None:
if SYSTEM == "cuda" and quantization is not None:
major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
major * 10 + minor
)
else:
@ -143,11 +146,10 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if marlin_kernels is not None:
if quantization is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
qweight, scale = quantization.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=quant_dtype,
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
@ -527,7 +529,7 @@ class Fp8Linear(torch.nn.Module):
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound, scalar=False
)
return marlin_kernels.cutlass_scaled_mm(
return quantization.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)

View File

@ -8,11 +8,15 @@ from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
permute_scales,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
MARLIN_TILE_SIZE = 16
@ -32,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
@ -69,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
C = quantization.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
@ -134,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)

View File

@ -12,13 +12,17 @@ from text_generation_server.layers.marlin.util import (
unpack_cols,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
try:
major, _minor = torch.cuda.get_device_capability()
@ -37,7 +41,7 @@ def can_use_gptq_marlin(
) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and quantization is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
@ -287,7 +291,7 @@ def repack_gptq_for_marlin(
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
@ -330,7 +334,7 @@ def repack_gptq_for_marlin(
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack(
repacked = quantization.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
@ -342,7 +346,7 @@ def repack_gptq_for_marlin(
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
@ -379,13 +383,26 @@ class GPTQMarlinLinear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits
if weight.bits not in (4, 8):
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
if weight.qzeros.numel() > 0:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4
else:
self.quant_type = quantization.scalar_types.uint8
else:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = quantization.scalar_types.uint8b128
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
@ -403,10 +420,10 @@ class GPTQMarlinLinear(nn.Module):
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
C = quantization.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
@ -414,7 +431,7 @@ class GPTQMarlinLinear(nn.Module):
self.g_idx,
self.perm,
self.workspace,
self.bits,
self.quant_type,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],

View File

@ -3,13 +3,18 @@ from typing import List, Optional, Union
import torch
import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
class MarlinWeightsLoader(WeightsLoader):
@ -187,7 +192,7 @@ class MarlinLinear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
@ -216,9 +221,9 @@ class MarlinLinear(nn.Module):
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
C = marlin_kernels.marlin_gemm(
C = quantization.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
@ -277,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join(
@ -303,8 +308,11 @@ class GPTQMarlin24Linear(nn.Module):
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
self.bits = weight.bits
weights_per_int32 = 32 // self.bits
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = quantization.scalar_types.uint8b128
weights_per_int32 = 32 // weight.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
@ -336,15 +344,15 @@ class GPTQMarlin24Linear(nn.Module):
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
C = marlin_kernels.gptq_marlin_24_gemm(
C = quantization.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.weight_packed,
self.meta,
self.scale_packed,
self.workspace,
self.bits,
self.quant_type,
A.shape[0],
self.scale_packed.shape[1],
A.shape[1],

View File

@ -4,11 +4,14 @@ from typing import List, Tuple
import numpy
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
try:
major, _minor = torch.cuda.get_device_capability()
@ -23,7 +26,7 @@ def _check_marlin_kernels():
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
if quantization is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)

View File

@ -18,6 +18,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
@ -27,6 +28,10 @@ from text_generation_server.utils.weights import (
if SYSTEM == "ipex":
from .fused_moe_ipex import fused_topk, grouped_topk
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
fused_topk = moe_kernels.fused_topk
grouped_topk = moe_kernels.grouped_topk
else:
from moe_kernels.fused_moe import fused_topk, grouped_topk

View File

@ -12,7 +12,7 @@ from text_generation_server.layers.fp8 import (
)
try:
from moe_kernels.fused_moe import fused_moe
from .unquantized import fused_moe
except Exception:
fused_moe = None

View File

@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from text_generation_server.layers import moe
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight,
@ -12,9 +14,9 @@ from text_generation_server.layers.marlin.gptq import (
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else:
fused_marlin_moe = None
moe_kernels = None
try:
@ -32,7 +34,7 @@ def can_use_marlin_moe_gemm(
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and moe is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
@ -230,3 +232,111 @@ def _pack_weight(
moe_weight.perm[expert] = weight.perm
return moe_weight
def fused_marlin_moe(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
sort_indices1: torch.Tensor,
sort_indices2: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
is_k_full: bool,
topk: int,
renormalize: bool,
num_bits: int = 8,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- sort_indices1 (torch.Tensor): The first act_order input permutation.
- sort_indices2 (torch.Tensor): The second act_order input permutation.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
return moe_kernels.fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
gating_output=gating_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=w1_zeros,
w2_zeros=w2_zeros,
num_bits=num_bits,
is_k_full=is_k_full,
)

View File

@ -1,15 +1,18 @@
from typing import Optional
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else:
from moe_kernels.fused_moe import fused_moe
import moe_kernels
class UnquantizedSparseMoELayer(nn.Module):
@ -63,7 +66,17 @@ class UnquantizedSparseMoELayer(nn.Module):
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "ipex":
if SYSTEM == "rocm":
return moe_kernels.fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
@ -146,3 +159,110 @@ def _load_expert_weights_row(
assert all_weight is not None
return all_weight
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
from loguru import logger
import inspect
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states, gating_output, topk, renormalize
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize
)
return moe_kernels.fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)

View File

@ -22,11 +22,14 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else:
from moe_kernels.fused_moe import fused_moe
import moe_kernels
from text_generation_server.layers.attention import (
paged_attention,
@ -510,7 +513,7 @@ class BlockSparseMoE(nn.Module):
topk_group=None,
)
else:
out = fused_moe(
out = moe_kernels.fused_moe(
x,
self.wv1,
self.w2,

View File

@ -0,0 +1,22 @@
import importlib
from loguru import logger
from hf_kernels import load_kernel as hf_load_kernel
from text_generation_server.utils.log import log_once
def load_kernel(*, module: str, repo_id: str):
"""
Load a kernel. First try to load it as the given module (e.g. for
local development), falling back to a locked Hub kernel.
"""
try:
m = importlib.import_module(module)
log_once(logger.info, f"Using local module for `{module}`")
return m
except ModuleNotFoundError:
return hf_load_kernel(repo_id=repo_id)
__all__ = ["load_kernel"]

View File

@ -168,20 +168,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 },
]
[[package]]
name = "attention-kernels"
version = "0.2.0.post2"
source = { url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" }
dependencies = [
{ name = "torch" },
]
wheels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl", hash = "sha256:863e02dda4b30e9d04ef6cf4d17d16c154f54bdcb8a8b87b8b46075eabf62d25" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "attrs"
version = "24.3.0"
@ -676,6 +662,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/64/51/f6b198152399d17247d962340947728fb1b06da6bc0c0a542446b2ffee49/grpcio_tools-1.69.0-cp39-cp39-win_amd64.whl", hash = "sha256:5d47abf7e0662dd5dbb9cc252c3616e5fbc5f71d34e3f6332cd24bcdf2940abd", size = 1114931 },
]
[[package]]
name = "hf-kernels"
version = "0.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "packaging" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/01/fe/5aa3ea1b66bcc7d81aff19683ea04d4a9cd414c8d4ff05b150fc1f196ccd/hf_kernels-0.1.6.tar.gz", hash = "sha256:5effee5046552ce226ff86d3870a799f4ecae399bcb2beb4046c28c2dd736d2f", size = 8704 }
[[package]]
name = "hf-transfer"
version = "0.1.9"
@ -906,86 +904,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 },
]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.13'",
]
dependencies = [
{ name = "torch", marker = "python_full_version >= '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b2/82/886d1eece474ef23668c4780f5053ea654999704a0195aadc651631b740d/marlin-kernels-0.3.7.tar.gz", hash = "sha256:8be8a65fd9ae21b2406afba9e460e3922582479b85a1372096e87e3a15684a77", size = 15662 }
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl" }
resolution-markers = [
"python_full_version == '3.10.*'",
]
dependencies = [
{ name = "torch", marker = "python_full_version == '3.10.*'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", hash = "sha256:dd91a4e2c3b5e954833c5c34b0322e4c02cd92a967eb94654b6bbcece131340b" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl" }
resolution-markers = [
"python_full_version == '3.11.*'",
]
dependencies = [
{ name = "torch", marker = "python_full_version == '3.11.*'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", hash = "sha256:b24d92135fbd156c55ce43158ab4a90fa880ba0df965528895cf1870b03a64bf" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl" }
resolution-markers = [
"python_full_version == '3.12.*'",
]
dependencies = [
{ name = "torch", marker = "python_full_version == '3.12.*'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", hash = "sha256:8a407f1435a571a8d4ca3b9f533da83fde323043a9836b739cf8018c77782d49" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl" }
resolution-markers = [
"python_full_version < '3.10'",
]
dependencies = [
{ name = "torch", marker = "python_full_version < '3.10'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", hash = "sha256:bf7003753c364c504b3998fffdfcf619a42ab04f908903dbad8d54347b6b142b" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "mdurl"
version = "0.1.2"
@ -995,26 +913,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
]
[[package]]
name = "moe-kernels"
version = "0.8.2"
source = { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" }
dependencies = [
{ name = "nvidia-ml-py" },
{ name = "torch" },
{ name = "triton" },
]
wheels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl", hash = "sha256:1ed5b26f52339d25ea2513e99e8b6239cf1921af3eac54e03a46bb8f8efb380b" },
]
[package.metadata]
requires-dist = [
{ name = "nvidia-ml-py" },
{ name = "torch" },
{ name = "triton" },
]
[[package]]
name = "mpmath"
version = "1.3.0"
@ -1308,6 +1206,7 @@ name = "nvidia-cublas-cu12"
version = "12.4.5.8"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 },
{ url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 },
]
@ -1316,6 +1215,7 @@ name = "nvidia-cuda-cupti-cu12"
version = "12.4.127"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 },
{ url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 },
]
@ -1324,6 +1224,7 @@ name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 },
{ url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 },
]
@ -1332,6 +1233,7 @@ name = "nvidia-cuda-runtime-cu12"
version = "12.4.127"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 },
{ url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 },
]
@ -1354,6 +1256,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 },
{ url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
]
@ -1362,6 +1265,7 @@ name = "nvidia-curand-cu12"
version = "10.3.5.147"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 },
{ url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 },
]
@ -1375,6 +1279,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 },
{ url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
]
@ -1386,18 +1291,10 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 },
{ url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
]
[[package]]
name = "nvidia-ml-py"
version = "12.560.30"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/53/10/5f34de4a71db8b2b7ec4269f4a33287f24c23e2857ea3187c977b7bc3604/nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97", size = 39194 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/f3/a69ce0b1a1e12fbf6b2ad9f4c14c9999fdbdf15f2478d210f0fd501ddc98/nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73", size = 40526 },
]
[[package]]
name = "nvidia-nccl-cu12"
version = "2.21.5"
@ -1411,6 +1308,7 @@ name = "nvidia-nvjitlink-cu12"
version = "12.4.127"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 },
{ url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 },
]
@ -1419,6 +1317,7 @@ name = "nvidia-nvtx-cu12"
version = "12.4.127"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 },
{ url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
]
@ -2653,6 +2552,7 @@ dependencies = [
{ name = "grpcio" },
{ name = "grpcio-reflection" },
{ name = "grpcio-status" },
{ name = "hf-kernels" },
{ name = "hf-transfer" },
{ name = "loguru" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
@ -2678,9 +2578,6 @@ dependencies = [
accelerate = [
{ name = "accelerate" },
]
attention = [
{ name = "attention-kernels" },
]
bnb = [
{ name = "bitsandbytes" },
]
@ -2695,16 +2592,6 @@ gen = [
{ name = "grpcio-tools" },
{ name = "mypy-protobuf" },
]
marlin = [
{ name = "marlin-kernels", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl" }, marker = "python_full_version == '3.10.*'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl" }, marker = "python_full_version == '3.11.*'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl" }, marker = "python_full_version == '3.12.*'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl" }, marker = "python_full_version < '3.10'" },
]
moe = [
{ name = "moe-kernels" },
]
outlines = [
{ name = "outlines" },
]
@ -2719,7 +2606,6 @@ quantize = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'accelerate'", specifier = ">=1.2.1,<2" },
{ name = "attention-kernels", marker = "extra == 'attention'", url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" },
{ name = "bitsandbytes", marker = "extra == 'bnb'", specifier = ">=0.45.0" },
{ name = "compressed-tensors", marker = "extra == 'compressed-tensors'", specifier = ">=0.9.0" },
{ name = "datasets", marker = "extra == 'quantize'", specifier = ">=2.21,<3" },
@ -2730,14 +2616,9 @@ requires-dist = [
{ name = "grpcio-status", specifier = ">=1.67.0" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" },
{ name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" },
{ name = "hf-kernels", specifier = ">=0.1.5" },
{ name = "hf-transfer", specifier = ">=0.1.8" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "marlin-kernels", marker = "python_full_version == '3.9.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl" },
{ name = "marlin-kernels", marker = "(python_full_version < '3.9' and extra == 'marlin') or (python_full_version >= '3.13' and extra == 'marlin')" },
{ name = "marlin-kernels", marker = "python_full_version == '3.10.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl" },
{ name = "marlin-kernels", marker = "python_full_version == '3.11.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl" },
{ name = "marlin-kernels", marker = "python_full_version == '3.12.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl" },
{ name = "moe-kernels", marker = "extra == 'moe'", url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" },
{ name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" },
{ name = "numpy", specifier = ">=1.26,<3" },
{ name = "opentelemetry-api", specifier = ">=1.27.0" },
@ -2919,7 +2800,7 @@ name = "triton"
version = "3.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "filelock", marker = "python_full_version < '3.13'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 },