add fp8 kv cache for rocm

This commit is contained in:
Mohit Sharma 2024-12-18 14:55:53 +00:00
parent 8f66d323d0
commit fa14d71ac8
4 changed files with 122 additions and 44 deletions

View File

@ -52,13 +52,22 @@ class KVCache:
device: torch.device, device: torch.device,
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
ATTENTION == "paged" and SYSTEM == "rocm"
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
)
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( self.kv_cache_dtype_str = "auto"
ATTENTION != "flashinfer" or SYSTEM != "cuda" if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
): self.kv_cache_dtype_str = "fp8"
raise ValueError( dtype = torch.uint8
"FP8 KV cache is currently only supported for flashinfer on CUDA"
)
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu": if SYSTEM == "ipex" and device.type == "xpu":
@ -120,6 +129,16 @@ class KVCache:
"Using FP8 KV cache scales", "Using FP8 KV cache scales",
) )
return True return True
elif (
self.kv_cache_dtype_str == "fp8"
and ATTENTION == "paged"
and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
@ -158,7 +177,7 @@ class KVCache:
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales): if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0: if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize( key = fp8_quantize(
key.float(), key.float(),
@ -188,7 +207,16 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
self.kv_cache_dtype_str,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -197,7 +225,11 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
kv_cache_dtype: str = "auto",
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
import attention_kernels import attention_kernels
@ -216,7 +248,7 @@ def paged_reshape_and_cache(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
ops.reshape_and_cache( ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -119,9 +119,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache.kv_cache_dtype_str,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -154,9 +154,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache.kv_cache_dtype_str,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -174,9 +174,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache.kv_cache_dtype_str,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
None, None,
_PARTITION_SIZE, _PARTITION_SIZE,
) )

View File

@ -398,10 +398,16 @@ class LlamaMLP(nn.Module):
return self.down_proj(out, adapter_data) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
return self.down_proj( out = torch.empty(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
) )
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# return self.down_proj(
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
# )
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):

View File

@ -520,28 +520,68 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( if (
input_ids, torch.distributed.get_rank() == 0
position_ids, and input_ids.shape[0] == 262144
cu_seqlen_prefill, and cu_seqlen_prefill is not None
kv_cache, ):
block_tables, with torch.profiler.profile(
slots, activities=[
seqlen, torch.profiler.ProfilerActivity.CPU,
max_s, torch.profiler.ProfilerActivity.CUDA,
true_max_s, ],
prefill_cache_indices, record_shapes=True,
) ) as prof:
if lm_head_indices is not None: true_max_s = max_s
hidden_states = hidden_states[lm_head_indices] if prefill_cache_indices is not None:
logits = self.lm_head(hidden_states) # Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
else:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits return logits