diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 53837ef7..b1a41534 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -26,6 +26,11 @@ class Dtype(str, Enum): bloat16 = "bfloat16" +class KVCacheDtype(str, Enum): + fp8_e4m3fn = "fp8_e4m3fn" + fp8_e5m2 = "fp8_e5m2" + + @app.command() def serve( model_id: str, @@ -34,6 +39,7 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, + kv_cache_dtype: Optional[KVCacheDtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -93,7 +99,8 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = "bfloat16" if dtype is None else dtype.value - logger.info(f"quantize={quantize}") + kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value + logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}") if dtype is not None and quantize not in { None, "bitsandbytes", @@ -175,6 +182,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, uds_path, max_input_tokens, diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1d73dcb3..092fe138 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -11,11 +11,61 @@ import os SUPPORTS_WINDOWING = False -def fetch_from_cache(cache, blocks): - if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": - return cache[: blocks.size(0)] - else: - return cache.index_select(0, blocks) +class FP8Matmul(torch.nn.Module): + + def __init__(self, scale_other): + super().__init__() + self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu") + self.scale_other = scale_other + + def quant_input(self, x, scale): + return torch.ops.hpu.cast_to_fp8_v2( + x, scale, False, False, torch.float8_e4m3fn + )[0] + + def matmul_fp8( + self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None + ): + return torch.ops.hpu.fp8_gemm_v2( + A=x, + trans_A=False, + B=other, + trans_B=False, + D=None, + out_dtype=out_dtype, + A_scale_inv=scale_input_inv, + B_scale_inv=scale_other_inv, + bias=None, + accumulate=False, + ) + + def forward(self, input, other): + qinput = self.quant_input(input, self.scale_input) + qother = self.quant_input(other, self.scale_other) + output = self.matmul_fp8( + qinput, + qother, + out_dtype=torch.bfloat16, + scale_input_inv=1.0 / self.scale_input, + scale_other_inv=1.0 / self.scale_other, + ) + return output + + +class FetchFromCache(torch.nn.Module): + + def __init__(self, scale_inv): + super().__init__() + self.scale_inv = scale_inv + + def forward(self, cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + out = cache[: blocks.size(0)] + else: + out = cache.index_select(0, blocks) + if out.dtype == torch.float8_e4m3fn: + out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16) + return out def attention( @@ -67,6 +117,7 @@ def paged_attention( hpu_attention_meta: HPUPagedAttentionMetadata, ): batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn output = ops.flat_pa( query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, @@ -76,12 +127,12 @@ def paged_attention( block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, scale=softmax_scale, - matmul_qk_op=Matmul(), - matmul_av_op=Matmul(), + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), batch2block_matmul_op=Matmul(), block2batch_matmul_op=Matmul(), - keys_fetch_func=fetch_from_cache, - values_fetch_func=fetch_from_cache, + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu), ) # Reshape the output tensor. return output.view(batch_size, head_num, head_size) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index d238cdb9..e6c5f67d 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -50,6 +50,8 @@ class KVCache: ): """Construct the key-value cache for a layer.""" ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") self.kv_cache = ( torch.zeros( @@ -101,8 +103,8 @@ class KVCache: key_cache, value_cache, slots, - kv_scales.key_scale_cpu, - kv_scales.value_scale_cpu, + kv_scales.key_scale, + kv_scales.value_scale, ) @@ -112,11 +114,18 @@ def paged_reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ): block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE + if key_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, k_scale, False, False, torch.float8_e4m3fn + )[0] + value = torch.ops.hpu.cast_to_fp8_v2( + value, v_scale, False, False, torch.float8_e4m3fn + )[0] cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 778b14a1..7de01ce5 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -324,6 +324,7 @@ def get_model( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, ) -> Model: @@ -449,7 +450,12 @@ def get_model( model_type = config_dict["model_type"] - kv_cache_dtype = dtype + if kv_cache_dtype == "fp8_e4m3fn": + kv_cache_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + kv_cache_dtype = torch.float8_e5m2 + else: + kv_cache_dtype = dtype if FLASH_ATTENTION: if model_type == DEEPSEEK_V2: @@ -890,6 +896,7 @@ def get_model_with_lora_adapters( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, adapter_to_index: Dict[str, int], @@ -903,6 +910,7 @@ def get_model_with_lora_adapters( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, max_input_tokens, ) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ad585172..79626233 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1438,6 +1438,7 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 5a7d2117..70bcdb6d 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -224,6 +224,7 @@ def serve( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, uds_path: Path, max_input_tokens: int, @@ -236,6 +237,7 @@ def serve( quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, trust_remote_code: bool = False, ): if not is_driver_compatible(): @@ -279,6 +281,7 @@ def serve( quantize, speculate, data_type, + kv_cache_dtype, trust_remote_code, max_input_tokens, adapter_to_index, @@ -326,6 +329,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, ) )