From 58934c8b61e9f3cb7316c9e61ce528819e354853 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 16 May 2025 11:48:58 -0400 Subject: [PATCH 1/9] fix: count gpu uuids if NVIDIA_VISIBLE_DEVICES env set to all (#3230) --- launcher/src/main.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a82ad12f..ee80eb00 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option { let devices = match env::var("CUDA_VISIBLE_DEVICES") { Ok(devices) => devices, Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { - Ok(devices) => devices, + Ok(devices) => { + if devices.trim() == "all" { + // Count the number of all GPUs via nvidia-smi + let output = Command::new("nvidia-smi") + .args(["--query-gpu=uuid", "--format=csv,noheader"]) + .output() + .ok()?; + + String::from_utf8_lossy(&output.stdout) + .lines() + .filter(|line| !line.trim().is_empty()) + .count() + .to_string() + } else { + devices + } + } Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, }, }; From d658b5def3fe6c32b09b4ffe36f770ba2aa959b4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 19 May 2025 22:36:39 +0800 Subject: [PATCH 2/9] Deepseek R1 for Gaudi backend (#3211) Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 5 +- .../server/text_generation_server/cli.py | 10 +- .../text_generation_server/layers/__init__.py | 2 + .../layers/attention/__init__.py | 5 +- .../layers/attention/hpu.py | 110 +++++++- .../layers/attention/kv_cache.py | 80 +++++- .../text_generation_server/layers/fp8.py | 243 ++++++++++++++++-- .../layers/gptq/__init__.py | 26 +- .../layers/layernorm.py | 17 +- .../text_generation_server/layers/moe/fp8.py | 143 +++++++++-- .../layers/moe/fused_moe.py | 82 +++++- .../layers/moe/unquantized.py | 28 +- .../text_generation_server/layers/rotary.py | 4 +- .../text_generation_server/models/__init__.py | 10 +- .../custom_modeling/flash_cohere_modeling.py | 8 +- .../custom_modeling/flash_dbrx_modeling.py | 7 +- .../flash_deepseek_v2_modeling.py | 6 + .../flash_deepseek_v3_modeling.py | 157 ++++++++--- .../custom_modeling/flash_gemma2_modeling.py | 7 + .../custom_modeling/flash_gemma_modeling.py | 6 + .../custom_modeling/flash_gpt2_modeling.py | 7 + .../custom_modeling/flash_gptj_modeling.py | 6 + .../custom_modeling/flash_llama_modeling.py | 7 +- .../custom_modeling/flash_mistral_modeling.py | 6 + .../custom_modeling/flash_mixtral_modeling.py | 6 + .../custom_modeling/flash_neox_modeling.py | 6 + .../custom_modeling/flash_phi_modeling.py | 6 + .../custom_modeling/flash_phi_moe_modeling.py | 1 - .../custom_modeling/flash_qwen2_modeling.py | 6 + .../custom_modeling/flash_rw_modeling.py | 6 + .../flash_santacoder_modeling.py | 6 + .../flash_starcoder2_modeling.py | 6 + .../models/flash_causal_lm.py | 157 +++++++---- .../models/flash_vlm_causal_lm.py | 21 +- .../models/mllama_causal_lm.py | 44 +++- .../server/text_generation_server/server.py | 4 + .../text_generation_server/utils/dist.py | 2 +- .../utils/import_utils.py | 18 +- .../utils/quantization.py | 65 +++-- .../text_generation_server/utils/weights.py | 15 ++ backends/v3/src/queue.rs | 20 ++ 41 files changed, 1133 insertions(+), 238 deletions(-) diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 06073fe4..54a0bb7c 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -60,6 +60,8 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor ENV ATTENTION=default ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 +ENV PT_HPU_LAZY_MODE=1 +ENV PT_HPU_WEIGHT_SHARING=0 # Text Generation Inference base env ENV HF_HOME=/data \ @@ -95,7 +97,8 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir -RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git +RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794 + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 53837ef7..b1a41534 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -26,6 +26,11 @@ class Dtype(str, Enum): bloat16 = "bfloat16" +class KVCacheDtype(str, Enum): + fp8_e4m3fn = "fp8_e4m3fn" + fp8_e5m2 = "fp8_e5m2" + + @app.command() def serve( model_id: str, @@ -34,6 +39,7 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, + kv_cache_dtype: Optional[KVCacheDtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -93,7 +99,8 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = "bfloat16" if dtype is None else dtype.value - logger.info(f"quantize={quantize}") + kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value + logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}") if dtype is not None and quantize not in { None, "bitsandbytes", @@ -175,6 +182,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, uds_path, max_input_tokens, diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py index 0000ca91..fd146728 100644 --- a/backends/gaudi/server/text_generation_server/layers/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/__init__.py @@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d +from text_generation_server.layers.fp8 import Fp8Linear from text_generation_server.layers.lora import ( LoraLinear, @@ -27,6 +28,7 @@ __all__ = [ "TensorParallelEmbedding", "SpeculativeHead", "LoraLinear", + "Fp8Linear", "TensorParallelMultiAdapterLinear", "TensorParallelAdapterRowLinear", "load_layer_norm", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 89a43d65..370e05bc 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -10,18 +10,21 @@ from .hpu import ( SUPPORTS_WINDOWING, attention, paged_attention, + paged_attention_mla, ) # KVCache needs `reshape_and_cache`, so ensure that it is defined already. -from .kv_cache import KVCache, get_kv_scales +from .kv_cache import KVCache, get_kv_scales, KVCompressCache __all__ = [ "attention", "get_kv_scales", "paged_attention", + "paged_attention_mla", "SUPPORTS_WINDOWING", "KVCache", + "KVCompressCache", "Seqlen", "HPUPagedAttentionMetadata", "trim_seqlen_metadata", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1d73dcb3..1c2e37c7 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -11,11 +11,61 @@ import os SUPPORTS_WINDOWING = False -def fetch_from_cache(cache, blocks): - if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": - return cache[: blocks.size(0)] - else: - return cache.index_select(0, blocks) +class FP8Matmul(torch.nn.Module): + + def __init__(self, scale_other): + super().__init__() + self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu") + self.scale_other = scale_other + + def quant_input(self, x, scale): + return torch.ops.hpu.cast_to_fp8_v2( + x, scale, False, False, torch.float8_e4m3fn + )[0] + + def matmul_fp8( + self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None + ): + return torch.ops.hpu.fp8_gemm_v2( + A=x, + trans_A=False, + B=other, + trans_B=False, + D=None, + out_dtype=out_dtype, + A_scale_inv=scale_input_inv, + B_scale_inv=scale_other_inv, + bias=None, + accumulate=False, + ) + + def forward(self, input, other): + qinput = self.quant_input(input, self.scale_input) + qother = self.quant_input(other, self.scale_other) + output = self.matmul_fp8( + qinput, + qother, + out_dtype=torch.bfloat16, + scale_input_inv=1.0 / self.scale_input, + scale_other_inv=1.0 / self.scale_other, + ) + return output + + +class FetchFromCache(torch.nn.Module): + + def __init__(self, scale_inv): + super().__init__() + self.scale_inv = scale_inv + + def forward(self, cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + out = cache[: blocks.size(0)] + else: + out = cache.index_select(0, blocks) + if out.dtype == torch.float8_e4m3fn: + out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16) + return out def attention( @@ -67,6 +117,7 @@ def paged_attention( hpu_attention_meta: HPUPagedAttentionMetadata, ): batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn output = ops.flat_pa( query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, @@ -76,19 +127,50 @@ def paged_attention( block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, scale=softmax_scale, - matmul_qk_op=Matmul(), - matmul_av_op=Matmul(), + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), batch2block_matmul_op=Matmul(), block2batch_matmul_op=Matmul(), - keys_fetch_func=fetch_from_cache, - values_fetch_func=fetch_from_cache, + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu), ) # Reshape the output tensor. return output.view(batch_size, head_num, head_size) -__all__ = [ - "SUPPORTS_WINDOWING", - "attention", - "paged_attention", -] +def paged_attention_mla( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, + kv_lora_rank: int = 0, +): + batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn + output = ops.flat_pa_mla( + query=query, + key_cache=kv_cache.key, + value_cache=None, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_groups=hpu_attention_meta.block_groups, + scale=softmax_scale, + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=None, + kv_lora_rank=kv_lora_rank, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, -1) + + +__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index d238cdb9..cdd1e1d7 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -50,6 +50,8 @@ class KVCache: ): """Construct the key-value cache for a layer.""" ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") self.kv_cache = ( torch.zeros( @@ -101,22 +103,92 @@ class KVCache: key_cache, value_cache, slots, - kv_scales.key_scale_cpu, - kv_scales.value_scale_cpu, + kv_scales.key_scale, + kv_scales.value_scale, ) +class KVCompressCache(KVCache): + """ + Key-value cache for attention layers. + """ + + kv_cache: torch.Tensor + + def __init__( + self, + *, + num_blocks: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") + + self.kv_cache = torch.zeros( + (num_blocks, BLOCK_SIZE, 1, head_size), + dtype=dtype, + device=device, + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache.dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + + block_idx = slots // BLOCK_SIZE + block_offset = slots % BLOCK_SIZE + if self.kv_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, kv_scales.key_scale, False, False, torch.float8_e4m3fn + )[0] + cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset) + + def paged_reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ): block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE + if key_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, k_scale, False, False, torch.float8_e4m3fn + )[0] + value = torch.ops.hpu.cast_to_fp8_v2( + value, v_scale, False, False, torch.float8_e4m3fn + )[0] cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 0dc5cdaf..44d30202 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -12,11 +12,151 @@ from text_generation_server.utils.weights import ( from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 -import habana_frameworks.torch.utils.experimental as htexp -w8a8_block_fp8_matmul = None -per_token_group_quant_fp8 = None quant_dtype: torch.dtype = torch.float8_e4m3fn +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max +if is_hpu_gaudi2(): + FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max + + +def pad_weight(weight, block_size): + """Pads a matrix to make its dimensions multiples of block_size.""" + M, N = weight.shape[-2:] + block_size_m, block_size_n = block_size + pad_M = (block_size_m - M % block_size_m) % block_size_m + pad_N = (block_size_n - N % block_size_n) % block_size_n + + if pad_M == 0 and pad_N == 0: + return weight, M, N # No padding needed + padded_weight = torch.nn.functional.pad( + weight, (0, pad_N, 0, pad_M), mode="constant", value=0 + ) + return padded_weight, M, N # Return original dimensions for unpadding + + +def unpad_weight(weight, original_M, original_N, keep_first_dim=False): + """Removes padding from the matrix to restore its original shape.""" + if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N): + return weight + if keep_first_dim: + return weight[:, :original_M, :original_N] + else: + return weight[:original_M, :original_N] + + +def pad_block_fp8_weight_naive(weight, weight_scale, block_size): + + assert len(block_size) == 2 + + block_size_m, block_size_n = block_size + weight_scale_m, weight_scale_n = weight_scale.shape[-2:] + + weight, orig_M, orig_N = pad_weight(weight, block_size) + M, N = weight.shape[-2:] + + assert weight_scale_m == M // block_size_m + assert weight_scale_n == N // block_size_n + + return weight, orig_M, orig_N + + +def dynamic_quant(data, single_scale=False): + if single_scale: + scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX + else: + scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX + scale = scale.unsqueeze(-1) + data_fp8 = torch.ops.hpu.cast_to_fp8_v2( + data, 1.0 / scale, False, False, torch.float8_e4m3fn + )[0] + return data_fp8, scale.float() + + +def dequant_block_fp8_weight_naive( + weight, + weight_scale, + block_size, + dtype=torch.bfloat16, + original_M=None, + original_N=None, + do_unpad=False, +): + if weight_scale is None: + return weight + assert len(block_size) == 2 + + weight_shape_len = len(weight.shape) + + block_size_m, block_size_n = block_size + + # mul scale + if weight_shape_len == 2: + weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = False + elif weight_shape_len == 3: + fd, weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view( + fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n + ) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = True + else: + raise ValueError("Only support original weight shape is either 2 or 3") + + if do_unpad: + dequant_weight = unpad_weight( + dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim + ) + + return dequant_weight + + +def apply_block_fp8_linear_hpu_dynamic( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + x_fp8, x_scale = dynamic_quant(input_2d) + + output = torch.ops.hpu.fp8_gemm_v2( + x_fp8, + False, + weight, + True, + None, + torch.bfloat16, + x_scale, + weight_scale, + None, + False, + ) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: @@ -42,7 +182,7 @@ def per_tensor_dequantize( ) -> torch.Tensor: device = tensor.device dtype = torch.bfloat16 - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + if is_hpu_gaudi2(): # dequant on cpu to avoid nan on gaudi2 tensor = tensor.to("cpu") @@ -269,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_tensor(f"{p}.weight_scale_inv", to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1) + for p in prefixes + ] + scale = torch.cat(scale, dim=0).reshape(-1) + + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) + for p in prefixes + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch @@ -389,6 +589,22 @@ class Fp8Linear(torch.nn.Module): scale_upper_bound = kwargs.get("scale_upper_bound", None) weight_block_size = kwargs.get("weight_block_size", None) + if weight_block_size is not None: + weight, orig_M, orig_N = pad_block_fp8_weight_naive( + weight, scale, weight_block_size + ) + weight, scale = dynamic_quant( + dequant_block_fp8_weight_naive( + weight, + scale, + weight_block_size, + original_M=orig_M, + original_N=orig_N, + do_unpad=True, + ) + ) + scale = scale.squeeze(-1) + return cls( qweight=weight, scale=scale, @@ -409,25 +625,10 @@ class Fp8Linear(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight_block_size is not None: - # https://arxiv.org/pdf/2412.19437 - # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and - # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we - # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output - # channels). - qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) - output = w8a8_block_fp8_matmul( - qinput, - self.qweight, - scale, - self.scale, - self.weight_block_size, - output_dtype=input.dtype, + return apply_block_fp8_linear_hpu_dynamic( + input, self.qweight, self.scale, self.input_scale, self.bias ) - if self.bias is not None: - output = output + self.bias - return output.to(dtype=input.dtype) - qinput, scale = fp8_quantize( input, self.input_scale, diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 90b8f692..babf3d4b 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -4,7 +4,12 @@ from typing import List, Optional, Union import torch from loguru import logger from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from text_generation_server.utils.weights import ( + Weight, + Weights, + WeightsLoader, + DefaultWeightsLoader, +) from .hpu import QuantLinear @@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, + modules_to_not_convert: List[str], ): self.bits = bits self.desc_act = desc_act @@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.modules_to_not_convert = modules_to_not_convert + + def is_layer_skipped_quantization( + self, prefix: str, modules_to_not_convert: List[str] + ): + return any(module_name in prefix for module_name in modules_to_not_convert) def get_weights(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights(weights, prefix) + try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_col_packed( + weights, prefix, block_sizes + ) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader): if self.bits != 4: use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_row(weights, prefix) + if self.desc_act: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index 84878791..4bbb6c1f 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - from vllm_hpu_extension.kernels import rms_norm - - orig_shape = hidden_states.shape if residual is not None: - residual += hidden_states.view(residual.shape) - else: - residual = hidden_states - # Note: HPUFusedRMSNorm requires 3D tensors as inputs - if len(orig_shape) == 2: - residual = residual.unsqueeze(0) - x = rms_norm().apply(residual, self.weight, self.variance_epsilon) - return x.view(orig_shape), residual.view(orig_shape) + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(self.weight.dtype), residual diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py index 071b2abe..5365f24f 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.nn as nn +import os from text_generation_server.utils.weights import Weights from text_generation_server.layers.fp8 import ( @@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import ( fp8_quantize, quant_dtype, normalize_e4m3fn_to_native_float8, + dynamic_quant, + dequant_block_fp8_weight_naive, ) - -try: - from .unquantized import fused_moe -except Exception: - fused_moe = None +from text_generation_server.layers.moe.fused_moe import select_experts +import habana_frameworks.torch as htorch class FP8SparseMoELayer(nn.Module): @@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module): self.weight_block_size = weights.weights_loader.weight_block_size self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.ep_rank = self.rank + self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true" + + if self.use_ep: + n_experts = (n_experts + self.world_size - 1) // self.world_size + self.ep_offset = self.ep_rank * n_experts + else: + self.ep_offset = 0 ( self.gate_up_proj, @@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module): gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( @@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module): n_experts=n_experts, name=down_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) ) + if self.weight_block_size is not None: + self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.weight_block_size, + ) + ) + self.down_proj, self.down_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.down_proj, self.down_proj_weight_scale, self.weight_block_size + ) + ) + self.gate_up_proj_weight_scale, self.down_proj_weight_scale = ( + self.gate_up_proj_weight_scale.squeeze(-1), + self.down_proj_weight_scale.squeeze(-1), + ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=gating_output, use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, + top_k=self.topk, + renormalize=self.renormalize, topk_group=self.topk_group, + num_expert_group=self.n_expert_group, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, - use_fp8_w8a8=True, - w1_scale=self.gate_up_proj_weight_scale, - w2_scale=self.down_proj_weight_scale, - a1_scale=self.gate_up_proj_input_scale, - a2_scale=self.down_proj_input_scale, ) + total_num_experts = gating_output.size(-1) + x_fp8, x_scale = dynamic_quant(x, single_scale=True) + + if self.use_ep: + moe_n_slice = 1 + n_expert_slice = ( + total_num_experts + self.world_size - 1 + ) // self.world_size + else: + moe_n_slice = 1 + n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice + for i in range(moe_n_slice): + min_expert = i * n_expert_slice + max_expert = min((i + 1) * n_expert_slice, total_num_experts) + w13_list_slice = [ + self.gate_up_proj[j, ...] for j in range(min_expert, max_expert) + ] + w2_list_slice = [ + self.down_proj[j, ...] for j in range(min_expert, max_expert) + ] + w13_weight_scale = [ + self.gate_up_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + w2_weight_scale = [ + self.down_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + + current_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=x_fp8, + expert_routing_table=topk_ids.to(torch.int64), + router_weights=topk_weights.to(x.dtype), + w12=w13_list_slice, + w3=w2_list_slice, + d_scale_hidden_states=x_scale, + d_scale_w12=w13_weight_scale, + d_scale_w3=w2_weight_scale, + permuted_weights=True, + activation="silu", + experts_min=min_expert + self.ep_offset, + experts_max=max_expert + self.ep_offset - 1, + ) + htorch.core.mark_step() + if i == 0: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + return final_hidden_states def _load_expert_weights( @@ -98,13 +170,14 @@ def _load_expert_weights( n_experts: int, name: str, weights: Weights, + ep_offset: int = 0, ) -> torch.Tensor: all_weight = None all_weight_scales = None max_input_scale = None for i in range(n_experts): - weight = get_weight_fn(prefix, i, name, weights) + weight = get_weight_fn(prefix, i + ep_offset, name, weights) assert isinstance(weight, Fp8Weight) @@ -147,14 +220,26 @@ def _load_expert_multi_weights_col( gate_proj_name: str, up_proj_name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_multi_weights_col( [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 ) + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=None, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) @@ -164,10 +249,20 @@ def _load_expert_weights_row( n_experts: int, name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_weights_row(f"{prefix}.{i}.{name}") + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights(f"{prefix}.{i}.{name}") + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=name, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py index e26ff877..1987f0ed 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Optional import torch @@ -25,12 +25,36 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ 1 ] # [n, top_k_group] @@ -41,13 +65,19 @@ def grouped_topk( .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .reshape(num_token, -1) ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def fused_topk( @@ -63,3 +93,39 @@ def fused_topk( if renormalize: topk_weights /= topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +): + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index ec158398..58709ec3 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn from text_generation_server.utils.weights import UnquantizedWeight, Weights -from vllm_hpu_extension.ops import DynamicFusedMOE +from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp +import habana_frameworks.torch as htorch +import torch.nn.functional as F class UnquantizedSparseMoELayer(nn.Module): @@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) - self.hpu_fused_moe = DynamicFusedMOE(n_experts) + self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1) for i in range(n_experts): - self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) - self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) + self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.MoeOp.w2_list[i].set_weight(self.down_proj[i]) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return self.hpu_fused_moe(x, gating_output, self.topk) + htorch.core.mark_step() + routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.topk, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + final_hidden_states = self.MoeOp( + hidden_states=x, + expert_routing_table=selected_experts, + router_weights=routing_weights, + permuted_weights=True, + activation="silu", + ) + + return final_hidden_states.view(-1, x.shape[1]) def _load_expert_multi_weights_col( diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 6a83d6a5..7e740e5f 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__( - inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor - ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): / get_mscale(self.scaling_factor, mscale_all_dim) * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation + super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 6ca7b567..a9a1d0b7 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -343,6 +343,7 @@ def get_model( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, ) -> Model: @@ -468,7 +469,12 @@ def get_model( model_type = config_dict["model_type"] - kv_cache_dtype = dtype + if kv_cache_dtype == "fp8_e4m3fn": + kv_cache_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + kv_cache_dtype = torch.float8_e5m2 + else: + kv_cache_dtype = dtype if FLASH_ATTENTION: if model_type == DEEPSEEK_V2: @@ -934,6 +940,7 @@ def get_model_with_lora_adapters( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, adapter_to_index: Dict[str, int], @@ -947,6 +954,7 @@ def get_model_with_lora_adapters( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, max_input_tokens, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 3bcc689d..801ae09e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import ( apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch + class CohereRotary(PositionRotaryEmbedding): def forward( @@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None - + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 15c243c9..76972d38 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import ( FastLayerNorm, ) from vllm_hpu_extension.ops import DynamicFusedMOE +import habana_frameworks.torch as htorch class DbrxAttentionConfig(PretrainedConfig): @@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) - residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 9d61c694..6ac7fc1a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch class DeepseekV2Config(PretrainedConfig): @@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 1a7ce5cf..e0481691 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -28,11 +28,12 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, get_linear, + Fp8Linear, ) from text_generation_server.layers.attention import ( Seqlen, attention, - paged_attention, + paged_attention_mla, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -40,6 +41,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch + + +def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor: + if isinstance(layer, Fp8Linear): + eye = torch.eye( + layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device + ) + dequant_weights = layer(eye) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight class DeepseekV3Config(PretrainedConfig): @@ -249,6 +263,44 @@ class DeepseekV3Attention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _q_proj_and_k_up_proj(self, x): + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + q_nope, q_pe = ( + q_proj(x) + .view(-1, self.num_heads, self.head_size) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def _v_up_proj_and_o_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size) + return self.o_proj(x) + def forward( self, hidden_states: torch.Tensor, @@ -261,14 +313,9 @@ class DeepseekV3Attention(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: - query = self.q_proj(hidden_states) + hidden_states_or_q_c = hidden_states else: - query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) - query = query.view(-1, self.num_heads, self.head_size) - - _, query_pe = torch.split( - query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, key_pe = torch.split( @@ -276,13 +323,18 @@ class DeepseekV3Attention(torch.nn.Module): ) key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( - -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size - ) + kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0] - key_nope, value = torch.split( - kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 - ) + # Prefill + if cu_seqlen_prefill is not None: + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + query = q_proj(hidden_states_or_q_c) + query = query.view(-1, self.num_heads, self.head_size) + query_nope, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + else: + query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c) batch_size, heads, head_dim = query_pe.shape query_pe = ( @@ -297,33 +349,47 @@ class DeepseekV3Attention(torch.nn.Module): .reshape(batch_size, heads, head_dim) ) self.rotary_emb(query_pe, key_pe, cos, sin) + latent_vec_k = torch.concat( + (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1 + ) + latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank) - query[..., self.qk_nope_head_dim :] = query_pe - key = torch.empty_like(query) - key[..., : self.qk_nope_head_dim] = key_nope - key[..., self.qk_nope_head_dim :] = key_pe - - # We need to pad the heads because Flash Attention does not support - # qk and v with different head sizes. - query = torch.nn.functional.pad( - query, (0, self.head_pad_size - self.head_size), value=0 - ) - key = torch.nn.functional.pad( - key, (0, self.head_pad_size - self.head_size), value=0 - ) - value = torch.nn.functional.pad( - value, (0, self.head_pad_size - self.value_head_size), value=0 - ) + latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1)) kv_cache.store( - key=key, - value=value, + key=latent_vec_k, + value=None, slots=slots, kv_scales=self.kv_scales, ) - # Prefill if cu_seqlen_prefill is not None: + kv = self.kv_b_proj(kv_c_normed).view( + -1, + self.num_key_value_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + # flash attention attn_output = attention( query=query, @@ -334,9 +400,15 @@ class DeepseekV3Attention(torch.nn.Module): seqlen=seqlen, softmax_scale=self.softmax_scale, ) - # Decode + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) else: - attn_output = paged_attention( + # Decode + query = torch.cat([query_nope, query_pe], dim=-1) + attn_output = paged_attention_mla( query, kv_cache, self.kv_head_mapping, @@ -344,14 +416,10 @@ class DeepseekV3Attention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + kv_lora_rank=self.kv_lora_rank, ) - - # Remove padding. - attn_output = attn_output[..., : self.value_head_size] - - return self.o_proj( - attn_output.reshape(-1, self.num_heads * self.value_head_size) - ) + attn_output = self._v_up_proj_and_o_proj(attn_output) + return attn_output class DeepseekV3MLP(nn.Module): @@ -584,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -596,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 79f21b0f..a5860823 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Gemma2Config(PretrainedConfig): @@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 609f03ac..3d678df1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GemmaConfig(PretrainedConfig): @@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 10024a6d..ed413662 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales +import habana_frameworks.torch as htorch def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module): hidden_states = inputs_embeds residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 41eeab78..cde03a00 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix: str, weights): @@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 81af5560..0edea03a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,7 +26,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN - +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( KVCache, get_kv_scales, @@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module): cross_attention_states, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d23d4f67..75d9d360 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch class MistralConfig(PretrainedConfig): @@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 1ef6be48..f47986d8 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class MixtralConfig(PretrainedConfig): @@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33f63333..29620826 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GPTNeoXConfig(TransformersGPTNeoXConfig): @@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.final_layer_norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 0c777912..12830991 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +import habana_frameworks.torch as htorch class PhiConfig(PretrainedConfig): @@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py index bb585cc4..c28f3aee 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py @@ -18,7 +18,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index af4b404d..7c7ac03e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix, weights): @@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module): ) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 141e13a6..76a2cd01 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -21,6 +21,7 @@ from text_generation_server.layers.attention import ( Seqlen, HPUPagedAttentionMetadata, ) +import habana_frameworks.torch as htorch def load_row(config, prefix: str, weights, bias: bool): @@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel): cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, @@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b68f4784..c64b2ff7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +import habana_frameworks.torch as htorch def load_multi_mqa( @@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 76f6f473..94c60eb6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Starcoder2Config(PretrainedConfig): @@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ad585172..eb0f7454 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -53,6 +53,7 @@ from text_generation_server.models.globals import ( ) from text_generation_server.layers.attention import ( KVCache, + KVCompressCache, Seqlen, HPUPagedAttentionMetadata, trim_attn_metadata, @@ -68,11 +69,13 @@ from text_generation_server.utils.import_utils import ( synchronize, get_free_memory, ) - +from text_generation_server.utils.prefill_chunking import ( + get_max_prefill_tokens, +) import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing.common import get_bucketing_context tracer = trace.get_tracer(__name__) @@ -153,7 +156,7 @@ def prepare_for_decode( block_groups_device, num_classes=batch_size ) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) + mask = mask >= block_usage_device.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) return trim_attn_metadata( HPUPagedAttentionMetadata( @@ -425,7 +428,9 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device - all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -1438,15 +1443,17 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" ) - self.limit_hpu_graphs = ( - os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" + self.limit_hpu_graph = ( + os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true" ) + self.max_seq_len_to_capture = 8192 super().__init__( model_id=model_id, model=model, @@ -1478,16 +1485,27 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - self.kv_cache = [ - KVCache( - num_blocks=num_blocks, - num_heads=num_heads, - head_size=head_size, - dtype=dtype, - device=device, - ) - for _ in range(num_layers) - ] + if self.config.model_type == "deepseek_v3": + self.kv_cache = [ + KVCompressCache( + num_blocks=num_blocks, + head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] def warmup( self, @@ -1495,6 +1513,11 @@ class FlashCausalLM(Model): max_input_tokens: Optional[int], max_total_tokens: Optional[int], ): + if os.environ.get("MAX_BATCH_SIZE") is None: + raise RuntimeError( + "MAX_BATCH_SIZE is not set, it should be set in the launcher " + "using `--max-batch-size xxx`" + ) # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() @@ -1502,8 +1525,14 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + if self.config.model_type == "deepseek_v3": + cache_block_size = BLOCK_SIZE * ( + self.config.kv_lora_rank + self.config.qk_rope_head_dim + ) + else: + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + cache_block_size = cache_block_size * 2 + total_cache_size = self.num_layers * cache_block_size * dtype_size try: self.init_kv_cache( @@ -1563,25 +1592,33 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - - max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) - if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: - os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) - if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: - max_total_blocks = ( - math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 - ) - os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks) - + self.max_batch_prefill_tokens = get_max_prefill_tokens() + max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) + HPUBucketingContext = get_bucketing_context() + max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE + model_max_length = self.tokenizer.model_max_length + max_position_embeddings = getattr( + self.config, "max_position_embeddings", model_max_length + ) self.bucketing_ctx = HPUBucketingContext( max_num_seqs, - os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO + max_num_seqs, # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, - num_blocks * BLOCK_SIZE, + max_num_seqs * max_total_tokens_aligned, False, + min(model_max_length, max_position_embeddings), + max_input_tokens, + max_total_tokens_aligned, ) - self.bucketing_ctx.num_hpu_blocks = num_blocks + max_blocks = ( + max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1 + ) + self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks) if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets( + self.bucketing_ctx.num_hpu_blocks + ) logger.info("skip warmup hpu graph, not recommmended") del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens @@ -1591,28 +1628,55 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture): + if self.limit_hpu_graph: + return prefill + else: + return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture + def warmup_hpu_graph(self, batch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + + for i, (batch_size, seq_len) in enumerate(buckets): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue + warmup_shape_count += 1 log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + + def ordering_function_max_bs(b): + return (-b[0], b[1]) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch @@ -1643,7 +1707,9 @@ class FlashCausalLM(Model): lm_head_indices = input_lengths - 1 kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + True, input_ids.shape[0] + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1792,8 +1858,8 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + batch.prefilling, input_ids.shape[0] ) logits, speculative_logits = self.model.forward( @@ -1836,9 +1902,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - _async_h2d_tensor_copy( - batch.all_input_ids_tensor[:, : batch.max_current_length] - ), + batch.all_input_ids_tensor[:, : batch.max_current_length], batch.next_token_logits, speculate, batch.speculative_ids, @@ -1852,7 +1916,6 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: - next_input_ids = next_input_ids.cpu() next_token_logprobs = next_token_logprobs.cpu() accepted_ids = accepted_ids.cpu() batch.all_input_ids_tensor = batch.all_input_ids_tensor[ @@ -1902,7 +1965,6 @@ class FlashCausalLM(Model): accepted_ids = accepted_ids.cpu() cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - next_input_ids = next_input_ids.cpu() if batch.speculative_logits is not None: for i in range(len(batch)): batch.all_input_ids_tensor[ @@ -1914,7 +1976,7 @@ class FlashCausalLM(Model): ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor - index = index.to(batch.all_input_ids_tensor) + index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, batch.all_input_ids_tensor.shape[0], @@ -1924,6 +1986,7 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) + next_input_ids = next_input_ids.cpu() batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids if batch.position_ids.dim() == 2: diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index b44ae03f..d9c57f20 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -23,6 +23,7 @@ from text_generation_server.layers.attention import ( _async_h2d_tensor_copy, ) import habana_frameworks.torch as htorch +import time from text_generation_server.utils.import_utils import ( synchronize, ) @@ -486,20 +487,32 @@ class FlashVlmCausalLM(FlashCausalLM): ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 5de9bca8..0e5544f2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import ( ) import torch.nn.functional as F from text_generation_server.utils.log import log_master +import time tracer = trace.get_tracer(__name__) @@ -325,7 +326,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + True, input_ids.shape[0] + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -343,26 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + for i, (batch_size, seq_len) in enumerate(buckets): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue + warmup_shape_count += 1 log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self, @@ -438,8 +462,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + batch.prefilling, input_ids.shape[0] ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index f9250115..f5080ec3 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -206,6 +206,7 @@ def serve( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, uds_path: Path, max_input_tokens: int, @@ -218,6 +219,7 @@ def serve( quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, trust_remote_code: bool = False, ): if not is_driver_compatible(): @@ -261,6 +263,7 @@ def serve( quantize, speculate, data_type, + kv_cache_dtype, trust_remote_code, max_input_tokens, adapter_to_index, @@ -308,6 +311,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, ) ) diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 1c45713e..9866710b 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -7,7 +7,7 @@ from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) -MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9")) class FakeBarrier: diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 22560dd7..d25484d6 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,17 +1,19 @@ import torch from loguru import logger +import habana_frameworks.torch as htorch +import os def get_hpu_free_memory(device, memory_fraction): - from habana_frameworks.torch.hpu import memory_stats - - device_id = device.index - mem_stats = memory_stats(device_id) - logger.info(f"mem_stats: {mem_stats}") - total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] - free_memory = max( - 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) + graph_reserved_mem = ( + float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) + if htorch.utils.internal.is_lazy() + else 0 ) + free_memory = int( + torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) + ) + logger.info(f"Free memory on device {device}: {free_memory} bytes.") return free_memory diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index a8faf4a5..022a4897 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -1,7 +1,7 @@ import json import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, List from huggingface_hub import hf_hub_download from text_generation_server.utils.weights import ( @@ -18,6 +18,8 @@ class _QuantizerConfig: groupsize: int quant_method: str sym: bool + weight_block_size: Optional[List[int]] + modules_to_not_convert: List[str] @dataclass @@ -25,7 +27,20 @@ class _FP8QuantizerConfig: activation_scale_ub: float -# We should probably do this with Pytantic JSON deserialization, +def _get_config_json(model_id: str, revision: Optional[str], filename: str): + if os.path.exists( + os.path.join( + model_id, + ) + ): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + return json.load(f) + + +# We should probably do this with Pydantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): bits = 4 @@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision): checkpoint_format = None sym = False desc_act = False + weight_block_size = None + modules_to_not_convert = [] filename = "config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download(model_id, filename=filename, revision=revision) - with open(filename, "r") as f: - data = json.load(f) - + data = _get_config_json(model_id, revision, filename) # FP8 config if data["quantization_config"]["quant_method"] == "fbgemm_fp8": return _FP8QuantizerConfig( activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + weight_block_size = data["quantization_config"].get("weight_block_size", None) if "zero_point" in data["quantization_config"]: sym = not data["quantization_config"]["zero_point"] @@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision): # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + desc_act = data["quantization_config"].get("desc_act", False) + modules_to_not_convert = data["quantization_config"].get( + "modules_to_not_convert", [] + ) + if modules_to_not_convert is None: + modules_to_not_convert = [] except Exception: filename = "quantize_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["bits"] groupsize = data["group_size"] @@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision): except Exception: filename = "quant_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["w_bit"] groupsize = data["q_group_size"] desc_act = data["desc_act"] @@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision): checkpoint_format=checkpoint_format, sym=sym, desc_act=desc_act, + weight_block_size=weight_block_size, + modules_to_not_convert=modules_to_not_convert, ) @@ -134,6 +139,7 @@ def get_loader( quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, + modules_to_not_convert=quantizer_config.modules_to_not_convert, ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader @@ -141,9 +147,14 @@ def get_loader( # Since the default for the quantize config is _QuantizerConfig, # we need to add this check to not get an attribute error activation_scale_ub = None + weight_block_size = quantizer_config.weight_block_size if isinstance(quantizer_config, _FP8QuantizerConfig): activation_scale_ub = quantizer_config.activation_scale_ub - return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + return HybridFP8UnquantLoader( + activation_scale_ub, + to_fp8=quantize == "fp8", + weight_block_size=weight_block_size, + ) else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index dec22942..4edae0d4 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -62,6 +62,14 @@ class WeightsLoader(ABC): """ ... + @abstractmethod + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + @abstractmethod def get_weights_row(self, weights: "Weights", prefix: str): """ @@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader): weights.get_sharded(f"{prefix}.weight", dim=1), ) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_tensor(f"{p}.weight") for p in prefixes] + return self.weight_class(torch.cat(w, dim=dim)) + class Weights: def __init__( @@ -393,6 +405,9 @@ class Weights: def get_weights_row(self, prefix: str): return self.weights_loader.get_weights_row(self, prefix) + def get_multi_weights(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights(self, prefixes, dim) + @contextmanager def use_loader(self, weights_loader: WeightsLoader): """ diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index d3bf4b9c..8cfee3a5 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -8,6 +8,7 @@ use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::usage_stats::Env; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, @@ -185,6 +186,9 @@ struct State { /// Paged Attention Block Allocation block_allocator: Option, + + /// indicate if it's hpu device, the hpu device needs padding to generate first token. + is_hpu_device: bool, } impl State { @@ -214,6 +218,7 @@ impl State { speculate, support_chunking, block_allocator, + is_hpu_device: Env::new().is_hpu_device(), } } @@ -368,6 +373,21 @@ impl State { } } + if self.is_hpu_device { + //HPU needs to pad for the prefill + max_input_length = max_input_length.max(entry.request.input_length); + let actual_prefill_tokens_for_hpu = + (batch.len() + 1) as u32 * max_input_length; + + if actual_prefill_tokens_for_hpu > prefill_token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + prefill_tokens += postfix_len; Some(block_allocation) From 000e313a92d1ccd0bab326729a551245e0079c9f Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 20 May 2025 16:22:43 +0800 Subject: [PATCH 3/9] Refine warmup and upgrade to synapse AI 1.21.0 (#3234) Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 3 +- backends/gaudi/Makefile | 2 +- .../models/flash_causal_lm.py | 182 ++++++++++++++---- .../models/flash_vlm_causal_lm.py | 53 ++++- .../models/mllama_causal_lm.py | 97 ++++++++-- .../text_generation_server/utils/dist.py | 2 +- .../utils/import_utils.py | 15 +- 7 files changed, 278 insertions(+), 76 deletions(-) diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 54a0bb7c..bd6c58b4 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -1,5 +1,5 @@ # Those arguments are required to build the image -ARG HABANA_VERSION=1.20.0 +ARG HABANA_VERSION=1.21.0 ARG PYTORCH_VERSION=2.6.0 # Rust builder @@ -62,6 +62,7 @@ ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV PT_HPU_LAZY_MODE=1 ENV PT_HPU_WEIGHT_SHARING=0 +ENV VLLM_EXPONENTIAL_BUCKETING=true # Text Generation Inference base env ENV HF_HOME=/data \ diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index c153a5ff..77581517 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) root_dir := ${mkfile_dir}/../.. -HABANA_VERSION := 1.20.0 +HABANA_VERSION := 1.21.0 PYTORCH_VERSION := 2.6.0 .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index eb0f7454..bc0d240e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -76,6 +76,7 @@ import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools from vllm_hpu_extension.bucketing.common import get_bucketing_context +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -1357,6 +1358,8 @@ class FlashCausalLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() + if world_size > 1: + self.process_group_cpu = torch.distributed.new_group(backend="gloo") device = torch.device("hpu") dtype = torch.bfloat16 if dtype is None else dtype @@ -1453,6 +1456,7 @@ class FlashCausalLM(Model): self.limit_hpu_graph = ( os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true" ) + self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true" self.max_seq_len_to_capture = 8192 super().__init__( model_id=model_id, @@ -1521,7 +1525,7 @@ class FlashCausalLM(Model): # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() - + self.graphed_buckets = set() # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() @@ -1533,7 +1537,20 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = cache_block_size * 2 total_cache_size = self.num_layers * cache_block_size * dtype_size - + free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM) + self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION)) + graph_reserved_mem = ( + float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) + if htorch.utils.internal.is_lazy() + else 0 + ) + mem_used_from_graph = int( + (free_memory - self.mem_reserved) * graph_reserved_mem + ) + log_master( + logger.info, + f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}", + ) try: self.init_kv_cache( batch.num_blocks, @@ -1548,15 +1565,6 @@ class FlashCausalLM(Model): num_tokens = batch.to_pb().current_tokens synchronize(self.device) - free_memory = get_free_memory( - self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM - ) - real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) - log_master( - logger.debug, - f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", - ) - _, _batch, _ = self.generate_token([batch]) except Exception: raise RuntimeError( @@ -1565,8 +1573,9 @@ class FlashCausalLM(Model): ) synchronize(self.device) - free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) - kv_memory = free_memory + free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM) + + kv_memory = free_memory - self.mem_reserved - mem_used_from_graph num_blocks = ( # Leave 5% for some wiggle room int(kv_memory // total_cache_size) @@ -1583,7 +1592,6 @@ class FlashCausalLM(Model): self.kv_cache = [] empty_cache() - self.init_kv_cache( num_blocks, self.num_layers, @@ -1595,11 +1603,16 @@ class FlashCausalLM(Model): self.max_batch_prefill_tokens = get_max_prefill_tokens() max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) HPUBucketingContext = get_bucketing_context() - max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE + # need to warmup one more step since block is allocated from 1 + block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE) + max_total_tokens_aligned = math.ceil( + max_total_tokens / BLOCK_SIZE + ) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs) model_max_length = self.tokenizer.model_max_length max_position_embeddings = getattr( self.config, "max_position_embeddings", model_max_length ) + self.bucketing_ctx = HPUBucketingContext( max_num_seqs, max_num_seqs, # self.max_num_prefill_seqs, #TODO @@ -1610,31 +1623,75 @@ class FlashCausalLM(Model): max_input_tokens, max_total_tokens_aligned, ) - max_blocks = ( - max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1 + max_blocks = max( + BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE ) self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks) - if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": + synchronize(self.device) + if self.skip_warmup: self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_decode_buckets( self.bucketing_ctx.num_hpu_blocks ) - logger.info("skip warmup hpu graph, not recommmended") + log_master( + logger.info, "skip warmup hpu graph, not recommmended, may cause OOM" + ) del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - self.warmup_hpu_graph(batch) del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture): - if self.limit_hpu_graph: - return prefill - else: - return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture + def log_warmup(self, prefilling, i, max_i, batch_size, seq_len): + free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory()) + phase = "Prompt" if prefilling else "Decode" + dim = "seq_len" if prefilling else "num_blocks" + graphed_bucket = (batch_size, seq_len, prefilling) + bypass = graphed_bucket not in self.graphed_buckets + msg = ( + f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"bypass:{bypass} " + f"free_mem:{free_mem}" + ) + log_master(logger.info, msg) + + def use_graphs(self, prefill, seq_len, batch_size): + if self.limit_hpu_graph and prefill: + return False + + if self.skip_warmup: + return True + + return (batch_size, seq_len, prefill) in self.graphed_buckets + + def align_workers(self, value, op): + if self.world_size <= 1: + return value + value_t = torch.tensor(value, device="cpu") + torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu) + return value_t.item() def warmup_hpu_graph(self, batch): + prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3")) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) + log_master(logger.info, msg) start_time = time.time() warmup_shape_count = 0 warmup_times = 3 @@ -1646,15 +1703,34 @@ class FlashCausalLM(Model): buckets = list( sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) ) - + total_batch_seq = 0.001 + total_mem = 0 + available_mem = prompt_available_memory for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, seq_len, True) + if not ( + mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture + ): + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") - for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size, batch) - synchronize(self.device) + self.log_warmup(True, i, len(buckets), batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -1663,16 +1739,34 @@ class FlashCausalLM(Model): buckets = list( sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) ) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + total_batch_seq = 0.001 + total_mem = 0 + available_mem = free_mem - self.mem_reserved for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", @@ -1707,8 +1801,8 @@ class FlashCausalLM(Model): lm_head_indices = input_lengths - 1 kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - True, input_ids.shape[0] + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + True, prompt_len, batch_size ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. @@ -1762,7 +1856,9 @@ class FlashCausalLM(Model): slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + False, hpu_attention_meta.block_list.shape[0], batch_size + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), @@ -1858,8 +1954,14 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - batch.prefilling, input_ids.shape[0] + batch_size = input_lengths.shape[0] + prompt_len = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, prompt_len, batch_size ) logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index d9c57f20..fd239b3e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -27,6 +27,7 @@ import time from text_generation_server.utils.import_utils import ( synchronize, ) +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -487,6 +488,19 @@ class FlashVlmCausalLM(FlashCausalLM): ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + decode_available_memory = graph_free_mem + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(decode_available_memory)} for decode " + ) + log_master(logger.info, msg) start_time = time.time() warmup_shape_count = 0 warmup_times = 3 @@ -499,16 +513,34 @@ class FlashVlmCausalLM(FlashCausalLM): buckets = list( sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) ) + total_batch_seq = 0.001 + total_mem = 0 + available_mem = decode_available_memory for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", @@ -585,8 +617,15 @@ class FlashVlmCausalLM(FlashCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = batch.prefilling - + batch_size = input_lengths.shape[0] + seqlen = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, seqlen, batch_size + ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 0e5544f2..db3904a2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -33,6 +33,8 @@ from text_generation_server.utils.import_utils import ( import torch.nn.functional as F from text_generation_server.utils.log import log_master import time +import os +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -268,6 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states, image_indices, input_lengths, 1, False ) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + False, hpu_attention_meta.block_list.shape[0], batch_size + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -281,6 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states=cross_attention_states, indices=_async_h2d_tensor_copy(indices), cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), + **kwargs, ) def warmup_prefill( @@ -326,8 +334,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - True, input_ids.shape[0] + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + True, prompt_len, batch_size ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), @@ -346,6 +354,23 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3")) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) + log_master(logger.info, msg) start_time = time.time() warmup_shape_count = 0 warmup_times = 3 @@ -357,14 +382,35 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): buckets = list( sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) ) + graph_free_mem + total_batch_seq = 0.001 + total_mem = 0 + available_mem = prompt_available_memory for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, seq_len, True) + if not ( + mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture + ): + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") - for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size, batch) - synchronize(self.device) + self.log_warmup(True, i, len(buckets), batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -373,16 +419,34 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): buckets = list( sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) ) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + total_batch_seq = 0.001 + total_mem = 0 + available_mem = free_mem - self.mem_reserved for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", @@ -462,9 +526,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - batch.prefilling, input_ids.shape[0] + batch_size = input_lengths.shape[0] + seqlen = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, seqlen, batch_size + ) + if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 9866710b..1c45713e 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -7,7 +7,7 @@ from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) -MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) class FakeBarrier: diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index d25484d6..bdcfc9fa 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,20 +1,9 @@ import torch -from loguru import logger -import habana_frameworks.torch as htorch -import os def get_hpu_free_memory(device, memory_fraction): - graph_reserved_mem = ( - float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) - if htorch.utils.internal.is_lazy() - else 0 - ) - free_memory = int( - torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) - ) - logger.info(f"Free memory on device {device}: {free_memory} bytes.") - return free_memory + free_hpu_memory, _ = torch.hpu.mem_get_info() + return free_hpu_memory def synchronize_hpu(device): From 43b1b07fb96322f02b9c6af76c6aad721ab1729a Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 20 May 2025 20:02:32 +0800 Subject: [PATCH 4/9] Fix the crash in default ATTENTION path for Gaudi backend (#3235) Signed-off-by: Wang, Yi A --- backends/gaudi/server/text_generation_server/tgi_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/gaudi/server/text_generation_server/tgi_service.py b/backends/gaudi/server/text_generation_server/tgi_service.py index 18e88a7e..12317127 100644 --- a/backends/gaudi/server/text_generation_server/tgi_service.py +++ b/backends/gaudi/server/text_generation_server/tgi_service.py @@ -31,6 +31,7 @@ def main(args): trust_remote_code=args.trust_remote_code, uds_path=args.uds_path, max_input_tokens=args.max_input_tokens, + kv_cache_dtype="auto", ) From e32528792cc7ccfdc5dd4b10fecedeb907422261 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 21 May 2025 15:44:15 +0200 Subject: [PATCH 5/9] Switch to punica-sgmv kernel from the Hub (#3236) * Switch to punica-sgmv kernel from the Hub This also switches (temporarily) to the tgi-nix/kernel-builder merge branch, bumping up to CUDA 12.8 (same as non-Nix Torch). * nix: client depends on aiohttp This probably worked before the nixpkgs bump because a dependency propagated aiohttp. --- Dockerfile | 9 - flake.lock | 16 +- flake.nix | 2 +- nix/client.nix | 2 + nix/server.nix | 4 +- server/Makefile | 1 - server/Makefile-lorax-punica | 12 - server/kernels.lock | 58 ++++ server/pyproject.toml | 1 + .../text_generation_server/adapters/lora.py | 41 +-- server/text_generation_server/layers/lora.py | 34 ++- server/text_generation_server/utils/sgmv.py | 252 ------------------ 12 files changed, 115 insertions(+), 317 deletions(-) delete mode 100644 server/Makefile-lorax-punica delete mode 100644 server/text_generation_server/utils/sgmv.py diff --git a/Dockerfile b/Dockerfile index e72d9b70..869596d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile # Build specific version of transformers RUN . .venv/bin/activate && make build-awq -# Build Lorax Punica kernels -FROM kernel-builder AS lorax-punica-builder -WORKDIR /usr/src -COPY server/Makefile-lorax-punica Makefile -# Build specific version of transformers -RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica - # Build Transformers CUDA kernels FROM kernel-builder AS custom-kernels-builder WORKDIR /usr/src @@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages # Copy build artifacts from awq kernels builder COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages -# Copy build artifacts from lorax punica kernels builder -COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages diff --git a/flake.lock b/flake.lock index 4540a736..2c6e8063 100644 --- a/flake.lock +++ b/flake.lock @@ -718,16 +718,16 @@ }, "nixpkgs_6": { "locked": { - "lastModified": 1737453259, - "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", + "lastModified": 1746711195, + "narHash": "sha256-bSpM2ySq12PBOVN7jZdzXsc99iRoYOyolh5wz43+CjQ=", "owner": "danieldk", "repo": "nixpkgs", - "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", + "rev": "6b7a66b06ccb09ac95872ac6ddf952e0660672ab", "type": "github" }, "original": { "owner": "danieldk", - "ref": "outlines-v0.1.4-tgi", + "ref": "kernel-builder-cuda-12.9.0", "repo": "nixpkgs", "type": "github" } @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1746795305, - "narHash": "sha256-4fpUT4j4w0NDKF22KvG7iGmwQTBPM5SrPEqt+N3fqF0=", + "lastModified": 1747733488, + "narHash": "sha256-LYov4H9zvqXXlFKdytcVcDioH416c+LWfyw/HWta0qw=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "359cd25f31f0f2ad2cadfbf4e180780a7a06e3c5", + "rev": "61c730990efa58e64c652bf15253aae47dd0f7dd", "type": "github" }, "original": { "owner": "huggingface", - "ref": "torch-2.7", + "ref": "merge-with-kernel-builder", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index e405b60d..13f40054 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/merge-with-kernel-builder"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/nix/client.nix b/nix/client.nix index 351fd08a..be8e2fc7 100644 --- a/nix/client.nix +++ b/nix/client.nix @@ -1,6 +1,7 @@ { buildPythonPackage, poetry-core, + aiohttp, huggingface-hub, pydantic, }: @@ -15,6 +16,7 @@ buildPythonPackage { build-system = [ poetry-core ]; dependencies = [ + aiohttp huggingface-hub pydantic ]; diff --git a/nix/server.nix b/nix/server.nix index e6493531..a45f39cc 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -31,7 +31,7 @@ peft, pillow, prometheus-client, - punica-kernels, + punica-sgmv, py-cpuinfo, pydantic, quantization, @@ -107,7 +107,7 @@ buildPythonPackage { peft pillow prometheus-client - punica-kernels + punica-sgmv py-cpuinfo pydantic quantization diff --git a/server/Makefile b/server/Makefile index f4855392..a95a4ae5 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,7 +3,6 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-selective-scan -include Makefile-lorax-punica include Makefile-exllamav2 include Makefile-flashinfer diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica deleted file mode 100644 index 72f06f76..00000000 --- a/server/Makefile-lorax-punica +++ /dev/null @@ -1,12 +0,0 @@ -lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc - -build-lorax-punica: - if [ ! -d 'lorax-punica' ]; then \ - git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ - fi - cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) - cd lorax-punica && git submodule update --init --recursive - cd lorax-punica/server/punica_kernels && python setup.py build - -install-lorax-punica: build-lorax-punica - cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/server/kernels.lock b/server/kernels.lock index 1bce05c6..a06cbff3 100644 --- a/server/kernels.lock +++ b/server/kernels.lock @@ -163,6 +163,64 @@ } } }, + { + "repo_id": "kernels-community/punica-sgmv", + "sha": "9ae1b469cb39c33df9ddd61657c6359acc423714", + "variants": { + "torch26-cxx11-cu118-x86_64-linux": { + "hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu124-x86_64-linux": { + "hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-x86_64-linux": { + "hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu118-x86_64-linux": { + "hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu124-x86_64-linux": { + "hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-x86_64-linux": { + "hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1", + "hash_type": "git_lfs_concat" + } + } + }, { "repo_id": "kernels-community/quantization", "sha": "6470f9b005797e00279eb9103463dfe0f8b7da00", diff --git a/server/pyproject.toml b/server/pyproject.toml index 5489b19d..7f2addb6 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta" [tool.kernels.dependencies] "kernels-community/paged-attention" = ">=0.0.2" "kernels-community/moe" = ">=0.1.1" +"kernels-community/punica-sgmv" = ">=0.0.1" "kernels-community/quantization" = ">=0.0.3" "kernels-community/quantization-eetq" = ">=0.0.1" "kernels-community/rotary" = ">=0.0.1" diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 782d66e4..c8eb48a2 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -13,21 +13,20 @@ from torch.distributed import ProcessGroup from text_generation_server.utils.log import log_master from text_generation_server.adapters.config import AdapterConfig, ModuleMap - +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.adapters.weights import ( AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, ) -from text_generation_server.utils.sgmv import ( - BGMV_MAX_RANK, - MAX_RANK_CUSTOM, - get_tmp_tensors, - orient_for_rank, - pad_rank, - use_cutlass_shrink, - has_sgmv, -) + +if SYSTEM == "cuda": + punica_sgmv = load_kernel( + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" + ) +else: + punica_sgmv = None def get_start_stop_idxs_for_rank(offset, size, rank, world_size): @@ -129,11 +128,13 @@ class LoraWeights(AdapterWeights): self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 - self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) self._is_transposed = False # [num_layers, hidden_size, r] - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + weights_a = [ + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a + ] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] @@ -244,8 +245,12 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] - lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] + lora_a_list = [ + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list + ] + lora_b_list = [ + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list + ] if lora_a_list: # update rank if it was padded @@ -293,7 +298,7 @@ class BatchLoraWeights(BatchAdapterWeights): def can_vectorize(self, pg: ProcessGroup) -> bool: return all( - rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM for rank_data in self.rank_data.values() ) @@ -337,8 +342,8 @@ class BatchLoraWeights(BatchAdapterWeights): ) use_sgmv = False - if prefill or max_rank > BGMV_MAX_RANK: - if has_sgmv(): + if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK: + if punica_sgmv is not None: use_sgmv = True lora_a_ptr = torch.tensor( [ @@ -425,7 +430,7 @@ class BatchLoraWeights(BatchAdapterWeights): if use_sgmv: lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors( + tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors( lora_a_ptr_indices.size(0), rank, device ) segment_starts = meta.adapter_segments[indices] diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index a4537b55..abfb097d 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -5,14 +5,16 @@ import torch.distributed from torch import nn from torch.distributed import ProcessGroup -from text_generation_server.utils.sgmv import ( - add_lora_a_bgmv, - add_lora_b_bgmv, - has_sgmv, - lora_a_sgmv_cutlass, - lora_b_sgmv_cutlass, - orient_for_rank, -) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel + +if SYSTEM == "cuda": + punica_sgmv = load_kernel( + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" + ) +else: + punica_sgmv = None + if TYPE_CHECKING: from text_generation_server.adapters import AdapterBatchData @@ -41,7 +43,11 @@ class LoraLinear(nn.Module): return result data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) - if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if ( + punica_sgmv is not None + and data is not None + and data.can_vectorize(self.process_group) + ): # In tensor-parallel configurations, each GPU processes a specific segment of the output. # The 'result' tensor represents the full output, which can vary in size based on # the layer type (e.g., attention vs. feed-forward layers). We define the current @@ -68,7 +74,7 @@ class LoraLinear(nn.Module): if data.use_sgmv: # Use SGMV for prefill - v = lora_a_sgmv_cutlass( + v = punica_sgmv.lora_a_sgmv_cutlass( input, rank_segments.tmp_shrink, lora_a_ptr, @@ -81,7 +87,7 @@ class LoraLinear(nn.Module): if self.process_group.size() > 1: v = self.collect_lora_a(v) - lora_b_sgmv_cutlass( + punica_sgmv.lora_b_sgmv_cutlass( proj, v, rank_segments.tmp_expand, @@ -96,7 +102,7 @@ class LoraLinear(nn.Module): (input.size(0), r), dtype=input.dtype, device=input.device ) # TODO: error with [-1, 0], but not [0, -1] - add_lora_a_bgmv( + punica_sgmv.add_lora_a_bgmv( v, input, lora_a_ptr, @@ -107,7 +113,7 @@ class LoraLinear(nn.Module): if self.process_group.size() > 1: v = self.collect_lora_a(v) - add_lora_b_bgmv( + punica_sgmv.add_lora_b_bgmv( proj, v, lora_b_ptr, @@ -142,7 +148,7 @@ class LoraLinear(nn.Module): lora_a = data.lora_a[adapter_index][self.layer_id, :, :] lora_b = data.lora_b[adapter_index][self.layer_id, :, :] - lora_a = orient_for_rank(lora_a, lora_b.size(0)) + lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0)) a_out = input @ lora_a if self.process_group.size() > 1: diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py deleted file mode 100644 index 2d0a73a5..00000000 --- a/server/text_generation_server/utils/sgmv.py +++ /dev/null @@ -1,252 +0,0 @@ -# Origin: https://github.com/predibase/lorax -# Path: lorax/server/lorax_server/utils/sgmv.py -# License: Apache License Version 2.0, January 2004 - -import os -import warnings -from functools import lru_cache -from typing import List, Tuple - -import torch -import torch.nn.functional as F - -try: - import punica_kernels as _kernels - - HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) -except ImportError: - warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") - _kernels = None - HAS_SGMV = False - - -MIN_SGMV_RANK = 8 -MIN_RANK_CUSTOM = 16 -MAX_RANK_CUSTOM = 128 -SGMV_BLOCK_SIZE = 16 -BGMV_MAX_RANK = 64 - - -def has_sgmv() -> bool: - return HAS_SGMV - - -def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: - """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" - if not has_sgmv(): - return t - - # tensor parallelism will result in effective rank being divided by world_size, - # so we need to scale the min rank to offset that effect - min_rank = MIN_SGMV_RANK * world_size - - # if we're at or below the min rank, pad up to the min rank - # otherwise, pad to the nearest multiple of the block size - current_rank = t.size(dim) - target_rank = ( - min_rank - if current_rank <= min_rank - else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE - ) - if current_rank == target_rank: - return t - - pad_size = target_rank - current_rank - - # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - pad = [0, 0] * t.dim() - pad[(t.dim() - dim - 1) * 2 + 1] = pad_size - pad = tuple(pad) - - return F.pad(t, pad, mode="constant", value=0.0) - - -def use_cutlass_shrink(lora_rank: int) -> bool: - return lora_rank < MIN_RANK_CUSTOM - - -def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: - if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: - return t.transpose(0, 1) - return t - - -# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py -def add_lora_sgmv_cutlass( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.Tensor, - s_end: torch.Tensor, - layer_idx: int, - lora_rank: int, -): - """ - Semantics: - y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H1]`. - wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H2]`. - s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. - s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. - layer_idx: Layer index of the weight matrices. - """ - if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: - # Custom SGMV shrink only supports rank 16, 32, 64, 128 - _add_lora_sgmv_cutlass_legacy( - y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank - ) - return - - tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) - tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) - - -def _add_lora_sgmv_cutlass_legacy( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -): - tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -@lru_cache(maxsize=1) -def get_tmp_tensor(device: torch.device) -> torch.Tensor: - return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) - - -@lru_cache(maxsize=32) -def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: - tmp_size = _kernels.sgmv_cutlass_tmp_size(size) - return torch.empty((tmp_size,), dtype=torch.uint8, device=device) - - -def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: - return torch.empty((size,), dtype=torch.uint8, device=device) - - -def get_tmp_expand_size(size: int) -> int: - return _kernels.sgmv_cutlass_tmp_size(size) - - -def get_tmp_tensors( - nsegments: int, lora_rank: int, device: torch.device -) -> Tuple[torch.Tensor, torch.Tensor]: - use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv() - has_sgmv_available = has_sgmv() - - if use_cutlass: - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - elif has_sgmv_available: - return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device) - else: - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - - -def lora_a_sgmv_cutlass( - x: torch.Tensor, - tmp: torch.Tensor, - wa_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -) -> torch.Tensor: - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - else: - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - return v - - -def lora_b_sgmv_cutlass( - y: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, -): - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -""" -Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - -Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - v: Shape: `[B, R]`. Temporary vector. - x: Shape: `[B, H1]`. Input vectors. - wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. - wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. -""" - - -def add_lora_a_bgmv( - v: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) - - -def add_lora_b_bgmv( - y: torch.Tensor, - v: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) - - -def segmented_matmul( - y: torch.Tensor, - x: torch.Tensor, - w: List[torch.Tensor], - b: List[torch.Tensor], - s_start: torch.IntTensor, - s_end: torch.IntTensor, -): - for i in range(len(w)): - if s_end[i] - s_start[i] <= 0: - continue - - xi = x[s_start[i] : s_end[i]] - wi = w[i] - bi = b[i] - y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) From 9e7e546923abf89b11adcd0c34c98a8e123d12a8 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 22 May 2025 15:21:31 +0800 Subject: [PATCH 6/9] Move input_ids to hpu and remove disposal of adapter_meta (#3237) Signed-off-by: Wang, Yi A --- .../layers/attention/common.py | 2 + .../models/flash_causal_lm.py | 227 ++++++++++-------- .../models/flash_vlm_causal_lm.py | 6 +- .../models/mllama_causal_lm.py | 8 +- 4 files changed, 136 insertions(+), 107 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 9bd738fc..5e03cd44 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -90,6 +90,8 @@ class Seqlen: def _async_h2d_tensor_copy(source, device="hpu"): if source is None: return None + if source.device.type == "hpu": + return source assert source.device.type == "cpu", "Source tensor is not present in host memory!" target = torch.empty(source.shape, dtype=source.dtype, device=device) target.copy_(source, non_blocking=True) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index bc0d240e..f8abe5ad 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -634,21 +634,25 @@ class FlashCausalLMBatch(Batch): # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] input_lengths_tensor = self.input_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ) + if self.adapter_meta is not None: + adapter_indices = self.adapter_meta.adapter_indices[indices] + adapter_segments, adapter_segment_indices = find_segments( + adapter_indices + ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + else: + adapter_meta = None htorch.core.mark_step() return type(self)( batch_id=self.batch_id, @@ -710,6 +714,7 @@ class FlashCausalLMBatch(Batch): max_length = 0 max_input_length = 0 max_current_length = 0 + ADAPTER_TO_INDEX = get_adapter_to_index() for b in batches: total_batch_size += len(b) max_blocks = max(max_blocks, b.max_blocks) @@ -763,14 +768,15 @@ class FlashCausalLMBatch(Batch): cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( total_batch_size ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_segment_builder = SegmentConcatBuilder() - adapter_set = set() + if ADAPTER_TO_INDEX: + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size @@ -821,9 +827,7 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + valid_bsize - index = torch.tensor( - list(range(start_index, end_index)), device=batch.input_ids.device - ) + index = torch.tensor(list(range(start_index, end_index)), device="cpu") top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -847,7 +851,9 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) + input_ids.index_copy_( + 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] + ) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots @@ -858,20 +864,21 @@ class FlashCausalLMBatch(Batch): cache_lengths_tensor.index_copy_( 0, index, batch.cache_lengths_tensor[:valid_bsize] ) - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, - batch.adapter_meta.segment_indices, - ) + if ADAPTER_TO_INDEX: + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -914,7 +921,7 @@ class FlashCausalLMBatch(Batch): else: speculative_ids = None - if adapter_segment_builder is not None: + if ADAPTER_TO_INDEX and adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, @@ -961,7 +968,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=adapter_meta, + adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None, hpu_attn_meta=None, next_token_logits=None, speculative_logits=None, @@ -1037,6 +1044,7 @@ class FlashCausalLMBatch(Batch): # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length extra_pad_bs = max_padded_bs - len(self) + device = self.all_input_ids_tensor.device if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1047,12 +1055,12 @@ class FlashCausalLMBatch(Batch): input_ids.append(input_id) input_ids_padded_length.append(padded) input_ids = np.concatenate(input_ids, dtype=np.int64) - self.input_ids = torch.tensor(input_ids, dtype=torch.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) elif isinstance(self.input_ids, list): input_ids = self.input_ids[0] input_ids_padded_length.append(extra_pad) input_ids = [0] * extra_pad + input_ids - self.input_ids = torch.tensor(input_ids, dtype=torch.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) input_ids_padded_length.extend([extra_pad] * len(self)) @@ -1245,7 +1253,9 @@ class FlashCausalLMBatch(Batch): self.slot_indices = slot_indices self.prefill_cu_outlens = prefill_cu_outlens - self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices = torch.zeros_like( + self.input_ids, dtype=torch.bool, device="cpu" + ) self.prefill_cache_indices[prefill_cache_indices] = True if all_prefill_logprobs: @@ -1301,21 +1311,24 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - if adapter_set: - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - else: - adapter_indices = torch.zeros_like(self.input_ids) - adapter_segments = [0, len(adapter_indices)] - adapter_segment_indices = [len(adapter_indices) - 1] + if ADAPTER_TO_INDEX: + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) + adapter_segments, adapter_segment_indices = find_segments( + adapter_indices + ) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) - self.adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) def __len__(self): return len(self.requests) @@ -1941,11 +1954,11 @@ class FlashCausalLM(Model): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad seqlen = Seqlen( @@ -1965,7 +1978,7 @@ class FlashCausalLM(Model): ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, @@ -2059,15 +2072,16 @@ class FlashCausalLM(Model): batch.position_ids = batch.position_ids[indices] batch.slot_indices = batch.slot_indices[indices[: len(batch)]] - batch.adapter_meta.adapter_indices = ( - batch.adapter_meta.adapter_indices[indices] - ) + if batch.adapter_meta is not None: + batch.adapter_meta.adapter_indices = ( + batch.adapter_meta.adapter_indices[indices] + ) # For each member of the batch # Cumulative length - accepted_ids = accepted_ids.cpu() - cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) - torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + if batch.speculative_logits is not None: + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) for i in range(len(batch)): batch.all_input_ids_tensor[ i, @@ -2076,6 +2090,20 @@ class FlashCausalLM(Model): + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + accepted_ids = accepted_ids.cpu() + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += ( + batch.input_lengths_tensor + accepted_ids - 1 + ) + batch.input_lengths_tensor = torch.ones_like( + batch.input_lengths_tensor + ) + batch.slot_indices += accepted_ids[: len(batch)] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = index.to(batch.all_input_ids_tensor.device) @@ -2088,22 +2116,18 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) - next_input_ids = next_input_ids.cpu() - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.input_ids = next_input_ids + batch.position_ids += 1 + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = torch.ones_like( + batch.input_lengths_tensor + ) + batch.slot_indices += 1 + batch.speculative_ids = speculative_ids - if batch.position_ids.dim() == 2: - # Qwen2_vl case: - batch.position_ids += accepted_ids.unsqueeze(-1) - else: - batch.position_ids += accepted_ids - batch.cache_lengths_tensor += ( - batch.input_lengths_tensor + accepted_ids - 1 - ) - batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids[: len(batch)] # Does a HPU <-> CPU sync internally - if prefill: + if prefill and batch.adapter_meta is not None: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments( batch.adapter_meta.adapter_indices @@ -2194,30 +2218,33 @@ class FlashCausalLM(Model): prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta - if batch.speculative_ids is not None: - B, speculative_length = batch.speculative_ids.shape - new_length = speculative_length + 1 - adapter_indices = ( - adapter_meta.adapter_indices.unsqueeze(-1) - .expand(B, new_length) - .reshape(-1) - ) - adapter_segments = adapter_meta.adapter_segments * new_length - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_meta.segment_indices, - ) + if adapter_meta is not None: + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) - # Assign pointers to adapter weights - # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta( - adapter_meta, - self.layer_to_adapter_weights, - prefill, - batch.prefill_head_indices, - ) + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + else: + adapter_data = None out, speculative_logits = self.forward(batch, adapter_data) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index fd239b3e..e604fd3c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -627,11 +627,11 @@ class FlashVlmCausalLM(FlashCausalLM): batch.prefilling, seqlen, batch_size ) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad @@ -639,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index db3904a2..771cc0a8 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -190,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): input_ids = np.concatenate(batch.input_ids, dtype=np.int64) else: input_ids = batch.input_ids[0] - batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) @@ -537,11 +537,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad orig_bs = len(batch) @@ -570,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, From 674c514d448b4ca7346e1fe18731aa70dc78acb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 22 May 2025 09:43:55 +0200 Subject: [PATCH 7/9] Prepare for 3.3.1 (#3238) --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 2 +- README.md | 6 +++--- docs/openapi.json | 2 +- docs/source/backends/gaudi.mdx | 10 +++++----- docs/source/backends/neuron.md | 2 +- .../source/basic_tutorials/gated_model_access.md | 2 +- docs/source/conceptual/quantization.md | 6 +++--- docs/source/installation_amd.md | 2 +- docs/source/installation_intel.md | 4 ++-- docs/source/installation_nvidia.md | 2 +- docs/source/quicktour.md | 4 ++-- docs/source/reference/api_reference.md | 2 +- .../test_flash_gemma3_image_base64_rgb_jpg.json | 2 +- .../test_flash_gemma3_image_base64_rgb_png.json | 2 +- .../test_flash_gemma3_image_base64_rgba.json | 2 +- .../test_flash_gemma3_image_cow.json | 2 +- .../test_flash_gemma3_image_cow_dog.json | 2 +- .../test_json_schema_basic.json | 2 +- .../test_json_schema_complex.json | 2 +- .../test_mllama/test_mllama_load.json | 4 ++-- .../test_mllama/test_mllama_simpl.json | 2 +- 22 files changed, 40 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c757f885..b09f1c3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4650,7 +4650,7 @@ dependencies = [ [[package]] name = "text-generation-backends-trtllm" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "async-trait", "clap 4.5.32", @@ -4671,7 +4671,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "average", "clap 4.5.32", @@ -4691,7 +4691,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -4709,7 +4709,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "clap 4.5.32", "ctrlc", @@ -4730,7 +4730,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "anyhow", "async-stream", @@ -4782,7 +4782,7 @@ dependencies = [ [[package]] name = "text-generation-router-llamacpp" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "async-trait", "bindgen 0.71.1", @@ -4800,7 +4800,7 @@ dependencies = [ [[package]] name = "text-generation-router-v2" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "async-stream", "async-trait", @@ -4849,7 +4849,7 @@ dependencies = [ [[package]] name = "text-generation-router-v3" -version = "3.3.0-dev0" +version = "3.3.1-dev0" dependencies = [ "async-stream", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index df40d8d5..f7b1e3b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ default-members = [ resolver = "2" [workspace.package] -version = "3.3.0-dev0" +version = "3.3.1-dev0" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/README.md b/README.md index 0d8fedbd..79991590 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model ``` And then you can make requests like @@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \ **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. -**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0-rocm --model-id $model` instead of the command above. +**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1-rocm --model-id $model` instead of the command above. To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): ``` @@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading token= docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model ``` ### A note on Shared Memory (shm) diff --git a/docs/openapi.json b/docs/openapi.json index 5486413e..9249acad 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "3.3.0-dev0" + "version": "3.3.1-dev0" }, "paths": { "/": { diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx index 33686966..ab882fc2 100644 --- a/docs/source/backends/gaudi.mdx +++ b/docs/source/backends/gaudi.mdx @@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN docker run --runtime=habana --cap-add=sys_nice --ipc=host \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ --model-id $model ``` @@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN docker run --runtime=habana --cap-add=sys_nice --ipc=host \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ --model-id $model ``` @@ -115,7 +115,7 @@ docker run -p 8080:80 \ -e BATCH_BUCKET_SIZE=256 \ -e PREFILL_BATCH_BUCKET_SIZE=4 \ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ --model-id $model \ --sharded true --num-shard 8 \ --max-input-tokens 1024 --max-total-tokens 2048 \ @@ -141,7 +141,7 @@ docker run -p 8080:80 \ -v $volume:/data \ -e PREFILL_BATCH_BUCKET_SIZE=1 \ -e BATCH_BUCKET_SIZE=1 \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ --model-id $model \ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \ --max-total-tokens 8192 --max-batch-size 4 @@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \ -e PROF_PATH=/tmp/hpu_profile \ -e PROF_RANKS=0 \ -e PROF_RECORD_SHAPES=True \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \ --model-id $model ``` diff --git a/docs/source/backends/neuron.md b/docs/source/backends/neuron.md index 5c4829bc..a1fa3a9e 100644 --- a/docs/source/backends/neuron.md +++ b/docs/source/backends/neuron.md @@ -31,7 +31,7 @@ deployment instructions in the model card: The service is launched simply by running the text-generation-inference container with two sets of parameters: ``` -docker run ghcr.io/huggingface/text-generation-inference:3.3.0-neuron +docker run ghcr.io/huggingface/text-generation-inference:3.3.1-neuron ``` - system parameters are used to map ports, volumes and devices between the host and the service, diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index 35be7bab..dfed553e 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -19,6 +19,6 @@ docker run --gpus all \ --shm-size 1g \ -e HF_TOKEN=$token \ -p 8080:80 \ - -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 \ + -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 \ --model-id $model ``` diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 73c77d4b..c215f4c3 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes ``` 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. @@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes-nf4 +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes-nf4 ``` You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). @@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$ TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize gptq +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize gptq ``` Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 38e67aac..9f92859c 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0-rocm \ + ghcr.io/huggingface/text-generation-inference:3.3.1-rocm \ --model-id $model ``` diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md index e29285c3..71c8a2de 100644 --- a/docs/source/installation_intel.md +++ b/docs/source/installation_intel.md @@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0-intel-xpu \ + ghcr.io/huggingface/text-generation-inference:3.3.1-intel-xpu \ --model-id $model --cuda-graphs 0 ``` @@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0-intel-cpu \ + ghcr.io/huggingface/text-generation-inference:3.3.1-intel-cpu \ --model-id $model --cuda-graphs 0 ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 56619bce..40ae145b 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 \ + ghcr.io/huggingface/text-generation-inference:3.3.1 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 6a2d73c1..76832317 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 \ + ghcr.io/huggingface/text-generation-inference:3.3.1 \ --model-id $model ``` @@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:3.3.0 --help +docker run ghcr.io/huggingface/text-generation-inference:3.3.1 --help ``` diff --git a/docs/source/reference/api_reference.md b/docs/source/reference/api_reference.md index 0fc8714d..8dbe977a 100644 --- a/docs/source/reference/api_reference.md +++ b/docs/source/reference/api_reference.md @@ -163,7 +163,7 @@ hub = { # create Hugging Face Model Class huggingface_model = HuggingFaceModel( - image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.0"), + image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.1"), env=hub, role=role, ) diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json index 41eb19fd..df9daac8 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 42, "prompt_tokens": 277, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json index 1f3e2b91..328105ca 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 62, "prompt_tokens": 277, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json index 42a2be01..b7918d48 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 67, "prompt_tokens": 277, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json index fbe95016..43d01863 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 72, "prompt_tokens": 275, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json index 40f317cc..9d80a763 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json @@ -17,7 +17,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 80, "prompt_tokens": 279, diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json index 23fb8dda..30241eb9 100644 --- a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json @@ -14,7 +14,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 35, "prompt_tokens": 32, diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json index e344a226..008ae5b0 100644 --- a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json @@ -14,7 +14,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 44, "prompt_tokens": 37, diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json index 37c8ef8e..50e75361 100644 --- a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json @@ -18,7 +18,7 @@ "id": "", "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, @@ -44,7 +44,7 @@ "id": "", "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json index 75dc0ddf..91297113 100644 --- a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json @@ -17,7 +17,7 @@ "id": "", "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.1-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, From f08b44ade5c64ce87aff7ff4d74f766282f579a3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 22 May 2025 21:29:16 +0800 Subject: [PATCH 8/9] Upgrade to new vllm extension ops for Gaudi backend (fix issue in exponential bucketing) (#3239) Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 2 +- .../layers/attention/hpu.py | 3 +++ .../layers/attention/kv_cache.py | 18 ++++++------------ 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index bd6c58b4..c4164556 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -98,7 +98,7 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir -RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794 +RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1c2e37c7..8cca7a29 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -7,6 +7,7 @@ from vllm_hpu_extension.utils import Matmul from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA import os +from text_generation_server.models.globals import BLOCK_SIZE SUPPORTS_WINDOWING = False @@ -126,6 +127,7 @@ def paged_attention( block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, + block_size=BLOCK_SIZE, scale=softmax_scale, matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), @@ -160,6 +162,7 @@ def paged_attention_mla( block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, + block_size=BLOCK_SIZE, scale=softmax_scale, matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index cdd1e1d7..723c1ec0 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -5,7 +5,6 @@ import torch from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.utils.weights import Weights -from vllm_hpu_extension import cache_ops @dataclass @@ -55,12 +54,12 @@ class KVCache: self.kv_cache = ( torch.zeros( - (num_blocks, BLOCK_SIZE, num_heads, head_size), + (num_blocks * BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), torch.zeros( - (num_blocks, BLOCK_SIZE, num_heads, head_size), + (num_blocks * BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), @@ -129,7 +128,7 @@ class KVCompressCache(KVCache): raise ValueError("torch.float8_e5m2 is not supported in hpu. ") self.kv_cache = torch.zeros( - (num_blocks, BLOCK_SIZE, 1, head_size), + (num_blocks * BLOCK_SIZE, 1, head_size), dtype=dtype, device=device, ) @@ -161,14 +160,11 @@ class KVCompressCache(KVCache): ): """Store the key and value at the given slots.""" ## TODO FP8 kv cache support - - block_idx = slots // BLOCK_SIZE - block_offset = slots % BLOCK_SIZE if self.kv_cache.dtype == torch.float8_e4m3fn: key = torch.ops.hpu.cast_to_fp8_v2( key, kv_scales.key_scale, False, False, torch.float8_e4m3fn )[0] - cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset) + self.kv_cache.index_copy_(0, slots, key) def paged_reshape_and_cache( @@ -180,8 +176,6 @@ def paged_reshape_and_cache( k_scale: torch.Tensor, v_scale: torch.Tensor, ): - block_idx = slots // BLOCK_SIZE - block_offset = slots % BLOCK_SIZE if key_cache.dtype == torch.float8_e4m3fn: key = torch.ops.hpu.cast_to_fp8_v2( key, k_scale, False, False, torch.float8_e4m3fn @@ -189,8 +183,8 @@ def paged_reshape_and_cache( value = torch.ops.hpu.cast_to_fp8_v2( value, v_scale, False, False, torch.float8_e4m3fn )[0] - cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + key_cache.index_copy_(0, slots, key) + value_cache.index_copy_(0, slots, value) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: From f58d7cf50e78f3430e1efa0608189b960d23db74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 22 May 2025 17:09:15 +0200 Subject: [PATCH 9/9] Nix: switch to hf-nix (#3240) * Nix: switch to hf-nix * Remove outdated local overrides --- .github/workflows/nix_build.yaml | 2 +- .github/workflows/nix_cache.yaml | 2 +- .github/workflows/nix_tests.yaml | 2 +- Dockerfile.nix | 2 +- README.md | 2 +- flake.lock | 59 ++++++++++++++++---------------- flake.nix | 14 ++++---- nix/overlay.nix | 40 +++++++++++----------- 8 files changed, 61 insertions(+), 62 deletions(-) diff --git a/.github/workflows/nix_build.yaml b/.github/workflows/nix_build.yaml index 71ad59d0..e0076af6 100644 --- a/.github/workflows/nix_build.yaml +++ b/.github/workflows/nix_build.yaml @@ -21,7 +21,7 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - uses: cachix/cachix-action@v14 with: - name: text-generation-inference + name: huggingface # If you chose signing key for write access authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' env: diff --git a/.github/workflows/nix_cache.yaml b/.github/workflows/nix_cache.yaml index 967a5982..7c73e584 100644 --- a/.github/workflows/nix_cache.yaml +++ b/.github/workflows/nix_cache.yaml @@ -20,7 +20,7 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - uses: cachix/cachix-action@v14 with: - name: text-generation-inference + name: huggingface # If you chose signing key for write access authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" env: diff --git a/.github/workflows/nix_tests.yaml b/.github/workflows/nix_tests.yaml index d9b91048..4f68ff60 100644 --- a/.github/workflows/nix_tests.yaml +++ b/.github/workflows/nix_tests.yaml @@ -25,7 +25,7 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - uses: cachix/cachix-action@v14 with: - name: text-generation-inference + name: huggingface # If you chose signing key for write access authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' env: diff --git a/Dockerfile.nix b/Dockerfile.nix index f1e7e0f5..90390de6 100644 --- a/Dockerfile.nix +++ b/Dockerfile.nix @@ -6,7 +6,7 @@ FROM nixos/nix:2.18.8 AS builder RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf RUN nix profile install nixpkgs#cachix -RUN cachix use text-generation-inference +RUN cachix use huggingface WORKDIR /root ADD . . RUN nix build . diff --git a/README.md b/README.md index 79991590..f4c6c562 100644 --- a/README.md +++ b/README.md @@ -256,7 +256,7 @@ Another option is to install `text-generation-inference` locally using [Nix](htt we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can be pulled from a binary cache, removing the need to build them locally. -First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference). +First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface). Setting up the cache is important, otherwise Nix will build many of the dependencies locally, which can take hours. diff --git a/flake.lock b/flake.lock index 2c6e8063..e57990c8 100644 --- a/flake.lock +++ b/flake.lock @@ -102,7 +102,7 @@ "flake-parts": "flake-parts_3", "nix-test-runner": "nix-test-runner_3", "nixpkgs": [ - "tgi-nix", + "hf-nix", "nixpkgs" ], "pre-commit-hooks": "pre-commit-hooks_3" @@ -579,6 +579,26 @@ "type": "github" } }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_4", + "flake-utils": "flake-utils_7", + "nixpkgs": "nixpkgs_6" + }, + "locked": { + "lastModified": 1747919133, + "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, "nix-filter": { "locked": { "lastModified": 1731533336, @@ -718,16 +738,16 @@ }, "nixpkgs_6": { "locked": { - "lastModified": 1746711195, - "narHash": "sha256-bSpM2ySq12PBOVN7jZdzXsc99iRoYOyolh5wz43+CjQ=", + "lastModified": 1747820358, + "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", "owner": "danieldk", "repo": "nixpkgs", - "rev": "6b7a66b06ccb09ac95872ac6ddf952e0660672ab", + "rev": "d3c1681180717528068082103bf323147de6ab0b", "type": "github" }, "original": { "owner": "danieldk", - "ref": "kernel-builder-cuda-12.9.0", + "ref": "cudatoolkit-12.9-kernel-builder", "repo": "nixpkgs", "type": "github" } @@ -836,19 +856,19 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", + "hf-nix": "hf-nix", "nix-filter": "nix-filter", "nixpkgs": [ - "tgi-nix", + "hf-nix", "nixpkgs" ], - "rust-overlay": "rust-overlay", - "tgi-nix": "tgi-nix" + "rust-overlay": "rust-overlay" } }, "rust-overlay": { "inputs": { "nixpkgs": [ - "tgi-nix", + "hf-nix", "nixpkgs" ] }, @@ -970,27 +990,6 @@ "repo": "default", "type": "github" } - }, - "tgi-nix": { - "inputs": { - "flake-compat": "flake-compat_4", - "flake-utils": "flake-utils_7", - "nixpkgs": "nixpkgs_6" - }, - "locked": { - "lastModified": 1747733488, - "narHash": "sha256-LYov4H9zvqXXlFKdytcVcDioH416c+LWfyw/HWta0qw=", - "owner": "huggingface", - "repo": "text-generation-inference-nix", - "rev": "61c730990efa58e64c652bf15253aae47dd0f7dd", - "type": "github" - }, - "original": { - "owner": "huggingface", - "ref": "merge-with-kernel-builder", - "repo": "text-generation-inference-nix", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 13f40054..b5b13cad 100644 --- a/flake.nix +++ b/flake.nix @@ -2,15 +2,15 @@ inputs = { crate2nix = { url = "github:nix-community/crate2nix"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + inputs.nixpkgs.follows = "hf-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/merge-with-kernel-builder"; - nixpkgs.follows = "tgi-nix/nixpkgs"; + hf-nix.url = "github:huggingface/hf-nix"; + nixpkgs.follows = "hf-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { url = "github:oxalica/rust-overlay"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + inputs.nixpkgs.follows = "hf-nix/nixpkgs"; }; }; outputs = @@ -21,7 +21,7 @@ nixpkgs, flake-utils, rust-overlay, - tgi-nix, + hf-nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -33,10 +33,10 @@ }; pkgs = import nixpkgs { inherit system; - inherit (tgi-nix.lib) config; + inherit (hf-nix.lib) config; overlays = [ rust-overlay.overlays.default - tgi-nix.overlays.default + hf-nix.overlays.default (import nix/overlay.nix) ]; }; diff --git a/nix/overlay.nix b/nix/overlay.nix index 069fdd80..0eb07c2a 100644 --- a/nix/overlay.nix +++ b/nix/overlay.nix @@ -13,26 +13,26 @@ final: prev: { ( python-self: python-super: with python-self; { # Python package override example: - transformers = python-super.transformers.overrideAttrs ( - _: _: { - src = final.fetchFromGitHub { - owner = "huggingface"; - repo = "transformers"; - rev = "v4.51.0"; - hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8="; - }; - } - ); - huggingface-hub = python-super.huggingface-hub.overrideAttrs ( - _: _: { - src = final.fetchFromGitHub { - owner = "huggingface"; - repo = "huggingface_hub"; - rev = "v0.30.0"; - hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o="; - }; - } - ); + #transformers = python-super.transformers.overrideAttrs ( + # _: _: { + # src = final.fetchFromGitHub { + # owner = "huggingface"; + # repo = "transformers"; + # rev = "v4.51.0"; + # hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8="; + # }; + # } + #); + #huggingface-hub = python-super.huggingface-hub.overrideAttrs ( + # _: _: { + # src = final.fetchFromGitHub { + # owner = "huggingface"; + # repo = "huggingface_hub"; + # rev = "v0.30.0"; + # hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o="; + # }; + # } + #); } ) ];