add fp8 kv cache for rocm

This commit is contained in:
Mohit Sharma 2024-12-18 14:55:53 +00:00
parent 8f66d323d0
commit fa14d71ac8
4 changed files with 122 additions and 44 deletions

View File

@ -52,13 +52,22 @@ class KVCache:
device: torch.device, device: torch.device,
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
ATTENTION != "flashinfer" or SYSTEM != "cuda" ATTENTION == "paged" and SYSTEM == "rocm"
): ):
raise ValueError( raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA" "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
) )
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
)
self.kv_cache_dtype_str = "auto"
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
self.kv_cache_dtype_str = "fp8"
dtype = torch.uint8
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu": if SYSTEM == "ipex" and device.type == "xpu":
@ -120,6 +129,16 @@ class KVCache:
"Using FP8 KV cache scales", "Using FP8 KV cache scales",
) )
return True return True
elif (
self.kv_cache_dtype_str == "fp8"
and ATTENTION == "paged"
and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
@ -158,7 +177,7 @@ class KVCache:
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales): if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0: if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize( key = fp8_quantize(
key.float(), key.float(),
@ -188,7 +207,16 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
self.kv_cache_dtype_str,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -197,7 +225,11 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
kv_cache_dtype: str = "auto",
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
import attention_kernels import attention_kernels
@ -216,7 +248,7 @@ def paged_reshape_and_cache(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
ops.reshape_and_cache( ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -119,9 +119,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache.kv_cache_dtype_str,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -154,9 +154,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache.kv_cache_dtype_str,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -174,9 +174,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache.kv_cache_dtype_str,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
None, None,
_PARTITION_SIZE, _PARTITION_SIZE,
) )

View File

@ -398,10 +398,16 @@ class LlamaMLP(nn.Module):
return self.down_proj(out, adapter_data) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
return self.down_proj( out = torch.empty(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
) )
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# return self.down_proj(
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
# )
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):

View File

@ -520,6 +520,46 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if (
torch.distributed.get_rank() == 0
and input_ids.shape[0] == 262144
and cu_seqlen_prefill is not None
):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
else:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor