runnable version

This commit is contained in:
System administrator 2024-12-12 14:27:07 +00:00
parent 3a636ed165
commit 649cb1f5f1
2 changed files with 74 additions and 15 deletions

View File

@ -8,7 +8,7 @@ from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
PREFIX_CACHING = os.environ["USE_PREFIX_CACHING"].lower() in {
"1",
"true",
}

View File

@ -5,12 +5,17 @@ from typing import Optional, Tuple, Dict, Any
import torch
from opentelemetry import trace
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import (
empty_cache,
synchronize,
@ -57,7 +62,7 @@ def _flash_attention_forward_patched(
query_length: int,
is_causal: bool,
softmax_scale: Optional[float] = None,
sliding_window: int = -1,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
):
@ -67,11 +72,11 @@ def _flash_attention_forward_patched(
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
# Correctly reshape the states
_, _, num_heads, head_dim = query_states.size()
_, _, num_kv_heads, _ = key_states.size()
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
_, num_heads, head_dim = query_states.size()
# _, num_kv_heads, _ = key_states.size()
# query_states = query_states.view(-1, num_heads, head_dim)
# key_states = key_states.view(-1, num_kv_heads, head_dim)
# value_states = value_states.view(-1, num_kv_heads, head_dim)
# Take care of updating the cache in-place
kv_cache.store(
@ -82,6 +87,7 @@ def _flash_attention_forward_patched(
)
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
if kwargs["cu_seqlen_prefill"] is not None:
attn_output = attention(
@ -109,7 +115,8 @@ def _flash_attention_forward_patched(
softcap=softcap,
)
attn_output = attn_output.view(attn_output.shape[0], -1)
# attn_output = attn_output.view(attn_output.shape[0], -1)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output
@ -122,14 +129,21 @@ class TransformersFlashCausalLM(FlashCausalLM):
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
kv_cache_dtype: Optional[torch.dtype] = None,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available():
device = torch.device("cuda")
device = torch.device("cuda:0")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
@ -157,6 +171,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2"
)
if device_count == 1 and quantize != "bitsandbytes":
model = model.to(device)
@ -174,10 +189,44 @@ class TransformersFlashCausalLM(FlashCausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = len(model.model.layers)
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_prefill_state,
create_decode_state,
create_prefill_with_paged_kv_state,
)
self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.num_groups = self.num_heads // self.num_kv_heads
self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups)
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
@ -186,6 +235,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
@ -198,11 +249,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code)
return cls(
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def warmup(self, batch: FlashCausalLMBatch):
def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
super().warmup(batch)
return super().warmup(batch, max_input_tokens, max_total_tokens)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
@ -270,7 +328,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
kv_head_mapping=self.kv_head_mapping,
).logits
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, None