Use kernels from the kernel hub (#2988)

* Use Hub kernels for Marlin and cutlass quantization kernels

* Use hub kernels for MoE/GPTQ-Marlin MoE

* Use attention kernels from the Hub

* Cache the kernels in the Docker image

* Update moe kernels

* Support loading local kernels for development

* Support latest moe kernels

* Update to moe 0.1.1

* CI: download locked kernels for server tests

* Fixup some imports

* CI: activate venv

* Fix unused imports

* Nix: add attention/moe/quantization kernels

* Update hf-kernels to 0.1.5

* Update kernels

* Update tgi-nix flake for hf-kernels

* Fix EOF

* Take `load_kernel` out of a frequently-called function

* Hoist another case of kernel loading out of a somewhat hot function

* marlin-kernels -> quantization

* attention -> paged-attention

* EOF fix

* Update hf-kernels, fixup Docker

* ipex fix

* Remove outdated TODO
This commit is contained in:
Daniël de Kok 2025-02-10 19:19:25 +01:00 committed by GitHub
parent 4b8cda684b
commit 571ac9b507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 7206 additions and 257 deletions

View File

@ -48,6 +48,10 @@ jobs:
uv venv uv venv
source ./.venv/bin/activate source ./.venv/bin/activate
make install-cpu make install-cpu
- name: Download locked kernels
run: |
source ./.venv/bin/activate
hf-kernels download server
- name: Run server tests - name: Run server tests
run: | run: |
source ./.venv/bin/activate source ./.venv/bin/activate

View File

@ -206,11 +206,13 @@ COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \ RUN cd server && \
pip install -U pip uv && \ pip install -U pip uv && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project && \ uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project && \
. ./.venv/bin/activate && \ . ./.venv/bin/activate && \
make gen-server-raw make gen-server-raw && \
hf-kernels download .
RUN cd server && \ RUN cd server && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines && \ uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines && \

View File

@ -853,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1737685583, "lastModified": 1738549608,
"narHash": "sha256-p+NVABRpGi+pT+xxf9HcLcFVxG6L+vEEy+NwzB9T0f8=", "narHash": "sha256-GdyT9QEUSx5k/n8kILuNy83vxxdyUfJ8jL5mMpQZWfw=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "eb64cbcc8eee0fa87ebded92805280d2ec97415a", "rev": "35c6f8c4352f995ecd53896200769f80a3e8f22d",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1738323634, "lastModified": 1738769628,
"narHash": "sha256-lKPzgEm7pEuQJVhacsxFHqg1MOtrUMZvr+9IuJzC5J4=", "narHash": "sha256-hgHf1mscFbH9XtT3dYtFQcxRfict9N+Vi6QSW1c+FjU=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "eb5fede2756f544f75e01f55a4097f9c9a8c5005", "rev": "9a5a58219dead9704d83d9d32f105b6b90bd31f2",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -90,7 +90,7 @@ mkShell {
postVenvCreation = '' postVenvCreation = ''
unset SOURCE_DATE_EPOCH unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . ) ( cd server ; python -m pip install --no-build-isolation --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . ) ( cd clients/python ; python -m pip install --no-dependencies -e . )
''; '';

View File

@ -3,7 +3,6 @@
buildPythonPackage, buildPythonPackage,
poetry-core, poetry-core,
mypy-protobuf, mypy-protobuf,
attention-kernels,
awq-inference-engine, awq-inference-engine,
causal-conv1d, causal-conv1d,
compressed-tensors, compressed-tensors,
@ -19,22 +18,24 @@
grpcio-reflection, grpcio-reflection,
grpcio-status, grpcio-status,
grpcio-tools, grpcio-tools,
hf-kernels,
hf-transfer, hf-transfer,
loguru, loguru,
mamba-ssm, mamba-ssm,
marlin-kernels, moe,
moe-kernels,
opentelemetry-api, opentelemetry-api,
opentelemetry-exporter-otlp, opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc, opentelemetry-instrumentation-grpc,
opentelemetry-semantic-conventions, opentelemetry-semantic-conventions,
outlines, outlines,
paged-attention,
peft, peft,
pillow, pillow,
prometheus-client, prometheus-client,
punica-kernels, punica-kernels,
py-cpuinfo, py-cpuinfo,
pydantic, pydantic,
quantization,
safetensors, safetensors,
tokenizers, tokenizers,
torch, torch,
@ -78,7 +79,6 @@ buildPythonPackage {
pythonRemoveDeps = [ "scipy" ]; pythonRemoveDeps = [ "scipy" ];
dependencies = [ dependencies = [
attention-kernels
awq-inference-engine awq-inference-engine
eetq eetq
causal-conv1d causal-conv1d
@ -93,22 +93,24 @@ buildPythonPackage {
grpcio-reflection grpcio-reflection
grpcio-status grpcio-status
grpcio-tools grpcio-tools
hf-kernels
hf-transfer hf-transfer
loguru loguru
mamba-ssm mamba-ssm
marlin-kernels moe
moe-kernels
opentelemetry-api opentelemetry-api
opentelemetry-exporter-otlp opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions opentelemetry-semantic-conventions
outlines outlines
paged-attention
peft peft
pillow pillow
prometheus-client prometheus-client
punica-kernels punica-kernels
py-cpuinfo py-cpuinfo
pydantic pydantic
quantization
safetensors safetensors
sentencepiece sentencepiece
tokenizers tokenizers

6740
server/hf-kernels.lock Normal file

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ dependencies = [
"grpcio>=1.67.0", "grpcio>=1.67.0",
"grpcio-reflection>=1.67.0", "grpcio-reflection>=1.67.0",
"grpcio-status>=1.67.0", "grpcio-status>=1.67.0",
"hf-kernels>=0.1.5",
"hf-transfer>=0.1.8", "hf-transfer>=0.1.8",
"loguru>=0.7.3", "loguru>=0.7.3",
"numpy>=1.26,<3", "numpy>=1.26,<3",
@ -33,6 +34,15 @@ dependencies = [
"transformers>=4.48.0" "transformers>=4.48.0"
] ]
[build-system]
requires = ["hf-kernels>=0.1.2", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/paged-attention" = ">=0.0.2"
"kernels-community/moe" = ">=0.1.1"
"kernels-community/quantization" = ">=0.0.3"
[project.scripts] [project.scripts]
text-generation-server = "text_generation_server.cli:app" text-generation-server = "text_generation_server.cli:app"
@ -60,24 +70,11 @@ quantize = [
"texttable>=1.6.7,<2", "texttable>=1.6.7,<2",
"datasets>=2.21,<3", "datasets>=2.21,<3",
] ]
moe = [ "moe-kernels" ]
attention = [ "attention-kernels" ]
marlin = [ "marlin-kernels" ]
gen = [ gen = [
"grpcio-tools>=1.69.0", "grpcio-tools>=1.69.0",
"mypy-protobuf>=3.6.0", "mypy-protobuf>=3.6.0",
] ]
[tool.uv.sources]
attention-kernels.url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
]
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

View File

@ -1,6 +1,7 @@
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
@ -13,6 +14,18 @@ major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
if SYSTEM == "cuda":
try:
paged_attention_kernels = load_kernel(
module="paged_attention", repo_id="kernels-community/paged-attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
paged_attention_kernels = None
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
@ -107,7 +120,6 @@ def paged_attention(
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths + seqlen.cache_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
import attention_kernels
out = torch.empty_like(query) out = torch.empty_like(query)
@ -117,7 +129,7 @@ def paged_attention(
max_num_partitions == 1 or num_seqs * num_heads > 512 max_num_partitions == 1 or num_seqs * num_heads > 512
) )
if use_v1: if use_v1:
attention_kernels.paged_attention_v1( paged_attention_kernels.paged_attention_v1(
out, out,
query, query,
kv_cache.key, kv_cache.key,
@ -130,8 +142,8 @@ def paged_attention(
max_s, max_s,
None, None,
kv_cache_dtype, kv_cache_dtype,
kv_scales.key_scale_cpu, torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
kv_scales.value_scale_cpu, torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -148,7 +160,7 @@ def paged_attention(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
attention_kernels.paged_attention_v2( paged_attention_kernels.paged_attention_v2(
out, out,
exp_sums, exp_sums,
max_logits, max_logits,
@ -164,8 +176,8 @@ def paged_attention(
max_s, max_s,
None, None,
kv_cache_dtype, kv_cache_dtype,
kv_scales.key_scale_cpu, torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
kv_scales.value_scale_cpu, torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
) )
return out return out

View File

@ -7,9 +7,22 @@ import torch
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
if SYSTEM == "cuda":
try:
paged_attention = load_kernel(
module="paged_attention", repo_id="kernels-community/paged-attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
paged_attention = None
@dataclass @dataclass
class KVScales: class KVScales:
@ -119,7 +132,7 @@ class KVCache:
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False return False
elif self.dtype == torch.float8_e4m3fn and ( elif self.dtype == torch.float8_e4m3fn and (
(ATTENTION == "flashinfer" and SYSTEM == "cuda") (ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm") or (ATTENTION == "paged" and SYSTEM == "rocm")
): ):
log_once(logger.info, "Using FP8 KV cache scales") log_once(logger.info, "Using FP8 KV cache scales")
@ -220,19 +233,19 @@ def paged_reshape_and_cache(
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try:
import attention_kernels
except Exception as e:
raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
)
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn: if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8" kv_cache_dtype = "fp8"
attention_kernels.reshape_and_cache( paged_attention.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype,
torch.tensor(k_scale),
torch.tensor(v_scale),
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
try: try:

View File

@ -6,13 +6,17 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: if SYSTEM == "cuda":
import marlin_kernels quantization = load_kernel(
except ImportError: module="quantization", repo_id="kernels-community/quantization"
marlin_kernels = None )
else:
quantization = None
class W8A8IntLoader(WeightsLoader): class W8A8IntLoader(WeightsLoader):
@ -159,8 +163,8 @@ class Int8Weight(Weight):
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
assert marlin_kernels is not None assert quantization is not None
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight) qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
return W8A8IntLinear( return W8A8IntLinear(
bias=bias, bias=bias,
input_symmetric=self.input_symmetric, input_symmetric=self.input_symmetric,
@ -204,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant( qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
input=input, input=input,
scale=None, scale=None,
azp=None, azp=None,
@ -214,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module):
) )
if self.input_symmetric: if self.input_symmetric:
return marlin_kernels.cutlass_scaled_mm( return quantization.cutlass_scaled_mm(
a=qinput, a=qinput,
b=self.weight, b=self.weight,
scale_a=input_scale, scale_a=input_scale,
@ -229,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module):
and (self.input_symmetric or input_zero_point is not None) and (self.input_symmetric or input_zero_point is not None)
) )
return marlin_kernels.cutlass_scaled_mm_azp( return quantization.cutlass_scaled_mm_azp(
a=qinput, a=qinput,
b=self.weight, b=self.weight,
scale_a=input_scale, scale_a=input_scale,

View File

@ -6,6 +6,7 @@ import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
Weight, Weight,
WeightsLoader, WeightsLoader,
@ -14,10 +15,12 @@ from text_generation_server.utils.weights import (
) )
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
try: if SYSTEM == "cuda":
import marlin_kernels quantization = load_kernel(
except ImportError: module="quantization", repo_id="kernels-community/quantization"
marlin_kernels = None )
else:
quantization = None
try: try:
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
@ -29,9 +32,9 @@ quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
) )
if SYSTEM == "cuda" and marlin_kernels is not None: if SYSTEM == "cuda" and quantization is not None:
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8( CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
major * 10 + minor major * 10 + minor
) )
else: else:
@ -143,11 +146,10 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification). be used without modification).
""" """
if marlin_kernels is not None: if quantization is not None:
shape = weight.shape shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant( qweight, scale = quantization.scaled_fp8_quant(
weight.reshape(-1, shape[-1]), weight.reshape(-1, shape[-1]),
dtype=quant_dtype,
scale=scale, scale=scale,
scale_ub=scale_upper_bound, scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel. # TODO: don't do this when we have to use the Torch kernel.
@ -527,7 +529,7 @@ class Fp8Linear(torch.nn.Module):
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound, scalar=False input, scale_upper_bound=self.scale_upper_bound, scalar=False
) )
return marlin_kernels.cutlass_scaled_mm( return quantization.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
) )

View File

@ -8,11 +8,15 @@ from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
permute_scales, permute_scales,
) )
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
try: if SYSTEM == "cuda":
import marlin_kernels quantization = load_kernel(
except ImportError: module="quantization", repo_id="kernels-community/quantization"
marlin_kernels = None )
else:
quantization = None
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
@ -32,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: if scales.shape[1] == 1:
@ -69,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias) return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
A_flat = A.view(-1, A.shape[-1]) A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm( C = quantization.fp8_marlin_gemm(
A_flat, A_flat,
self.qweight, self.qweight,
self.scales, self.scales,
@ -134,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
qweight = pack_fp8_as_int32(weight.t()) qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device) perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack( repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8 qweight, perm, in_features, out_features, 8
) )

View File

@ -12,13 +12,17 @@ from text_generation_server.layers.marlin.util import (
unpack_cols, unpack_cols,
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: if SYSTEM == "cuda":
import marlin_kernels quantization = load_kernel(
except ImportError: module="quantization", repo_id="kernels-community/quantization"
marlin_kernels = None )
else:
quantization = None
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
@ -37,7 +41,7 @@ def can_use_gptq_marlin(
) -> bool: ) -> bool:
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and quantization is not None
and has_sm_8_0 and has_sm_8_0
and quantize in {"awq", "gptq"} and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"} and quant_method in {"awq", "gptq"}
@ -287,7 +291,7 @@ def repack_gptq_for_marlin(
) -> GPTQMarlinWeight: ) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
if bits not in GPTQ_MARLIN_BITS: if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
@ -330,7 +334,7 @@ def repack_gptq_for_marlin(
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq": if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack( repacked = quantization.awq_marlin_repack(
qweight, in_features, out_features, bits qweight, in_features, out_features, bits
) )
if qzeros is not None: if qzeros is not None:
@ -342,7 +346,7 @@ def repack_gptq_for_marlin(
) )
else: else:
repacked = marlin_kernels.gptq_marlin_repack( repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits qweight, perm, in_features, out_features, bits
) )
@ -379,13 +383,26 @@ class GPTQMarlinLinear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1] out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features) _check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits if weight.bits not in (4, 8):
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
if weight.qzeros.numel() > 0:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4
else:
self.quant_type = quantization.scalar_types.uint8
else:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = quantization.scalar_types.uint8b128
self.is_full_k = weight.is_full_k self.is_full_k = weight.is_full_k
self.qweight = weight.qweight self.qweight = weight.qweight
@ -403,10 +420,10 @@ class GPTQMarlinLinear(nn.Module):
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
A_flat = A.view(-1, A.shape[-1]) A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm( C = quantization.gptq_marlin_gemm(
A_flat, A_flat,
self.qweight, self.qweight,
self.scales, self.scales,
@ -414,7 +431,7 @@ class GPTQMarlinLinear(nn.Module):
self.g_idx, self.g_idx,
self.perm, self.perm,
self.workspace, self.workspace,
self.bits, self.quant_type,
A_flat.shape[0], A_flat.shape[0],
self.scales.shape[1], self.scales.shape[1],
A_flat.shape[1], A_flat.shape[1],

View File

@ -3,13 +3,18 @@ from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: if SYSTEM == "cuda":
import marlin_kernels quantization = load_kernel(
except ImportError: module="quantization", repo_id="kernels-community/quantization"
marlin_kernels = None )
else:
quantization = None
class MarlinWeightsLoader(WeightsLoader): class MarlinWeightsLoader(WeightsLoader):
@ -187,7 +192,7 @@ class MarlinLinear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1] out_features = weight.s.shape[1]
@ -216,9 +221,9 @@ class MarlinLinear(nn.Module):
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
C = marlin_kernels.marlin_gemm( C = quantization.marlin_gemm(
A.view(-1, A.shape[-1]), A.view(-1, A.shape[-1]),
self.B, self.B,
self.s, self.s,
@ -277,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join( supported_bits = ", ".join(
@ -303,8 +308,11 @@ class GPTQMarlin24Linear(nn.Module):
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
) )
self.bits = weight.bits if weight.bits == 4:
weights_per_int32 = 32 // self.bits self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = quantization.scalar_types.uint8b128
weights_per_int32 = 32 // weight.bits
assert ( assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
@ -336,15 +344,15 @@ class GPTQMarlin24Linear(nn.Module):
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
C = marlin_kernels.gptq_marlin_24_gemm( C = quantization.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]), A.view(-1, A.shape[-1]),
self.weight_packed, self.weight_packed,
self.meta, self.meta,
self.scale_packed, self.scale_packed,
self.workspace, self.workspace,
self.bits, self.quant_type,
A.shape[0], A.shape[0],
self.scale_packed.shape[1], self.scale_packed.shape[1],
A.shape[1], A.shape[1],

View File

@ -4,11 +4,14 @@ from typing import List, Tuple
import numpy import numpy
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
try: if SYSTEM == "cuda":
import marlin_kernels quantization = load_kernel(
except ImportError: module="quantization", repo_id="kernels-community/quantization"
marlin_kernels = None )
else:
quantization = None
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
@ -23,7 +26,7 @@ def _check_marlin_kernels():
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
) )
if marlin_kernels is None: if quantization is None:
raise NotImplementedError( raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin" "marlin is not installed, install it with: pip install server/marlin"
) )

View File

@ -18,6 +18,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
@ -27,6 +28,10 @@ from text_generation_server.utils.weights import (
if SYSTEM == "ipex": if SYSTEM == "ipex":
from .fused_moe_ipex import fused_topk, grouped_topk from .fused_moe_ipex import fused_topk, grouped_topk
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
fused_topk = moe_kernels.fused_topk
grouped_topk = moe_kernels.grouped_topk
else: else:
from moe_kernels.fused_moe import fused_topk, grouped_topk from moe_kernels.fused_moe import fused_topk, grouped_topk

View File

@ -12,7 +12,7 @@ from text_generation_server.layers.fp8 import (
) )
try: try:
from moe_kernels.fused_moe import fused_moe from .unquantized import fused_moe
except Exception: except Exception:
fused_moe = None fused_moe = None

View File

@ -1,10 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Callable, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers import moe
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import ( from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight, GPTQMarlinWeight,
@ -12,9 +14,9 @@ from text_generation_server.layers.marlin.gptq import (
) )
if SYSTEM == "cuda": if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else: else:
fused_marlin_moe = None moe_kernels = None
try: try:
@ -32,7 +34,7 @@ def can_use_marlin_moe_gemm(
): ):
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and fused_marlin_moe is not None and moe is not None
and has_sm_8_0 and has_sm_8_0
and quantize in {"awq", "gptq"} and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"} and quant_method in {"awq", "gptq"}
@ -230,3 +232,111 @@ def _pack_weight(
moe_weight.perm[expert] = weight.perm moe_weight.perm[expert] = weight.perm
return moe_weight return moe_weight
def fused_marlin_moe(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
sort_indices1: torch.Tensor,
sort_indices2: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
is_k_full: bool,
topk: int,
renormalize: bool,
num_bits: int = 8,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- sort_indices1 (torch.Tensor): The first act_order input permutation.
- sort_indices2 (torch.Tensor): The second act_order input permutation.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
return moe_kernels.fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
gating_output=gating_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=w1_zeros,
w2_zeros=w2_zeros,
num_bits=num_bits,
is_k_full=is_k_full,
)

View File

@ -1,15 +1,18 @@
from typing import Optional from typing import Callable, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else: else:
from moe_kernels.fused_moe import fused_moe import moe_kernels
class UnquantizedSparseMoELayer(nn.Module): class UnquantizedSparseMoELayer(nn.Module):
@ -63,7 +66,17 @@ class UnquantizedSparseMoELayer(nn.Module):
) )
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "ipex": if SYSTEM == "rocm":
return moe_kernels.fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe( return self.ipex_fused_moe(
hidden_states=x, hidden_states=x,
router_logits=gating_output, router_logits=gating_output,
@ -146,3 +159,110 @@ def _load_expert_weights_row(
assert all_weight is not None assert all_weight is not None
return all_weight return all_weight
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
from loguru import logger
import inspect
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states, gating_output, topk, renormalize
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize
)
return moe_kernels.fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)

View File

@ -22,11 +22,14 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else: else:
from moe_kernels.fused_moe import fused_moe import moe_kernels
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -510,7 +513,7 @@ class BlockSparseMoE(nn.Module):
topk_group=None, topk_group=None,
) )
else: else:
out = fused_moe( out = moe_kernels.fused_moe(
x, x,
self.wv1, self.wv1,
self.w2, self.w2,

View File

@ -0,0 +1,22 @@
import importlib
from loguru import logger
from hf_kernels import load_kernel as hf_load_kernel
from text_generation_server.utils.log import log_once
def load_kernel(*, module: str, repo_id: str):
"""
Load a kernel. First try to load it as the given module (e.g. for
local development), falling back to a locked Hub kernel.
"""
try:
m = importlib.import_module(module)
log_once(logger.info, f"Using local module for `{module}`")
return m
except ModuleNotFoundError:
return hf_load_kernel(repo_id=repo_id)
__all__ = ["load_kernel"]

View File

@ -168,20 +168,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 },
] ]
[[package]]
name = "attention-kernels"
version = "0.2.0.post2"
source = { url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" }
dependencies = [
{ name = "torch" },
]
wheels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl", hash = "sha256:863e02dda4b30e9d04ef6cf4d17d16c154f54bdcb8a8b87b8b46075eabf62d25" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]] [[package]]
name = "attrs" name = "attrs"
version = "24.3.0" version = "24.3.0"
@ -676,6 +662,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/64/51/f6b198152399d17247d962340947728fb1b06da6bc0c0a542446b2ffee49/grpcio_tools-1.69.0-cp39-cp39-win_amd64.whl", hash = "sha256:5d47abf7e0662dd5dbb9cc252c3616e5fbc5f71d34e3f6332cd24bcdf2940abd", size = 1114931 }, { url = "https://files.pythonhosted.org/packages/64/51/f6b198152399d17247d962340947728fb1b06da6bc0c0a542446b2ffee49/grpcio_tools-1.69.0-cp39-cp39-win_amd64.whl", hash = "sha256:5d47abf7e0662dd5dbb9cc252c3616e5fbc5f71d34e3f6332cd24bcdf2940abd", size = 1114931 },
] ]
[[package]]
name = "hf-kernels"
version = "0.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "packaging" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/01/fe/5aa3ea1b66bcc7d81aff19683ea04d4a9cd414c8d4ff05b150fc1f196ccd/hf_kernels-0.1.6.tar.gz", hash = "sha256:5effee5046552ce226ff86d3870a799f4ecae399bcb2beb4046c28c2dd736d2f", size = 8704 }
[[package]] [[package]]
name = "hf-transfer" name = "hf-transfer"
version = "0.1.9" version = "0.1.9"
@ -906,86 +904,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 }, { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 },
] ]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.13'",
]
dependencies = [
{ name = "torch", marker = "python_full_version >= '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b2/82/886d1eece474ef23668c4780f5053ea654999704a0195aadc651631b740d/marlin-kernels-0.3.7.tar.gz", hash = "sha256:8be8a65fd9ae21b2406afba9e460e3922582479b85a1372096e87e3a15684a77", size = 15662 }
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl" }
resolution-markers = [
"python_full_version == '3.10.*'",
]
dependencies = [
{ name = "torch", marker = "python_full_version == '3.10.*'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", hash = "sha256:dd91a4e2c3b5e954833c5c34b0322e4c02cd92a967eb94654b6bbcece131340b" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl" }
resolution-markers = [
"python_full_version == '3.11.*'",
]
dependencies = [
{ name = "torch", marker = "python_full_version == '3.11.*'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", hash = "sha256:b24d92135fbd156c55ce43158ab4a90fa880ba0df965528895cf1870b03a64bf" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl" }
resolution-markers = [
"python_full_version == '3.12.*'",
]
dependencies = [
{ name = "torch", marker = "python_full_version == '3.12.*'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", hash = "sha256:8a407f1435a571a8d4ca3b9f533da83fde323043a9836b739cf8018c77782d49" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]]
name = "marlin-kernels"
version = "0.3.7"
source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl" }
resolution-markers = [
"python_full_version < '3.10'",
]
dependencies = [
{ name = "torch", marker = "python_full_version < '3.10'" },
]
wheels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", hash = "sha256:bf7003753c364c504b3998fffdfcf619a42ab04f908903dbad8d54347b6b142b" },
]
[package.metadata]
requires-dist = [{ name = "torch" }]
[[package]] [[package]]
name = "mdurl" name = "mdurl"
version = "0.1.2" version = "0.1.2"
@ -995,26 +913,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
] ]
[[package]]
name = "moe-kernels"
version = "0.8.2"
source = { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" }
dependencies = [
{ name = "nvidia-ml-py" },
{ name = "torch" },
{ name = "triton" },
]
wheels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl", hash = "sha256:1ed5b26f52339d25ea2513e99e8b6239cf1921af3eac54e03a46bb8f8efb380b" },
]
[package.metadata]
requires-dist = [
{ name = "nvidia-ml-py" },
{ name = "torch" },
{ name = "triton" },
]
[[package]] [[package]]
name = "mpmath" name = "mpmath"
version = "1.3.0" version = "1.3.0"
@ -1308,6 +1206,7 @@ name = "nvidia-cublas-cu12"
version = "12.4.5.8" version = "12.4.5.8"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 },
{ url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 },
] ]
@ -1316,6 +1215,7 @@ name = "nvidia-cuda-cupti-cu12"
version = "12.4.127" version = "12.4.127"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 },
{ url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 },
] ]
@ -1324,6 +1224,7 @@ name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127" version = "12.4.127"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 },
{ url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 },
] ]
@ -1332,6 +1233,7 @@ name = "nvidia-cuda-runtime-cu12"
version = "12.4.127" version = "12.4.127"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 },
{ url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 },
] ]
@ -1354,6 +1256,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 },
{ url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
] ]
@ -1362,6 +1265,7 @@ name = "nvidia-curand-cu12"
version = "10.3.5.147" version = "10.3.5.147"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 },
{ url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 },
] ]
@ -1375,6 +1279,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 },
{ url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
] ]
@ -1386,18 +1291,10 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 },
{ url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
] ]
[[package]]
name = "nvidia-ml-py"
version = "12.560.30"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/53/10/5f34de4a71db8b2b7ec4269f4a33287f24c23e2857ea3187c977b7bc3604/nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97", size = 39194 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/f3/a69ce0b1a1e12fbf6b2ad9f4c14c9999fdbdf15f2478d210f0fd501ddc98/nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73", size = 40526 },
]
[[package]] [[package]]
name = "nvidia-nccl-cu12" name = "nvidia-nccl-cu12"
version = "2.21.5" version = "2.21.5"
@ -1411,6 +1308,7 @@ name = "nvidia-nvjitlink-cu12"
version = "12.4.127" version = "12.4.127"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 },
{ url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 },
] ]
@ -1419,6 +1317,7 @@ name = "nvidia-nvtx-cu12"
version = "12.4.127" version = "12.4.127"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 },
{ url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
] ]
@ -2653,6 +2552,7 @@ dependencies = [
{ name = "grpcio" }, { name = "grpcio" },
{ name = "grpcio-reflection" }, { name = "grpcio-reflection" },
{ name = "grpcio-status" }, { name = "grpcio-status" },
{ name = "hf-kernels" },
{ name = "hf-transfer" }, { name = "hf-transfer" },
{ name = "loguru" }, { name = "loguru" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
@ -2678,9 +2578,6 @@ dependencies = [
accelerate = [ accelerate = [
{ name = "accelerate" }, { name = "accelerate" },
] ]
attention = [
{ name = "attention-kernels" },
]
bnb = [ bnb = [
{ name = "bitsandbytes" }, { name = "bitsandbytes" },
] ]
@ -2695,16 +2592,6 @@ gen = [
{ name = "grpcio-tools" }, { name = "grpcio-tools" },
{ name = "mypy-protobuf" }, { name = "mypy-protobuf" },
] ]
marlin = [
{ name = "marlin-kernels", version = "0.3.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl" }, marker = "python_full_version == '3.10.*'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl" }, marker = "python_full_version == '3.11.*'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl" }, marker = "python_full_version == '3.12.*'" },
{ name = "marlin-kernels", version = "0.3.7", source = { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl" }, marker = "python_full_version < '3.10'" },
]
moe = [
{ name = "moe-kernels" },
]
outlines = [ outlines = [
{ name = "outlines" }, { name = "outlines" },
] ]
@ -2719,7 +2606,6 @@ quantize = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "accelerate", marker = "extra == 'accelerate'", specifier = ">=1.2.1,<2" }, { name = "accelerate", marker = "extra == 'accelerate'", specifier = ">=1.2.1,<2" },
{ name = "attention-kernels", marker = "extra == 'attention'", url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" },
{ name = "bitsandbytes", marker = "extra == 'bnb'", specifier = ">=0.45.0" }, { name = "bitsandbytes", marker = "extra == 'bnb'", specifier = ">=0.45.0" },
{ name = "compressed-tensors", marker = "extra == 'compressed-tensors'", specifier = ">=0.9.0" }, { name = "compressed-tensors", marker = "extra == 'compressed-tensors'", specifier = ">=0.9.0" },
{ name = "datasets", marker = "extra == 'quantize'", specifier = ">=2.21,<3" }, { name = "datasets", marker = "extra == 'quantize'", specifier = ">=2.21,<3" },
@ -2730,14 +2616,9 @@ requires-dist = [
{ name = "grpcio-status", specifier = ">=1.67.0" }, { name = "grpcio-status", specifier = ">=1.67.0" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" }, { name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" },
{ name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" }, { name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" },
{ name = "hf-kernels", specifier = ">=0.1.5" },
{ name = "hf-transfer", specifier = ">=0.1.8" }, { name = "hf-transfer", specifier = ">=0.1.8" },
{ name = "loguru", specifier = ">=0.7.3" }, { name = "loguru", specifier = ">=0.7.3" },
{ name = "marlin-kernels", marker = "python_full_version == '3.9.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl" },
{ name = "marlin-kernels", marker = "(python_full_version < '3.9' and extra == 'marlin') or (python_full_version >= '3.13' and extra == 'marlin')" },
{ name = "marlin-kernels", marker = "python_full_version == '3.10.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl" },
{ name = "marlin-kernels", marker = "python_full_version == '3.11.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl" },
{ name = "marlin-kernels", marker = "python_full_version == '3.12.*' and extra == 'marlin'", url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl" },
{ name = "moe-kernels", marker = "extra == 'moe'", url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl" },
{ name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" }, { name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" },
{ name = "numpy", specifier = ">=1.26,<3" }, { name = "numpy", specifier = ">=1.26,<3" },
{ name = "opentelemetry-api", specifier = ">=1.27.0" }, { name = "opentelemetry-api", specifier = ">=1.27.0" },
@ -2919,7 +2800,7 @@ name = "triton"
version = "3.1.0" version = "3.1.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "filelock" }, { name = "filelock", marker = "python_full_version < '3.13'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 }, { url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 },