mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
runnable version
This commit is contained in:
parent
3a636ed165
commit
649cb1f5f1
@ -8,7 +8,7 @@ from text_generation_server.utils.log import log_master
|
|||||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
ATTENTION = os.environ["ATTENTION"]
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
PREFIX_CACHING = os.environ["USE_PREFIX_CACHING"].lower() in {
|
||||||
"1",
|
"1",
|
||||||
"true",
|
"true",
|
||||||
}
|
}
|
||||||
|
@ -5,12 +5,17 @@ from typing import Optional, Tuple, Dict, Any
|
|||||||
import torch
|
import torch
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
|
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
empty_cache,
|
empty_cache,
|
||||||
synchronize,
|
synchronize,
|
||||||
@ -57,7 +62,7 @@ def _flash_attention_forward_patched(
|
|||||||
query_length: int,
|
query_length: int,
|
||||||
is_causal: bool,
|
is_causal: bool,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
sliding_window: int = -1,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -67,11 +72,11 @@ def _flash_attention_forward_patched(
|
|||||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||||
|
|
||||||
# Correctly reshape the states
|
# Correctly reshape the states
|
||||||
_, _, num_heads, head_dim = query_states.size()
|
_, num_heads, head_dim = query_states.size()
|
||||||
_, _, num_kv_heads, _ = key_states.size()
|
# _, num_kv_heads, _ = key_states.size()
|
||||||
query_states = query_states.view(-1, num_heads, head_dim)
|
# query_states = query_states.view(-1, num_heads, head_dim)
|
||||||
key_states = key_states.view(-1, num_kv_heads, head_dim)
|
# key_states = key_states.view(-1, num_kv_heads, head_dim)
|
||||||
value_states = value_states.view(-1, num_kv_heads, head_dim)
|
# value_states = value_states.view(-1, num_kv_heads, head_dim)
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
@ -82,6 +87,7 @@ def _flash_attention_forward_patched(
|
|||||||
)
|
)
|
||||||
|
|
||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
|
||||||
if kwargs["cu_seqlen_prefill"] is not None:
|
if kwargs["cu_seqlen_prefill"] is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
@ -109,7 +115,8 @@ def _flash_attention_forward_patched(
|
|||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(attn_output.shape[0], -1)
|
# attn_output = attn_output.view(attn_output.shape[0], -1)
|
||||||
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -122,14 +129,21 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
config_class=AutoConfig,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
device_count = 0
|
device_count = 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda:0")
|
||||||
device_count = torch.cuda.device_count()
|
device_count = torch.cuda.device_count()
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
@ -157,6 +171,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
device_map=("auto" if device_count > 1 else None),
|
device_map=("auto" if device_count > 1 else None),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
attn_implementation="flash_attention_2"
|
||||||
)
|
)
|
||||||
if device_count == 1 and quantize != "bitsandbytes":
|
if device_count == 1 and quantize != "bitsandbytes":
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
@ -174,10 +189,44 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
|
||||||
self.num_layers = len(model.model.layers)
|
self.num_layers = model.config.num_hidden_layers
|
||||||
|
self.num_heads = model.config.num_attention_heads // self.process_group.size()
|
||||||
self.num_kv_heads = model.config.num_key_value_heads
|
self.num_kv_heads = model.config.num_key_value_heads
|
||||||
|
self.num_kv_heads = (
|
||||||
|
self.num_kv_heads // self.process_group.size()
|
||||||
|
if self.num_kv_heads > 1
|
||||||
|
else self.num_kv_heads
|
||||||
|
)
|
||||||
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
self.kv_cache = []
|
||||||
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_prefill_state,
|
||||||
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_state = create_decode_state(
|
||||||
|
device=device,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_kv_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_kv_heads, dtype=torch.int32, device=device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
# Skip FlashCausalLM init.
|
# Skip FlashCausalLM init.
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -186,6 +235,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -198,11 +249,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code)
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
|
||||||
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
||||||
super().warmup(batch)
|
return super().warmup(batch, max_input_tokens, max_total_tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
@ -270,7 +328,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
|
).logits
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, None
|
return logits, None
|
||||||
|
Loading…
Reference in New Issue
Block a user