mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
* wip * rollback * refactor to use prefix/postfix namming + fix all_input_ids_tensor * maybe patching vlms? * fix filter and concat * wip, no filter, no concat * current * add prepare_for_prefill * working * load tested * re-create slots * re-create slots * fix slot_filtering_indices * feedback loop * remove log * fix benchmarker * fix vlm and seq2seq * rename to cache and input lengths * fix prefill logprobs * fix launcher * fix logprobs? * idk at this point * max input length * omfg * remove debugging lines * fix tests * fix mllama * fix cargo tests * remove support chunking for paged * Fixing non blocked attentions * Fixing dtype + AMD, Ipex targets. * lint fix. * rename * Fix prefix_caching variable, remove defaults in server (confusing a lot of the times). * Add simple resolution when user specifies ATTENTION=paged. * Put back non default simple tests. * Fix env name --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
93 lines
2.2 KiB
Python
93 lines
2.2 KiB
Python
import intel_extension_for_pytorch as ipex
|
|
import torch
|
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
|
from text_generation_server.layers.attention import Seqlen
|
|
from typing import Optional
|
|
|
|
SUPPORTS_WINDOWING = False
|
|
PREFILL_IN_KV_CACHE = False
|
|
|
|
|
|
def attention(
|
|
q: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
seqlen: Seqlen,
|
|
block_tables: torch.Tensor,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
causal=True,
|
|
softcap: Optional[float] = None,
|
|
):
|
|
out = torch.empty_like(q)
|
|
|
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
|
ipex.llm.functional.varlen_attention(
|
|
q.contiguous() if q.device.type == "xpu" else q,
|
|
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
|
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
|
out,
|
|
seqlen.cu_seqlen_q,
|
|
seqlen.cu_seqlen_q,
|
|
seqlen.max_q,
|
|
seqlen.max_q,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
causal,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def reshape_and_cache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
slots: torch.Tensor,
|
|
):
|
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
key, value, key_cache, value_cache, slots
|
|
)
|
|
|
|
|
|
def paged_attention(
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
kv_head_mapping: torch.Tensor,
|
|
softmax_scale: float,
|
|
block_tables: torch.Tensor,
|
|
seqlen: Seqlen,
|
|
max_s: int,
|
|
softcap: Optional[float] = None,
|
|
):
|
|
out = torch.empty_like(query)
|
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
out,
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
kv_head_mapping,
|
|
softmax_scale,
|
|
block_tables,
|
|
input_lengths,
|
|
BLOCK_SIZE,
|
|
max_s,
|
|
None,
|
|
)
|
|
return out
|
|
|
|
|
|
__all__ = [
|
|
"PREFILL_IN_KV_CACHE",
|
|
"SUPPORTS_WINDOWING",
|
|
"attention",
|
|
"paged_attention",
|
|
"reshape_and_cache",
|
|
]
|