mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
improvements
This commit is contained in:
parent
fa14d71ac8
commit
43370a1f82
@ -53,20 +53,21 @@ class KVCache:
|
|||||||
):
|
):
|
||||||
"""Construct the key-value cache for a layer."""
|
"""Construct the key-value cache for a layer."""
|
||||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
|
if not (
|
||||||
ATTENTION == "paged" and SYSTEM == "rocm"
|
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||||
|
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
|
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
|
||||||
)
|
)
|
||||||
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
|
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_cache_dtype_str = "auto"
|
self.kv_cache_dtype = "auto"
|
||||||
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
|
||||||
self.kv_cache_dtype_str = "fp8"
|
self.kv_cache_dtype = "fp8"
|
||||||
dtype = torch.uint8
|
dtype = torch.uint8
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
@ -123,27 +124,16 @@ class KVCache:
|
|||||||
self.dtype == torch.float8_e4m3fn
|
self.dtype == torch.float8_e4m3fn
|
||||||
and ATTENTION == "flashinfer"
|
and ATTENTION == "flashinfer"
|
||||||
and SYSTEM == "cuda"
|
and SYSTEM == "cuda"
|
||||||
|
) or (
|
||||||
|
self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm"
|
||||||
):
|
):
|
||||||
log_once(
|
log_once(logger.info, "Using FP8 KV cache scales")
|
||||||
logger.info,
|
|
||||||
"Using FP8 KV cache scales",
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
elif (
|
|
||||||
self.kv_cache_dtype_str == "fp8"
|
|
||||||
and ATTENTION == "paged"
|
|
||||||
and SYSTEM == "rocm"
|
|
||||||
):
|
|
||||||
log_once(
|
|
||||||
logger.info,
|
|
||||||
"Using FP8 KV cache scales",
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
# We have scales, but not the correct FP8 cache type, so warn once.
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
|
"Ignoring FP8 KV cache scales, supported only for flashinfer on CUDA and paged attention on ROCm",
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -213,7 +203,7 @@ class KVCache:
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
slots,
|
slots,
|
||||||
self.kv_cache_dtype_str,
|
self.kv_cache_dtype,
|
||||||
kv_scales.key_scale_cpu,
|
kv_scales.key_scale_cpu,
|
||||||
kv_scales.value_scale_cpu,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
|
@ -119,7 +119,7 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
kv_cache.kv_cache_dtype_str,
|
kv_cache.kv_cache_dtype,
|
||||||
kv_scales.key_scale_cpu,
|
kv_scales.key_scale_cpu,
|
||||||
kv_scales.value_scale_cpu,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
@ -154,7 +154,7 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
kv_cache.kv_cache_dtype_str,
|
kv_cache.kv_cache_dtype,
|
||||||
kv_scales.key_scale_cpu,
|
kv_scales.key_scale_cpu,
|
||||||
kv_scales.value_scale_cpu,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
@ -174,7 +174,7 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
kv_cache.kv_cache_dtype_str,
|
kv_cache.kv_cache_dtype,
|
||||||
kv_scales.key_scale_cpu,
|
kv_scales.key_scale_cpu,
|
||||||
kv_scales.value_scale_cpu,
|
kv_scales.value_scale_cpu,
|
||||||
None,
|
None,
|
||||||
|
@ -398,16 +398,10 @@ class LlamaMLP(nn.Module):
|
|||||||
return self.down_proj(out, adapter_data)
|
return self.down_proj(out, adapter_data)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
out = torch.empty(
|
return self.down_proj(
|
||||||
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
)
|
)
|
||||||
ops.silu_and_mul(out, gate_up_states)
|
|
||||||
return self.down_proj(out, adapter_data)
|
|
||||||
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
|
||||||
# return self.down_proj(
|
|
||||||
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
|
@ -520,46 +520,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if (
|
|
||||||
torch.distributed.get_rank() == 0
|
|
||||||
and input_ids.shape[0] == 262144
|
|
||||||
and cu_seqlen_prefill is not None
|
|
||||||
):
|
|
||||||
with torch.profiler.profile(
|
|
||||||
activities=[
|
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
record_shapes=True,
|
|
||||||
) as prof:
|
|
||||||
true_max_s = max_s
|
|
||||||
if prefill_cache_indices is not None:
|
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
|
||||||
slots = slots[prefill_cache_indices]
|
|
||||||
elif self.max_past is not None:
|
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
|
||||||
# kernel requires the true values
|
|
||||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids,
|
|
||||||
position_ids,
|
|
||||||
cu_seqlen_prefill,
|
|
||||||
kv_cache,
|
|
||||||
block_tables,
|
|
||||||
slots,
|
|
||||||
seqlen,
|
|
||||||
max_s,
|
|
||||||
true_max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
|
|
||||||
else:
|
|
||||||
true_max_s = max_s
|
true_max_s = max_s
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user