improvements

This commit is contained in:
Mohit Sharma 2025-01-03 11:58:14 +00:00
parent fa14d71ac8
commit 43370a1f82
4 changed files with 41 additions and 97 deletions

View File

@ -53,20 +53,21 @@ class KVCache:
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not ( if not (
ATTENTION == "paged" and SYSTEM == "rocm" (ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
): ):
raise ValueError( raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM" "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
) )
if SYSTEM == "rocm" and dtype == torch.float8_e5m2: if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError( raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm" "float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
) )
self.kv_cache_dtype_str = "auto" self.kv_cache_dtype = "auto"
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn: if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
self.kv_cache_dtype_str = "fp8" self.kv_cache_dtype = "fp8"
dtype = torch.uint8 dtype = torch.uint8
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
@ -123,27 +124,16 @@ class KVCache:
self.dtype == torch.float8_e4m3fn self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer" and ATTENTION == "flashinfer"
and SYSTEM == "cuda" and SYSTEM == "cuda"
) or (
self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm"
): ):
log_once( log_once(logger.info, "Using FP8 KV cache scales")
logger.info,
"Using FP8 KV cache scales",
)
return True
elif (
self.kv_cache_dtype_str == "fp8"
and ATTENTION == "paged"
and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
logger.info, logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", "Ignoring FP8 KV cache scales, supported only for flashinfer on CUDA and paged attention on ROCm",
) )
return False return False
@ -213,7 +203,7 @@ class KVCache:
key_cache, key_cache,
value_cache, value_cache,
slots, slots,
self.kv_cache_dtype_str, self.kv_cache_dtype,
kv_scales.key_scale_cpu, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu, kv_scales.value_scale_cpu,
) )

View File

@ -119,7 +119,7 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
kv_cache.kv_cache_dtype_str, kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu, kv_scales.value_scale_cpu,
) )
@ -154,7 +154,7 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
kv_cache.kv_cache_dtype_str, kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu, kv_scales.value_scale_cpu,
) )
@ -174,7 +174,7 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
kv_cache.kv_cache_dtype_str, kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu, kv_scales.value_scale_cpu,
None, None,

View File

@ -398,16 +398,10 @@ class LlamaMLP(nn.Module):
return self.down_proj(out, adapter_data) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
out = torch.empty( return self.down_proj(
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
) )
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# return self.down_proj(
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
# )
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):

View File

@ -520,46 +520,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if (
torch.distributed.get_rank() == 0
and input_ids.shape[0] == 262144
and cu_seqlen_prefill is not None
):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
else:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor