diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 180f5933..1950ea98 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -140,7 +140,7 @@ class FlashLlamaAttention(torch.nn.Module): if self.kv_cache_dtype == "fp8": self.kv_scale = weights.get_kv_cache_scaling_factor( - prefix, self.kv_cache_dtype + prefix, self.kv_cache_dtype, config.kv_cache_torch_dtype ) else: self.kv_scale = 1.0 diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 5b1c051d..85d8ad10 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,6 +64,8 @@ class FlashLlama(FlashCausalLM): config.quantize = quantize config.speculator = speculator config.kv_cache_dtype = kv_cache_dtype + if not hasattr(config, "kv_cache_torch_dtype"): + config.kv_cache_torch_dtype = None torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 7db32538..68f89f64 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -767,7 +767,9 @@ class Weights: except Exception: pass - def get_kv_cache_scaling_factor(self, prefix: str, kv_cache_dtype: str): + def get_kv_cache_scaling_factor( + self, prefix: str, kv_cache_dtype: str, kv_cache_torch_dtype: str + ): try: kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist() except RuntimeError: @@ -791,13 +793,21 @@ class Weights: "Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache" ) + if kv_cache_torch_dtype not in {"float8_e4m3fn", "float8_e4m3fnuz"}: + raise RuntimeError( + f"Found `kv_scale` in the checkpoint, the config must specify the `kv_cache_torch_dtype` " + f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'." + ) + # ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3. # The multiplication by 2 compensates for the different numeric representation # between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms. # After this adjustment, the overall effect is equivalent to the scaling applied without # it on Nvidia GPUs. - if SYSTEM == "rocm": + if SYSTEM == "rocm" and kv_cache_torch_dtype == "float8_e4m3fn": kv_scale *= 2.0 + elif SYSTEM == "cuda" and kv_cache_torch_dtype == "float8_e4m3fnuz": + kv_scale /= 2.0 return kv_scale