improvements

This commit is contained in:
Mohit Sharma 2025-01-03 11:58:14 +00:00
parent fa14d71ac8
commit 43370a1f82
4 changed files with 41 additions and 97 deletions

View File

@ -53,20 +53,21 @@ class KVCache:
):
"""Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
ATTENTION == "paged" and SYSTEM == "rocm"
if not (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
)
self.kv_cache_dtype_str = "auto"
self.kv_cache_dtype = "auto"
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
self.kv_cache_dtype_str = "fp8"
self.kv_cache_dtype = "fp8"
dtype = torch.uint8
element_size = torch.tensor([], dtype=dtype).element_size()
@ -123,27 +124,16 @@ class KVCache:
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
) or (
self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
elif (
self.kv_cache_dtype_str == "fp8"
and ATTENTION == "paged"
and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
log_once(logger.info, "Using FP8 KV cache scales")
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
"Ignoring FP8 KV cache scales, supported only for flashinfer on CUDA and paged attention on ROCm",
)
return False
@ -213,7 +203,7 @@ class KVCache:
key_cache,
value_cache,
slots,
self.kv_cache_dtype_str,
self.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)

View File

@ -119,7 +119,7 @@ def paged_attention(
block_size,
max_s,
None,
kv_cache.kv_cache_dtype_str,
kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
@ -154,7 +154,7 @@ def paged_attention(
block_size,
max_s,
None,
kv_cache.kv_cache_dtype_str,
kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
@ -174,7 +174,7 @@ def paged_attention(
block_size,
max_s,
None,
kv_cache.kv_cache_dtype_str,
kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
None,

View File

@ -398,16 +398,10 @@ class LlamaMLP(nn.Module):
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
out = torch.empty(
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# return self.down_proj(
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
# )
class FlashLlamaLayer(nn.Module):

View File

@ -520,46 +520,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if (
torch.distributed.get_rank() == 0
and input_ids.shape[0] == 262144
and cu_seqlen_prefill is not None
):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
else:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor