mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-08 18:32:06 +00:00
parent
17f0d57581
commit
1d3a4ab851
14
Cargo.lock
generated
14
Cargo.lock
generated
@ -4177,7 +4177,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4200,7 +4200,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
@ -4220,7 +4220,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
@ -4238,7 +4238,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -4259,7 +4259,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4308,7 +4308,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v2"
|
name = "text-generation-router-v2"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4357,7 +4357,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v3"
|
name = "text-generation-router-v3"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
@ -41,7 +41,7 @@ COPY launcher launcher
|
|||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest AS base
|
FROM vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest AS base
|
||||||
|
|
||||||
ENV ATTENTION=default
|
ENV ATTENTION=default
|
||||||
ENV PREFIX_CACHING=0
|
ENV PREFIX_CACHING=0
|
||||||
@ -75,7 +75,7 @@ RUN cd server && \
|
|||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install --no-deps -r requirements.txt && \
|
pip install --no-deps -r requirements.txt && \
|
||||||
bash ./dill-0.3.8-patch.sh && \
|
bash ./dill-0.3.8-patch.sh && \
|
||||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 && \
|
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.20.0 && \
|
||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
pip install . --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
|
|
||||||
|
@ -123,7 +123,6 @@ impl Client {
|
|||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||||
if n_tokens == 0 {
|
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
// Safe unwrap, because we control the data.
|
// Safe unwrap, because we control the data.
|
||||||
@ -132,7 +131,6 @@ impl Client {
|
|||||||
})
|
})
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||||
// been updated to support chunks.
|
// been updated to support chunks.
|
||||||
|
@ -22,7 +22,7 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
|
|||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
peft = "^0.10"
|
peft = "^0.10"
|
||||||
optimum-habana = "1.15.0"
|
#optimum-habana = "1.15.0"
|
||||||
transformers = "4.45.2"
|
transformers = "4.45.2"
|
||||||
numpy = "1.26.4"
|
numpy = "1.26.4"
|
||||||
accelerate = "0.33.0"
|
accelerate = "0.33.0"
|
||||||
|
@ -46,7 +46,7 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
|
|||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
optimum-habana @ git+https://github.com/huggingface/optimum-habana.git@v1.16-release ; python_version >= "3.9" and python_version < "3.13"
|
||||||
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
|
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -87,3 +87,18 @@ wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -17,14 +17,11 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.starcoder import StarCoder
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
from text_generation_server.models.custom_modeling.mllama import MllamaForConditionalGeneration
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
# from text_generation_server.models.custom_modeling.mllama import (
|
|
||||||
# MllamaForConditionalGeneration,
|
|
||||||
# )
|
|
||||||
from text_generation_server.utils.adapter import (
|
from text_generation_server.utils.adapter import (
|
||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
build_layer_weight_lookup,
|
build_layer_weight_lookup,
|
||||||
@ -196,6 +193,17 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_type == "mllama":
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_class=MllamaForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=None,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -274,7 +274,6 @@ class CausalLMBatch(Batch):
|
|||||||
top_n_tokens: List[int]
|
top_n_tokens: List[int]
|
||||||
top_n_tokens_tensor: torch.Tensor
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
input_length: int
|
|
||||||
|
|
||||||
# Past metadata
|
# Past metadata
|
||||||
logits = None
|
logits = None
|
||||||
@ -924,7 +923,6 @@ class CausalLM(Model):
|
|||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"token_idx": token_idx,
|
"token_idx": token_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
||||||
if self.model.config.model_type == "llama":
|
if self.model.config.model_type == "llama":
|
||||||
kwargs["lazy_mode"] = LAZY_MODE == 1
|
kwargs["lazy_mode"] = LAZY_MODE == 1
|
||||||
|
@ -20,6 +20,7 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
unpad_image,
|
unpad_image,
|
||||||
@ -240,10 +241,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||||
# that are set to 0
|
# that are set to 0
|
||||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
|
|
||||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
|
|
||||||
# Get the target length
|
# Get the target length
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
extended_attention_mask = torch.ones(
|
extended_attention_mask = torch.ones(
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -76,19 +76,20 @@ IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|||||||
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
|
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
|
||||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
||||||
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072))
|
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072))
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||||
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||||
|
CROSS_ATTENTION_LAYERS= []
|
||||||
def round_up(warmup_list:list, num) :
|
def round_up(warmup_list:list, num) :
|
||||||
i = 0
|
i = 0
|
||||||
for i in warmup_list:
|
for i in warmup_list:
|
||||||
if num <= i :
|
if num <= i :
|
||||||
break
|
break
|
||||||
return i
|
return i if i > 0 else num
|
||||||
|
|
||||||
def split(string) -> List[Dict[str, str]]:
|
def split(string) -> List[Dict[str, str]]:
|
||||||
parts = []
|
parts = []
|
||||||
@ -106,25 +107,19 @@ def split(string) -> List[Dict[str, str]]:
|
|||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
def image_text_replacement(config) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
image_seq_len = 64
|
|
||||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||||
if processor.image_processor.do_image_splitting:
|
|
||||||
image_str *= 5
|
|
||||||
return image_str
|
return image_str
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
return "<image>"
|
||||||
num_features = get_number_of_features(height, width, config)
|
|
||||||
from loguru import logger
|
|
||||||
return "<image>" * num_features
|
|
||||||
|
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
return "<image>" * config.text_config.num_image_tokens
|
return "<image>"
|
||||||
|
elif config.model_type == "mllama":
|
||||||
|
return "<|image|>"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement_fixup(config, text: str) -> str:
|
def image_text_replacement_fixup(config, text: str) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
return text.replace(
|
return text.replace(
|
||||||
@ -192,6 +187,95 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
aspect_ratio_ids: Optional[torch.Tensor] = None
|
||||||
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||||
|
cross_attention_mask: Optional[torch.Tensor] = None
|
||||||
|
prefilling: bool = True
|
||||||
|
token_idx: torch.Tensor = None
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
batch_id,
|
||||||
|
requests,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
merged_kv_cache,
|
||||||
|
next_token_chooser,
|
||||||
|
top_n_tokens,
|
||||||
|
top_n_tokens_tensor,
|
||||||
|
input_length,
|
||||||
|
pixel_values: Optional[List[torch.Tensor]] = None,
|
||||||
|
pixel_attention_mask: Optional[List[torch.Tensor]] = None,
|
||||||
|
image_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||||
|
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||||
|
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
prefilling: Optional[bool] = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
batch_id = batch_id,
|
||||||
|
requests = requests,
|
||||||
|
input_ids = input_ids,
|
||||||
|
attention_mask = attention_mask,
|
||||||
|
position_ids = position_ids,
|
||||||
|
past_key_values = past_key_values,
|
||||||
|
merged_kv_cache = merged_kv_cache,
|
||||||
|
next_token_chooser = next_token_chooser,
|
||||||
|
top_n_tokens = top_n_tokens,
|
||||||
|
top_n_tokens_tensor = top_n_tokens_tensor,
|
||||||
|
input_length = input_length)
|
||||||
|
|
||||||
|
self.pixel_values = pixel_values
|
||||||
|
self.pixel_attention_mask = pixel_attention_mask
|
||||||
|
self.image_sizes = image_sizes
|
||||||
|
self.aspect_ratio_ids = aspect_ratio_ids
|
||||||
|
self.aspect_ratio_mask = aspect_ratio_mask
|
||||||
|
self.cross_attention_mask = cross_attention_mask
|
||||||
|
self.prefilling = prefilling
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_idx(self):
|
||||||
|
if self.prefilling:
|
||||||
|
# no right padding for prefill
|
||||||
|
token_idx_scalar = self.attention_mask.shape[-1] - 1
|
||||||
|
return torch.tensor(token_idx_scalar).to(self.attention_mask.device)
|
||||||
|
else:
|
||||||
|
token_idx_scalar = (
|
||||||
|
self.attention_mask.shape[-1] - self.right_padding
|
||||||
|
)
|
||||||
|
return torch.tensor(token_idx_scalar).to(self.attention_mask.device)
|
||||||
|
|
||||||
|
def padding_process(self, pad_id:int):
|
||||||
|
#self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1)
|
||||||
|
right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1]
|
||||||
|
self.input_ids = torch.nn.functional.pad(self.input_ids, (0, right_padding), value=pad_id)
|
||||||
|
self.attention_mask = torch.nn.functional.pad(
|
||||||
|
self.attention_mask, (0, right_padding), value=0
|
||||||
|
)
|
||||||
|
# if self.position_ids is not None:
|
||||||
|
# self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1
|
||||||
|
if self.cross_attention_mask is not None:
|
||||||
|
self.cross_attention_mask = torch.nn.functional.pad(
|
||||||
|
self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0
|
||||||
|
)
|
||||||
|
if self.past is not None:
|
||||||
|
past_key_values_list = list(self.past_key_values)
|
||||||
|
for layer_id in range(len(self.past)):
|
||||||
|
past_key_value_list = list(self.past_key_values[layer_id])
|
||||||
|
if layer_id not in CROSS_ATTENTION_LAYERS:
|
||||||
|
past_key_value_list[0] = torch.nn.functional.pad(
|
||||||
|
self.past_key_values[layer_id][0], (0, 0, 0, right_padding), value=0
|
||||||
|
)
|
||||||
|
past_key_value_list[1] = torch.nn.functional.pad(
|
||||||
|
self.past_key_values[layer_id][1], (0, 0, 0, right_padding), value=0
|
||||||
|
)
|
||||||
|
past_key_values_list[layer_id] = tuple(past_key_value_list)
|
||||||
|
self.past_key_values = tuple(past_key_values_list)
|
||||||
|
|
||||||
|
self.prefilling = False
|
||||||
|
self.input_length = self.input_length
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tokenized(
|
def from_tokenized(
|
||||||
@ -234,24 +318,24 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
bucket_size = max_input_length
|
bucket_size = max_input_length
|
||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if is_warmup is False:
|
if is_warmup is False:
|
||||||
if input_len < max_input_length :
|
|
||||||
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
||||||
if rounded_seq_len <= max_input_length:
|
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
else:
|
|
||||||
bucket_size = max_input_length - 1
|
|
||||||
left_padding = bucket_size - input_len
|
left_padding = bucket_size - input_len
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
attention_mask = tokenized_inputs["attention_mask"]
|
attention_mask = tokenized_inputs["attention_mask"]
|
||||||
|
cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None)
|
||||||
# Allocate space for first token
|
# Allocate space for first token
|
||||||
if left_padding > 0:
|
|
||||||
input_ids = torch.nn.functional.pad(
|
input_ids = torch.nn.functional.pad(
|
||||||
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
attention_mask = torch.nn.functional.pad(
|
attention_mask = torch.nn.functional.pad(
|
||||||
attention_mask, (left_padding, 1), value=0
|
attention_mask, (left_padding, 1), value=0
|
||||||
)
|
)
|
||||||
|
if cross_attention_mask is not None:
|
||||||
|
cross_attention_mask = torch.nn.functional.pad(
|
||||||
|
cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0
|
||||||
|
)
|
||||||
all_input_ids = torch.nn.functional.pad(
|
all_input_ids = torch.nn.functional.pad(
|
||||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||||
).T.split(1, dim=1)
|
).T.split(1, dim=1)
|
||||||
@ -265,9 +349,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
r.all_input_ids = all_input_ids[r.idx]
|
r.all_input_ids = all_input_ids[r.idx]
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
attention_mask = attention_mask.to(device)
|
attention_mask = attention_mask.to(device)
|
||||||
|
cross_attention_mask = cross_attention_mask.to(device) if cross_attention_mask is not None else None
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
@ -282,53 +366,46 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
input_length=input_len,
|
input_length=input_len,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup
|
||||||
):
|
):
|
||||||
# Process images first. We need all of them so that the processor
|
image_inputs = {}
|
||||||
# can make the image splits the same size. And we need the final
|
texts = []
|
||||||
# sizes to insert correct number of image tokens.
|
|
||||||
images = []
|
images = []
|
||||||
for r in requests:
|
image_indices = []
|
||||||
|
batch_tokenized_inputs = {}
|
||||||
|
|
||||||
|
for i, r in enumerate(requests):
|
||||||
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
|
curr_text = ""
|
||||||
|
curr_image = None
|
||||||
|
curr_i = None
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
pass
|
curr_text += chunk.text
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
if config.model_type == "llava_next":
|
# TODO unsure about BOS
|
||||||
images.append(image)
|
if config.model_type == "mllama":
|
||||||
|
curr_text = image_text_replacement(config) + curr_text
|
||||||
else:
|
else:
|
||||||
images.append([image])
|
curr_text += image_text_replacement(config)
|
||||||
|
curr_image = image
|
||||||
|
curr_i = i
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
image_inputs = None
|
texts.append(curr_text)
|
||||||
if images:
|
if curr_image is not None:
|
||||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
if config.model_type == "mllama":
|
||||||
|
images.append([curr_image])
|
||||||
batch_inputs = []
|
else:
|
||||||
max_truncation = 0
|
images.append(curr_image)
|
||||||
image_id = 0
|
|
||||||
for r in requests:
|
|
||||||
full_text = ""
|
|
||||||
for chunk in r.input_chunks.chunks:
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
|
||||||
if chunk_type == "text":
|
|
||||||
full_text += chunk.text
|
|
||||||
elif chunk_type == "image":
|
|
||||||
full_text += image_text_replacement(
|
|
||||||
processor, image_inputs, config, image_id
|
|
||||||
)
|
|
||||||
image_id += 1
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
missing_inputs = 0
|
missing_inputs = 0
|
||||||
dummy_images = None
|
dummy_images = None
|
||||||
@ -337,45 +414,37 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
missing_inputs = new_bs - len(requests)
|
missing_inputs = new_bs - len(requests)
|
||||||
if missing_inputs > 0:
|
if missing_inputs > 0:
|
||||||
dummy_inputs = []
|
dummy_inputs = []
|
||||||
if len(batch_inputs) > 0:
|
if len(texts) > 0:
|
||||||
dummy_inputs = [batch_inputs[0]] * missing_inputs
|
dummy_inputs = [texts[0]] * missing_inputs
|
||||||
|
if config.model_type == "mllama":
|
||||||
batch_inputs += dummy_inputs
|
dummy_images = [images[0]] * missing_inputs
|
||||||
|
else:
|
||||||
batch_tokenized_inputs = tokenizer(
|
dummy_images = [images[0]] * missing_inputs
|
||||||
batch_inputs,
|
texts += dummy_inputs
|
||||||
|
images += dummy_images
|
||||||
|
processor_output = processor(images,
|
||||||
|
texts,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=r.truncate,
|
||||||
add_special_tokens=not config.model_type == "paligemma",
|
add_special_tokens=r.add_special_tokens,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding="longest",
|
padding="longest")
|
||||||
return_token_type_ids=False,
|
if "input_ids" in processor_output:
|
||||||
)
|
batch_tokenized_inputs.update({"input_ids" : processor_output["input_ids"]})
|
||||||
|
if "attention_mask" in processor_output:
|
||||||
if missing_inputs > 0 and image_inputs is not None:
|
batch_tokenized_inputs.update({"attention_mask" : processor_output["attention_mask"]})
|
||||||
dummy_shape = list(image_inputs['pixel_values'].shape)
|
if "cross_attention_mask" in processor_output:
|
||||||
dummy_shape[0] = missing_inputs
|
batch_tokenized_inputs.update({"cross_attention_mask" : processor_output["cross_attention_mask"]})
|
||||||
dummy_images = torch.rand(dummy_shape)
|
if "pixel_values" in processor_output:
|
||||||
new_image_inputs = {
|
image_inputs.update({"pixel_values" : processor_output["pixel_values"]})
|
||||||
"pixel_values": torch.cat(
|
if "pixel_attention_mask" in processor_output:
|
||||||
(image_inputs['pixel_values'], dummy_images), dim=0
|
image_inputs.update({"pixel_attention_mask" : processor_output["pixel_attention_mask"]})
|
||||||
),
|
if "aspect_ratio_ids" in processor_output:
|
||||||
}
|
image_inputs.update({"aspect_ratio_ids" : processor_output["aspect_ratio_ids"]})
|
||||||
if "pixel_attention_mask" in image_inputs:
|
if "aspect_ratio_mask" in processor_output:
|
||||||
dummy_shape = list(image_inputs['pixel_attention_mask'].shape)
|
image_inputs.update({"aspect_ratio_mask" : processor_output["aspect_ratio_mask"]})
|
||||||
dummy_shape[0] = missing_inputs
|
if "image_sizes" in processor_output:
|
||||||
dummy_attention = torch.zeros(dummy_shape)
|
image_inputs.update({"image_sizes" : processor_output["image_sizes"]})
|
||||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
||||||
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
|
|
||||||
)
|
|
||||||
if "image_sizes" in image_inputs:
|
|
||||||
dummy_shape = list(list(image_inputs['image_sizes'])[0])
|
|
||||||
dummy_shape = missing_inputs*[dummy_shape]
|
|
||||||
dummy_sizes = torch.IntTensor(dummy_shape)
|
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
|
||||||
(image_inputs["image_sizes"], dummy_sizes), dim=0
|
|
||||||
)
|
|
||||||
image_inputs = new_image_inputs
|
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
@ -393,7 +462,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor, config, is_warmup
|
pb.requests, tokenizer, processor, config, is_warmup
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup)
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
if "pixel_attention_mask" in image_inputs:
|
if "pixel_attention_mask" in image_inputs:
|
||||||
@ -406,10 +475,22 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
else:
|
else:
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
|
if "aspect_ratio_ids" in image_inputs:
|
||||||
|
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.aspect_ratio_ids = None
|
||||||
|
if "aspect_ratio_mask" in image_inputs:
|
||||||
|
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.aspect_ratio_mask = None
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
|
batch.aspect_ratio_ids = None
|
||||||
|
batch.aspect_ratio_mask = None
|
||||||
|
batch.cross_attention_mask = None
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -423,93 +504,220 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch":
|
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch":
|
||||||
if not all(b.past_key_values is not None for b in batches):
|
if not all(b.past_key_values is not None for b in batches):
|
||||||
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||||
|
# Used for padding
|
||||||
|
|
||||||
total_requests = sum(len(b) for b in batches)
|
total_requests = sum(len(b) for b in batches)
|
||||||
new_bs = total_requests
|
new_bs = total_requests
|
||||||
if is_warmup is False :
|
if not is_warmup:
|
||||||
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST,total_requests)
|
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST,total_requests)
|
||||||
batch_id = batches[0].batch_id
|
|
||||||
device = batches[0].input_ids.device
|
|
||||||
|
|
||||||
input_lengths = [b.input_length for b in batches]
|
|
||||||
max_input_length = max(input_lengths)
|
|
||||||
offsets = [max_input_length - b.input_length for b in batches]
|
|
||||||
|
|
||||||
cur_padding = [b.right_padding for b in batches]
|
|
||||||
# For prefill there is a space allocated only for first token
|
|
||||||
# Need to add padding to the max total tokens before first decode
|
|
||||||
|
|
||||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
|
||||||
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
|
||||||
reshape = (batches[dst_batch_idx].batch_size < new_bs)
|
|
||||||
|
|
||||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
|
||||||
# FIXME: max_seq_len for non optimized code
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
scenario = 'CONCAT'
|
scenario = "CONCAT"
|
||||||
elif reshape:
|
elif batches[0].prefilling:
|
||||||
scenario = 'RESHAPE'
|
scenario = "SHIFT"
|
||||||
elif cur_padding[dst_batch_idx] <= 0:
|
|
||||||
scenario = 'SHIFT'
|
|
||||||
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
|
||||||
max_input_length = max_input_length + offsets[dst_batch_idx]
|
|
||||||
else:
|
else:
|
||||||
# Nothing to do
|
|
||||||
return batches[0]
|
return batches[0]
|
||||||
|
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
scenario,
|
||||||
f' reqs:{[len(b) for b in batches]}'
|
f"bs:{[b.batch_size for b in batches]}->{new_bs}"
|
||||||
f' offsets:{offsets}'
|
f" reqs:{[len(b) for b in batches]}"
|
||||||
f' input_lengths:{input_lengths}'
|
)
|
||||||
f' cur_padding:{cur_padding}'
|
|
||||||
f' dst_batch:{dst_batch_idx}')
|
|
||||||
|
|
||||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
if scenario == "SHIFT":
|
||||||
flat_requests = list(itertools.chain(*grouped_requests))
|
batch = batches[0]
|
||||||
|
batch.padding_process(pad_token_id)
|
||||||
|
return batch
|
||||||
|
|
||||||
for i in range(len(batches)):
|
total_batch_size = 0
|
||||||
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
max_input_length = 0
|
||||||
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
for i, batch in enumerate(batches):
|
||||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
total_batch_size += len(batch)
|
||||||
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
max_input_length = max(max_input_length, batch.input_length)
|
||||||
batches[dst_batch_idx].expand_bs(new_bs)
|
# Batch attributes
|
||||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
requests = []
|
||||||
|
input_lengths = []
|
||||||
|
top_n_tokens = []
|
||||||
|
max_tokens = 0
|
||||||
|
parameters = []
|
||||||
|
fsm_grammar_states = []
|
||||||
|
|
||||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
# Batch tensors
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
input_ids = None
|
||||||
|
attention_mask = None
|
||||||
|
position_ids = None
|
||||||
|
past_key_values = []
|
||||||
|
top_n_tokens_tensor = None
|
||||||
|
cross_attention_mask = None
|
||||||
|
# Used for slicing correctly inside the tensors
|
||||||
|
# Equivalent to a cumsum on batch sizes
|
||||||
|
start_index = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
keep_indices = []
|
||||||
|
for req in batch.requests:
|
||||||
|
keep_indices.append(req.idx)
|
||||||
|
|
||||||
parameters = [r.data.parameters for r in flat_requests]
|
requests.extend(batch.requests)
|
||||||
# append the dummy parameters for dummy requests
|
parameters.extend([r.data.parameters for r in batch.requests])
|
||||||
batch_size = batches[dst_batch_idx].batch_size
|
fsm_grammar_states.extend([batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices])
|
||||||
parameters = pad_next_token_chooser_parameters(parameters, batch_size)
|
input_lengths.extend([batch.input_length])
|
||||||
|
top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices])
|
||||||
|
|
||||||
# update past grammar states
|
# Slicing end index for this batch
|
||||||
fsm_grammar_states = [0] * batch_size
|
end_index = start_index + len(batch)
|
||||||
|
|
||||||
|
# We only concatenate batches that did at least one step
|
||||||
|
if batch.past_key_values is None:
|
||||||
|
raise ValueError("only concatenate prefilled batches")
|
||||||
|
|
||||||
|
# Create empty tensor
|
||||||
|
# input_ids is always of shape [batch_size, 1]
|
||||||
|
# We do not need to pad it
|
||||||
|
if input_ids is None:
|
||||||
|
input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS))
|
||||||
|
# # Copy to correct indices
|
||||||
|
|
||||||
|
left_offset = max_input_length - batch.input_length
|
||||||
|
right_padding = MAX_TOTAL_TOKENS - max_input_length
|
||||||
|
input_ids[start_index:end_index, left_offset:-right_padding] = batch.input_ids[keep_indices, :batch.input_length]
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
new_bs,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[keep_indices]
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
|
(new_bs, MAX_TOTAL_TOKENS),
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask[
|
||||||
|
start_index:end_index,
|
||||||
|
left_offset:-right_padding,
|
||||||
|
] = batch.attention_mask[
|
||||||
|
keep_indices,
|
||||||
|
:batch.input_length,
|
||||||
|
]
|
||||||
|
|
||||||
|
if batch.cross_attention_mask is not None:
|
||||||
|
cross_attention_mask_shape = list(batch.cross_attention_mask.shape)
|
||||||
|
cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS
|
||||||
|
cross_attention_mask_shape[0] = new_bs
|
||||||
|
cross_attention_mask_shape = torch.Size(cross_attention_mask_shape)
|
||||||
|
if cross_attention_mask is None:
|
||||||
|
cross_attention_mask = batch.cross_attention_mask.new_zeros(
|
||||||
|
cross_attention_mask_shape,
|
||||||
|
)
|
||||||
|
cross_attention_mask[
|
||||||
|
start_index:end_index,
|
||||||
|
left_offset:-right_padding,
|
||||||
|
] = batch.cross_attention_mask[
|
||||||
|
keep_indices,
|
||||||
|
:batch.input_length,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create empty tensor
|
||||||
|
# position_ids is always of shape [batch_size, 1]
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = batch.position_ids.new_empty((new_bs, 1))
|
||||||
|
position_ids[start_index:end_index] = batch.position_ids[keep_indices, :]
|
||||||
|
|
||||||
|
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
||||||
|
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||||
|
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||||
|
# And ensure that we can update tensors in-place
|
||||||
|
if isinstance(batch.past_key_values, tuple):
|
||||||
|
batch.past_key_values = [
|
||||||
|
[t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer]
|
||||||
|
for layer in batch.past_key_values
|
||||||
|
]
|
||||||
|
elif len(batch.past_key_values[0][0].shape) == 3:
|
||||||
|
for layer in batch.past_key_values:
|
||||||
|
for k, t in enumerate(layer):
|
||||||
|
layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:])
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
|
first_past_kvs = batches[0].past_key_values
|
||||||
|
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
||||||
|
past_key_values = []
|
||||||
|
for layer_id in range(len(batches[0].past_key_values)):
|
||||||
|
if layer_id in CROSS_ATTENTION_LAYERS:
|
||||||
|
padded_past_keys_shape = list(batches[0].past_key_values[layer_id][0].shape)
|
||||||
|
padded_past_keys_shape[0] = new_bs
|
||||||
|
padded_past_keys_shape = torch.Size(padded_past_keys_shape)
|
||||||
|
else:
|
||||||
|
padded_past_keys_shape = (
|
||||||
|
new_bs,
|
||||||
|
num_heads,
|
||||||
|
MAX_TOTAL_TOKENS,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_past_keys = first_past_kvs[layer_id][0].new_zeros(padded_past_keys_shape)
|
||||||
|
padded_past_values = first_past_kvs[layer_id][1].new_zeros(padded_past_keys_shape)
|
||||||
|
start_index = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
for i, req in enumerate(batch.requests):
|
keep_indices = []
|
||||||
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
|
for req in batch.requests:
|
||||||
|
keep_indices.append(req.idx)
|
||||||
|
|
||||||
|
left_offset = max_input_length - batch.input_length
|
||||||
|
right_padding = MAX_TOTAL_TOKENS - max_input_length
|
||||||
|
past_keys = batch.past_key_values[layer_id][0]
|
||||||
|
past_values = batch.past_key_values[layer_id][1]
|
||||||
|
# Clear reference to the original tensor
|
||||||
|
batch.past_key_values[layer_id] = None
|
||||||
|
|
||||||
|
# Slicing end index for this batch
|
||||||
|
end_index = start_index + len(batch)
|
||||||
|
# We slice the keys to remove the padding from previous batches
|
||||||
|
if layer_id in CROSS_ATTENTION_LAYERS:
|
||||||
|
padded_past_keys[start_index:end_index, :, :, :] = (
|
||||||
|
past_keys[keep_indices, :, :, :]
|
||||||
|
)
|
||||||
|
padded_past_values[start_index:end_index, :, :, :] = (
|
||||||
|
past_values[keep_indices, :, :, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
padded_past_keys[start_index:end_index, :, left_offset:-right_padding, :] = (
|
||||||
|
past_keys[keep_indices, :, :batch.input_length, :]
|
||||||
|
)
|
||||||
|
padded_past_values[start_index:end_index, :, left_offset:-right_padding, :] = (
|
||||||
|
past_values[keep_indices, :, :batch.input_length, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
|
past_key_values.append(tuple([padded_past_keys, padded_past_values]))
|
||||||
|
past_key_values = tuple(past_key_values)
|
||||||
|
|
||||||
|
batch_id = batches[0].batch_id
|
||||||
|
top_n_tokens.extend([-1] * (new_bs - total_batch_size))
|
||||||
|
fsm_grammar_states.extend([-1] * (new_bs - total_batch_size))
|
||||||
|
|
||||||
|
for idx, req in enumerate(requests):
|
||||||
|
req.idx = idx
|
||||||
|
|
||||||
|
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
parameters,
|
parameters,
|
||||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
batches[0].next_token_chooser.dtype,
|
||||||
batches[dst_batch_idx].next_token_chooser.device,
|
batches[0].next_token_chooser.device,
|
||||||
batches[dst_batch_idx].next_token_chooser.tokenizer,
|
batches[0].next_token_chooser.tokenizer,
|
||||||
fsm_grammar_states,
|
fsm_grammar_states,
|
||||||
quantization_enabled=hq_env.is_quantization_enabled,
|
quantization_enabled=hq_env.is_quantization_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = batches[dst_batch_idx].input_ids
|
|
||||||
attention_mask = batches[dst_batch_idx].attention_mask
|
|
||||||
position_ids = batches[dst_batch_idx].position_ids
|
|
||||||
past_key_values = batches[dst_batch_idx].past_key_values
|
|
||||||
input_length = max_input_length
|
input_length = max_input_length
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
requests=flat_requests,
|
requests=requests,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -519,6 +727,13 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
input_length=input_length,
|
input_length=input_length,
|
||||||
|
pixel_values=None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes=None,
|
||||||
|
aspect_ratio_ids=None,
|
||||||
|
aspect_ratio_mask=None,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
|
prefilling=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
class VlmCausalLM(Model):
|
class VlmCausalLM(Model):
|
||||||
@ -644,6 +859,11 @@ class VlmCausalLM(Model):
|
|||||||
self.kwargs["flash_attention_recompute"] = True
|
self.kwargs["flash_attention_recompute"] = True
|
||||||
|
|
||||||
self.speculate = get_speculate()
|
self.speculate = get_speculate()
|
||||||
|
if model.config.model_type == "mllama":
|
||||||
|
global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS
|
||||||
|
CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers
|
||||||
|
BASE_IMAGE_TOKENS = 0
|
||||||
|
|
||||||
super(VlmCausalLM, self).__init__(
|
super(VlmCausalLM, self).__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
@ -763,39 +983,39 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
batch: VlmCausalLMBatch,
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
token_idx,
|
|
||||||
past_key_values: Optional[List[Tuple]] = None,
|
|
||||||
pixel_values: Optional[List[torch.Tensor]] = None,
|
|
||||||
image_sizes: Optional[List[Tuple[int, int]]] = None,
|
|
||||||
bypass_hpu_graph: Optional[bool] = None,
|
bypass_hpu_graph: Optional[bool] = None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": batch.input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": batch.attention_mask,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": batch.past_key_values,
|
||||||
"token_idx": token_idx,
|
"token_idx": batch.token_idx,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": batch.pixel_values,
|
||||||
"image_sizes": image_sizes,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.model.config.model_type == "mllama":
|
||||||
|
kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids
|
||||||
|
kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask
|
||||||
|
kwargs["cross_attention_mask"] = batch.cross_attention_mask
|
||||||
|
else:
|
||||||
|
kwargs["image_sizes"] = batch.image_sizes
|
||||||
|
|
||||||
hpu_kwargs = {}
|
hpu_kwargs = {}
|
||||||
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
||||||
if self.model.config.model_type == "llama" :
|
if self.model.config.model_type == "llama" :
|
||||||
hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
|
hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
|
||||||
|
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = batch.position_ids
|
||||||
|
|
||||||
if bypass_hpu_graph != None:
|
if bypass_hpu_graph != None:
|
||||||
hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
||||||
|
|
||||||
kwargs.update(self.kwargs)
|
kwargs.update(self.kwargs)
|
||||||
model_inputs = self.model.prepare_inputs_for_generation(**kwargs)
|
model_inputs = self.model.prepare_inputs_for_generation(**kwargs)
|
||||||
if past_key_values is not None:
|
|
||||||
|
if batch.past_key_values is not None:
|
||||||
return self.model.forward(**model_inputs, **hpu_kwargs)
|
return self.model.forward(**model_inputs, **hpu_kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.model.forward(**model_inputs, **hpu_kwargs)
|
outputs = self.model.forward(**model_inputs, **hpu_kwargs)
|
||||||
@ -803,8 +1023,9 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batches: List[VlmCausalLMBatch], is_warmup: bool = False
|
self, batches: list[VlmCausalLMBatch], is_warmup: bool = False
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]:
|
||||||
|
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -870,9 +1091,16 @@ class VlmCausalLM(Model):
|
|||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask.index_fill_(1, token_idx, 1)
|
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||||
|
|
||||||
|
# add cross-attn mask for new token
|
||||||
|
if batch.cross_attention_mask is not None:
|
||||||
|
cross_attention_mask_prev = batch.cross_attention_mask
|
||||||
|
if token_idx is not None:
|
||||||
|
mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...]
|
||||||
|
cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask)
|
||||||
|
batch.cross_attention_mask = cross_attention_mask_prev
|
||||||
|
|
||||||
# Adjust lengths
|
# Adjust lengths
|
||||||
batch.input_length += 1
|
batch.input_length += 1
|
||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
if prefill:
|
if prefill:
|
||||||
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||||
@ -894,7 +1122,7 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
# Check if we need to do any bookkeeping first
|
# Check if we need to do any bookkeeping first
|
||||||
if not prefill:
|
if not prefill:
|
||||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
batch = self.batch_type.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
||||||
|
|
||||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
||||||
@ -907,32 +1135,30 @@ class VlmCausalLM(Model):
|
|||||||
# Execute batch
|
# Execute batch
|
||||||
if prefill:
|
if prefill:
|
||||||
# no right padding for prefill
|
# no right padding for prefill
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
#token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||||
batch.logits, batch.past = self.forward(
|
batch.logits, batch.past = self.forward(
|
||||||
batch.input_ids,
|
batch,
|
||||||
batch.attention_mask,
|
|
||||||
batch.position_ids,
|
|
||||||
token_idx,
|
|
||||||
batch.past_key_values,
|
|
||||||
batch.pixel_values,
|
|
||||||
batch.image_sizes,
|
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||||
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||||
# - we've already generated the first and only needed token in the prefill phase
|
# - we've already generated the first and only needed token in the prefill phase
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
#token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
batch.logits = self.forward(
|
batch.logits = self.forward(
|
||||||
batch.input_ids,
|
batch,
|
||||||
batch.attention_mask,
|
|
||||||
batch.position_ids,
|
|
||||||
token_idx,
|
|
||||||
batch.past_key_values,
|
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
if batch.aspect_ratio_ids is not None:
|
||||||
|
batch.aspect_ratio_ids = None
|
||||||
|
if batch.aspect_ratio_mask is not None:
|
||||||
|
batch.aspect_ratio_mask = None
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
@ -1091,7 +1317,7 @@ class VlmCausalLM(Model):
|
|||||||
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
def batch_from_pb(self, batch, is_warmup):
|
def batch_from_pb(self, batch, is_warmup):
|
||||||
return VlmCausalLMBatch.from_pb_processor(
|
return self.batch_type.from_pb_processor(
|
||||||
batch,
|
batch,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.processor,
|
self.processor,
|
||||||
@ -1112,21 +1338,23 @@ class VlmCausalLM(Model):
|
|||||||
return self.batch_from_pb(batch, is_warmup)
|
return self.batch_from_pb(batch, is_warmup)
|
||||||
|
|
||||||
def warmup(self, request) -> None:
|
def warmup(self, request) -> None:
|
||||||
is_warmup = True
|
global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS
|
||||||
batch = self.batch_from_pb(request.batch, is_warmup)
|
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||||
|
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||||
|
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
||||||
|
max_input_length = batch.input_ids.shape[1]
|
||||||
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
|
|
||||||
global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
|
global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
|
||||||
max_input_length = batch.input_ids.shape[1]
|
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
while batch_size <= max_prefill_batch_size:
|
while batch_size <= max_prefill_batch_size:
|
||||||
@ -1135,7 +1363,11 @@ class VlmCausalLM(Model):
|
|||||||
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
||||||
|
|
||||||
|
if self.model.config.model_type == "mllama":
|
||||||
|
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||||
|
else:
|
||||||
seq_len = BASE_IMAGE_TOKENS
|
seq_len = BASE_IMAGE_TOKENS
|
||||||
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
i = 0
|
i = 0
|
||||||
while seq_len <= max_input_length:
|
while seq_len <= max_input_length:
|
||||||
@ -1152,9 +1384,9 @@ class VlmCausalLM(Model):
|
|||||||
try:
|
try:
|
||||||
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
||||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||||
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
|
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup=True)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||||
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
|
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup=True)
|
||||||
|
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
|
|
||||||
@ -1182,33 +1414,28 @@ class VlmCausalLM(Model):
|
|||||||
try:
|
try:
|
||||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
||||||
batches = []
|
batches = []
|
||||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
|
||||||
batches.append(prefill_batch)
|
|
||||||
while batch_size <= max_decode_batch_size:
|
while batch_size <= max_decode_batch_size:
|
||||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
for i in range(int(batch_size/max_prefill_batch_size)) :
|
||||||
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0]-1, max_prefill_batch_size, is_warmup=False)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||||
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
|
_, decode_batch, _ = self.generate_token(batches, is_warmup=True)
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
batch_size = batch_size * 2
|
batch_size = batch_size * 2
|
||||||
batches.clear()
|
batches.clear()
|
||||||
|
|
||||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
|
||||||
batches.append(prefill_batch)
|
|
||||||
|
|
||||||
batches.clear()
|
|
||||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
||||||
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
||||||
batch_size = max_decode_batch_size
|
batch_size = max_decode_batch_size
|
||||||
for i in range(int(max_decode_batch_size / 2)) :
|
for i in range(int(max_decode_batch_size / 2)) :
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0]-1, 2, is_warmup=False)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
_, decode_batch, _ = self.generate_token(batches, is_warmup=True)
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
|
||||||
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
|
|
||||||
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
|
MAX_BATCH_TOTAL_TOKENS = max_decode_batch_size * MAX_TOTAL_TOKENS
|
||||||
except :
|
except :
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
|
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
|
||||||
|
Loading…
Reference in New Issue
Block a user