mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
adjust warmup and enable vlm
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f95aa42660
commit
36b6612f97
@ -23,7 +23,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
@ -172,7 +172,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
@ -279,8 +279,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
@ -678,23 +679,23 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
# hidden_states = hidden_states.unsqueeze(0)
|
# hidden_states = hidden_states.unsqueeze(0)
|
||||||
# bsz, q_len, _ = hidden_states.size()
|
# bsz, q_len, _ = hidden_states.size()
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
query_states = query_states.view(-1, self.num_heads, self.head_size)
|
|
||||||
query_states = self.q_norm(query_states)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cu_seqlen_q,
|
||||||
cu_seqlen_k,
|
cu_seqlen_k,
|
||||||
max_q,
|
|
||||||
max_k,
|
|
||||||
indices,
|
indices,
|
||||||
) = cross_attention_states
|
) = cross_attention_states
|
||||||
|
bs = cu_seqlen_q.size(0) - 1
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
|
||||||
key_states = self.k_proj(cross_attention_states)
|
key_states = self.k_proj(cross_attention_states)
|
||||||
value_states = self.v_proj(cross_attention_states)
|
value_states = self.v_proj(cross_attention_states)
|
||||||
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
|
||||||
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
value_states = value_states.view(
|
||||||
|
bs, -1, self.num_key_value_heads, self.head_size
|
||||||
|
)
|
||||||
key_states = self.k_norm(key_states)
|
key_states = self.k_norm(key_states)
|
||||||
|
|
||||||
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
@ -705,9 +706,9 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||||
# )
|
# )
|
||||||
# execute sdpa
|
# execute sdpa
|
||||||
query_states = query_states.unsqueeze(0).transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.unsqueeze(0).transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.unsqueeze(0).transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
attn_output = fsdpa_op(
|
attn_output = fsdpa_op(
|
||||||
query_states,
|
query_states,
|
||||||
@ -803,9 +804,10 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states, # [ IB, ...]
|
cross_attention_states, # [ IB, ...]
|
||||||
|
prefill_cache_indices,
|
||||||
|
hpu_attention_meta,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if cross_attention_states is None:
|
if cross_attention_states is None:
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
@ -912,7 +914,7 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
@ -949,8 +951,6 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
* seqlen_k
|
* seqlen_k
|
||||||
)
|
)
|
||||||
max_q = cu_seqlen_q[-1].item()
|
|
||||||
max_k = seqlen_k
|
|
||||||
else:
|
else:
|
||||||
cu_seqlen_q = torch.arange(
|
cu_seqlen_q = torch.arange(
|
||||||
seqlen_q + 1, device=device, dtype=torch.int32
|
seqlen_q + 1, device=device, dtype=torch.int32
|
||||||
@ -965,16 +965,12 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
* seqlen_k
|
* seqlen_k
|
||||||
)
|
)
|
||||||
max_q = seqlen_q
|
|
||||||
max_k = seqlen_k
|
|
||||||
indices = image_indices[:]
|
indices = image_indices[:]
|
||||||
|
|
||||||
cross_attention_states = (
|
cross_attention_states = (
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cu_seqlen_q,
|
||||||
cu_seqlen_k,
|
cu_seqlen_k,
|
||||||
max_q,
|
|
||||||
max_k,
|
|
||||||
indices,
|
indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -986,7 +982,7 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
|
@ -19,7 +19,7 @@ from torch import nn
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
@ -72,7 +72,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
@ -85,7 +85,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
max_s += 1
|
|
||||||
position_ids += 1
|
position_ids += 1
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
@ -110,7 +109,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
|
|||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -742,7 +742,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
@ -829,8 +829,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -24,7 +24,7 @@ from transformers.activations import ACT2FN
|
|||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -485,7 +485,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
@ -573,8 +573,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -40,6 +40,7 @@ from text_generation_server.layers import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
@ -906,7 +907,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
@ -937,8 +938,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
@ -482,7 +483,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
@ -512,8 +513,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -61,7 +61,6 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
from text_generation_server.utils.quantization import get_loader
|
from text_generation_server.utils.quantization import get_loader
|
||||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
empty_cache,
|
empty_cache,
|
||||||
synchronize,
|
synchronize,
|
||||||
@ -77,10 +76,6 @@ tracer = trace.get_tracer(__name__)
|
|||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def small_power_of_2(n: int):
|
|
||||||
return 1 << ((n - 1).bit_length() - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int):
|
def set_sliding_window(sliding_window: int):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
SLIDING_WINDOW = sliding_window
|
SLIDING_WINDOW = sliding_window
|
||||||
@ -91,40 +86,6 @@ def get_sliding_windows() -> int:
|
|||||||
return SLIDING_WINDOW
|
return SLIDING_WINDOW
|
||||||
|
|
||||||
|
|
||||||
def init_cpu_threads_env(rank_id: int, world_size: int):
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
if importlib.util.find_spec("numa") is not None:
|
|
||||||
import numa
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
nodes = numa.info.get_max_node() + 1
|
|
||||||
rank_per_node = math.ceil(world_size / nodes)
|
|
||||||
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
|
|
||||||
node_id = int(rank_id / rank_per_node)
|
|
||||||
rank_offset_per_node = rank_id % rank_per_node
|
|
||||||
if os.getenv("OMP_NUM_THREADS") is None:
|
|
||||||
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
|
|
||||||
else:
|
|
||||||
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
|
|
||||||
if len(numa.memory.get_membind_nodes()) == nodes:
|
|
||||||
numa.memory.set_membind_nodes((node_id))
|
|
||||||
torch.set_num_threads(num_cpus_per_rank)
|
|
||||||
if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
|
|
||||||
cpu_start = num_cpus_per_rank * rank_offset_per_node
|
|
||||||
numa.schedule.run_on_cpus(
|
|
||||||
0,
|
|
||||||
*(
|
|
||||||
numa.info.node_to_cpus(node_id)[
|
|
||||||
cpu_start : cpu_start + num_cpus_per_rank
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
@ -1447,16 +1408,13 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def warmup(
|
def warmup(
|
||||||
self,
|
self,
|
||||||
request: generate_pb2.WarmupRequest,
|
batch: FlashCausalLMBatch,
|
||||||
|
max_input_tokens: Optional[int],
|
||||||
|
max_total_tokens: Optional[int],
|
||||||
):
|
):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
max_input_tokens = request.max_input_tokens
|
|
||||||
max_total_tokens = request.max_total_tokens
|
|
||||||
batch = self.batch_type.from_pb(
|
|
||||||
request.batch, self.tokenizer, self.dtype, self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
@ -1505,10 +1463,10 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||||
if max_total_tokens is None or max_total_tokens == 0:
|
if max_total_tokens is None:
|
||||||
max_total_tokens = sum(batch.cache_lengths)
|
max_total_tokens = sum(batch.cache_lengths)
|
||||||
|
|
||||||
if max_input_tokens is None or max_input_tokens == 0:
|
if max_input_tokens is None:
|
||||||
max_input_tokens = max_total_tokens - 1
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
|
||||||
del _batch, batch
|
del _batch, batch
|
||||||
|
@ -16,7 +16,8 @@ from text_generation_server.models.globals import PREFIX_CACHING
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -447,6 +448,10 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
@ -459,13 +464,15 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
image_grid_thw=batch.image_grid_thw,
|
image_grid_thw=batch.image_grid_thw,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
@ -17,8 +17,8 @@ from text_generation_server.models.flash_vlm_causal_lm import (
|
|||||||
FlashVlmCausalLM,
|
FlashVlmCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -279,6 +279,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
|
|
||||||
cross_attention_states = batch.cross_attention_states
|
cross_attention_states = batch.cross_attention_states
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -286,13 +290,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
max_s=max_s,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
adapter_data=adapter_data,
|
# TODO list
|
||||||
|
adapter_data=None,
|
||||||
image_indices=batch.image_indices[:],
|
image_indices=batch.image_indices[:],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
|
|||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
|
|
||||||
if max_batch_size_str is not None:
|
|
||||||
MAX_BATCH_SIZE = int(max_batch_size_str)
|
|
||||||
else:
|
|
||||||
raise ValueError("MAX_BATCH_SIZE is not set")
|
|
||||||
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
@ -1464,6 +1460,11 @@ class VlmCausalLM(Model):
|
|||||||
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
||||||
max_input_tokens = request.max_input_tokens
|
max_input_tokens = request.max_input_tokens
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
|
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
|
||||||
|
if max_batch_size_str is not None:
|
||||||
|
MAX_BATCH_SIZE = int(max_batch_size_str)
|
||||||
|
else:
|
||||||
|
raise ValueError("MAX_BATCH_SIZE is not set")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
|
@ -18,10 +18,11 @@ from text_generation_server.interceptor import ExceptionInterceptor
|
|||||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id
|
from text_generation_server.models.globals import set_model_id, ATTENTION
|
||||||
from text_generation_server.models.globals import set_adapter_to_index
|
from text_generation_server.models.globals import set_adapter_to_index
|
||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||||
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
@ -109,6 +110,42 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
if ATTENTION == "paged":
|
||||||
|
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||||
|
if (
|
||||||
|
self.model.batch_type in VLM_BATCH_TYPES
|
||||||
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
|
request.batch,
|
||||||
|
self.model.tokenizer,
|
||||||
|
self.model.processor,
|
||||||
|
self.model.model.config,
|
||||||
|
self.model.dtype,
|
||||||
|
self.model.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = self.model.batch_type.from_pb(
|
||||||
|
request.batch,
|
||||||
|
self.model.tokenizer,
|
||||||
|
self.model.dtype,
|
||||||
|
self.model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override default values with None for clearer semantics.
|
||||||
|
max_input_tokens = (
|
||||||
|
request.max_input_tokens
|
||||||
|
if request.HasField("max_input_tokens")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
max_total_tokens = (
|
||||||
|
request.max_total_tokens
|
||||||
|
if request.HasField("max_total_tokens")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
|
self.model.warmup(batch, max_input_tokens, max_total_tokens)
|
||||||
|
)
|
||||||
|
else:
|
||||||
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
self.model.warmup(request)
|
self.model.warmup(request)
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
SUPPORT_CHUNKING: Optional[bool] = None
|
||||||
|
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_support_chunking(support_chunking: bool):
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
SUPPORT_CHUNKING = support_chunking
|
||||||
|
|
||||||
|
|
||||||
|
def get_support_chunking() -> bool:
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
return SUPPORT_CHUNKING
|
||||||
|
|
||||||
|
|
||||||
|
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_prefill_tokens() -> int:
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
return MAX_PREFILL_TOKENS
|
Loading…
Reference in New Issue
Block a user