runnable version

This commit is contained in:
System administrator 2024-12-12 14:27:07 +00:00
parent 3a636ed165
commit 649cb1f5f1
2 changed files with 74 additions and 15 deletions

View File

@ -8,7 +8,7 @@ from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"] ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { PREFIX_CACHING = os.environ["USE_PREFIX_CACHING"].lower() in {
"1", "1",
"true", "true",
} }

View File

@ -5,12 +5,17 @@ from typing import Optional, Tuple, Dict, Any
import torch import torch
from opentelemetry import trace from opentelemetry import trace
from loguru import logger from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
empty_cache, empty_cache,
synchronize, synchronize,
@ -57,7 +62,7 @@ def _flash_attention_forward_patched(
query_length: int, query_length: int,
is_causal: bool, is_causal: bool,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
sliding_window: int = -1, sliding_window: Optional[int] = None,
softcap: Optional[float] = None, softcap: Optional[float] = None,
**kwargs, **kwargs,
): ):
@ -67,11 +72,11 @@ def _flash_attention_forward_patched(
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
# Correctly reshape the states # Correctly reshape the states
_, _, num_heads, head_dim = query_states.size() _, num_heads, head_dim = query_states.size()
_, _, num_kv_heads, _ = key_states.size() # _, num_kv_heads, _ = key_states.size()
query_states = query_states.view(-1, num_heads, head_dim) # query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim) # key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim) # value_states = value_states.view(-1, num_kv_heads, head_dim)
# Take care of updating the cache in-place # Take care of updating the cache in-place
kv_cache.store( kv_cache.store(
@ -82,6 +87,7 @@ def _flash_attention_forward_patched(
) )
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
if kwargs["cu_seqlen_prefill"] is not None: if kwargs["cu_seqlen_prefill"] is not None:
attn_output = attention( attn_output = attention(
@ -109,7 +115,8 @@ def _flash_attention_forward_patched(
softcap=softcap, softcap=softcap,
) )
attn_output = attn_output.view(attn_output.shape[0], -1) # attn_output = attn_output.view(attn_output.shape[0], -1)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output return attn_output
@ -122,14 +129,21 @@ class TransformersFlashCausalLM(FlashCausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
kv_cache_dtype: Optional[torch.dtype] = None,
): ):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0 device_count = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda:0")
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif hasattr(torch, "xpu") and torch.xpu.is_available():
@ -157,6 +171,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
device_map=("auto" if device_count > 1 else None), device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2"
) )
if device_count == 1 and quantize != "bitsandbytes": if device_count == 1 and quantize != "bitsandbytes":
model = model.to(device) model = model.to(device)
@ -174,10 +189,44 @@ class TransformersFlashCausalLM(FlashCausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = len(model.model.layers) self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_kv_heads = model.config.num_key_value_heads self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads self.head_size = model.config.hidden_size // model.config.num_attention_heads
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_prefill_state,
create_decode_state,
create_prefill_with_paged_kv_state,
)
self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.num_groups = self.num_heads // self.num_kv_heads
self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups)
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init. # Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model_id=model_id, model_id=model_id,
@ -186,6 +235,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@classmethod @classmethod
@ -198,11 +249,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code) return cls(
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
super().warmup(batch) return super().warmup(batch, max_input_tokens, max_total_tokens)
def forward( def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
@ -270,7 +328,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
) kv_head_mapping=self.kv_head_mapping,
).logits
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits, None return logits, None