fix prehooks issues

This commit is contained in:
baptiste 2025-02-25 15:24:35 +00:00
parent 31535bcde2
commit 77dca4dfbe
24 changed files with 666 additions and 395 deletions

View File

@ -113,12 +113,14 @@ def serve(
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded: if sharded:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py" tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1")) num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard)) logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess import subprocess
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" cmd = (
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
)
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
@ -130,6 +132,7 @@ def serve(
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
do_terminate = False do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM) current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame): def terminate_handler(sig, frame):
nonlocal do_terminate nonlocal do_terminate
do_terminate = True do_terminate = True

View File

@ -17,7 +17,9 @@ if is_quantization_enabled:
def patch_scoped_linear_all_reduce(model): def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce from optimum.habana.transformers.models.modeling_all_models import (
ScopedLinearAllReduce,
)
for name, module in model.named_children(): for name, module in model.named_children():
if type(module) is LinearAllreduce: if type(module) is LinearAllreduce:
@ -36,7 +38,13 @@ def setup_quantization(model):
def prepare_model_for_quantization(model): def prepare_model_for_quantization(model):
if is_quantization_enabled: if is_quantization_enabled:
if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]: if model.config.model_type in [
"llama",
"falcon",
"qwen2",
"starcoder2",
"gemma",
]:
patch_scoped_linear_all_reduce(model) patch_scoped_linear_all_reduce(model)
from neural_compressor.torch.quantization import FP8Config, convert from neural_compressor.torch.quantization import FP8Config, convert

View File

@ -24,7 +24,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
response = method(request_or_iterator, context) response = method(request_or_iterator, context)
return await response return await response
except Exception as err: except Exception as err:
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else '' trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
method_name = method_name.split("/")[-1] method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.") logger.exception(f"Method {method_name} encountered an error.")
@ -36,7 +36,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
torch.cuda.empty_cache() torch.cuda.empty_cache()
from .utils.debug import dbg_trace from .utils.debug import dbg_trace
dbg_trace('EXCEPTION', traceback.format_exc())
dbg_trace("EXCEPTION", traceback.format_exc())
await context.abort_with_status( await context.abort_with_status(
rpc_status.to_status( rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)

View File

@ -8,6 +8,7 @@ from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from typing import List, Dict from typing import List, Dict
# Needed to properly setup habana_frameworks # Needed to properly setup habana_frameworks
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -16,10 +17,12 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration, LlavaNextForConditionalGeneration,
) )
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import ( # from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration, # MllamaForConditionalGeneration,

View File

@ -59,7 +59,12 @@ CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8)) BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
MAX_BATCH_SIZE = int(os.environ.get('MAX_BATCH_SIZE')) if os.environ.get('MAX_BATCH_SIZE') is not None else None MAX_BATCH_SIZE = (
int(os.environ.get("MAX_BATCH_SIZE"))
if os.environ.get("MAX_BATCH_SIZE") is not None
else None
)
def torch_compile_for_eager(func): def torch_compile_for_eager(func):
if LAZY_MODE == 1: if LAZY_MODE == 1:
@ -564,9 +569,9 @@ class CausalLMBatch(Batch):
bucket_size = max_input_length bucket_size = max_input_length
left_padding = max_input_length - input_len left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, ( assert (
"PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
) ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length: if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1 bucket_size = rounded_seq_len - 1
@ -1080,9 +1085,9 @@ class CausalLM(Model):
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph bypass_hpu_graph=(
if self.enable_hpu_graph prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
else None, ),
) )
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
# Don't schedule next forward if max_new_tokens for all requests equals 1 # Don't schedule next forward if max_new_tokens for all requests equals 1
@ -1099,9 +1104,9 @@ class CausalLM(Model):
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph bypass_hpu_graph=(
if self.enable_hpu_graph prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
else None, ),
) )
if self.model.config.model_type in ["gpt_bigcode"]: if self.model.config.model_type in ["gpt_bigcode"]:
batch.logits, batch.past = logits batch.logits, batch.past = logits
@ -1289,8 +1294,12 @@ class CausalLM(Model):
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
def warmup(self, request: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]: def warmup(
assert MAX_BATCH_SIZE is not None, "MAX_BATCH_SIZE is not set, it should be set in the launcher" self, request: generate_pb2.WarmupRequest
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
assert (
MAX_BATCH_SIZE is not None
), "MAX_BATCH_SIZE is not set, it should be set in the launcher"
MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens
logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}") logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}")
logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}") logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}")
@ -1313,7 +1322,14 @@ class CausalLM(Model):
# Warmup prefill batch_size # Warmup prefill batch_size
max_input_tokens = request.max_input_tokens max_input_tokens = request.max_input_tokens
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)] prefill_batch_size_list = [
batch
for batch in range(
PREFILL_BATCH_BUCKET_SIZE,
max_prefill_batch_size,
PREFILL_BATCH_BUCKET_SIZE,
)
]
prefill_batch_size_list.append(max_prefill_batch_size) prefill_batch_size_list.append(max_prefill_batch_size)
prefill_seqlen_list = [ prefill_seqlen_list = [
seq seq
@ -1399,7 +1415,7 @@ class CausalLM(Model):
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
max_input_tokens=max_input_tokens max_input_tokens = max_input_tokens
max_total_tokens=MAX_TOTAL_TOKENS max_total_tokens = MAX_TOTAL_TOKENS
return max_supported_total_tokens, max_input_tokens, max_total_tokens return max_supported_total_tokens, max_input_tokens, max_total_tokens

View File

@ -25,6 +25,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -89,11 +90,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
): ):
if token_idx is not None: if token_idx is not None:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -120,175 +129,214 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs return outputs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids, input_ids,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
pixel_values=None, pixel_values=None,
image_sizes=None, image_sizes=None,
attention_mask=None, attention_mask=None,
**kwargs, **kwargs,
): ):
""" """
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are: The only differences are:
- add new args token_idx - add new args token_idx
- add the process of merging images into inputs_embeds - add the process of merging images into inputs_embeds
""" """
token_idx = kwargs.get("token_idx", None) token_idx = kwargs.get("token_idx", None)
if token_idx is None: if token_idx is None:
return super().prepare_inputs_for_generation( return super().prepare_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
pixel_values=pixel_values, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs, **kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if (
past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
) )
else:
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None) # 1. Extract the input embeddings
labels = kwargs.get("labels", None) inputs_embeds = self.get_input_embeddings()(input_ids)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: # 2. Merge text and images
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) batch_size, num_patches, num_channels, height, width = (
vision_feature_layer = kwargs.get("vision_feature_layer", None) pixel_values.shape
vision_feature_select_strategy = ( )
vision_feature_select_strategy reshaped_pixel_values = pixel_values.view(
if vision_feature_select_strategy is not None batch_size * num_patches, num_channels, height, width
else self.config.vision_feature_select_strategy )
) image_features = self.vision_tower(
vision_feature_layer = ( reshaped_pixel_values,
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer output_hidden_states=True,
) use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
# 1. Extract the input embeddings selected_image_feature = image_features.hidden_states[
inputs_embeds = self.get_input_embeddings()(input_ids) vision_feature_layer
# 2. Merge text and images ]
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower(
reshaped_pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
selected_image_feature = image_features.hidden_states[vision_feature_layer] if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
if vision_feature_select_strategy == "default": image_features = self.multi_modal_projector(selected_image_feature)
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature) # split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# split up image_features for each of the individual images # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size) height = width = (
# if we assume each image has 5 image features (base image + 4 patches) self.config.vision_config.image_size
split_sizes = [image.shape[0] for image in pixel_values] // self.config.vision_config.patch_size
image_features = torch.split(image_features, split_sizes, dim=0) )
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" new_image_features = []
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
new_image_features = [] if height * width != base_image_feature.shape[0]:
for image_idx, image_feature in enumerate(image_features): raise ValueError(
if image_feature.shape[0] > 1: "The number of patches is not consistent with the image size."
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
) )
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_sizes[image_idx].tolist(),
image_feature = image_feature.flatten(1, 2).flatten(2, 3) self.config.image_grid_pinpoints,
image_feature = unpad_image(image_feature, image_sizes[image_idx]) self.config.vision_config.image_size,
image_feature = torch.cat( )
(
image_feature, image_feature = image_feature.view(
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
), ),
dim=-1, ),
) dim=-1,
image_feature = image_feature.flatten(1, 2).transpose(0, 1) )
image_feature = torch.cat((base_image_feature, image_feature), dim=0) image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else: image_feature = torch.cat(
image_feature = image_feature[0] (base_image_feature, image_feature), dim=0
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) )
new_image_features.append(image_feature) else:
image_features = torch.stack(new_image_features, dim=0) image_feature = image_feature[0]
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) image_feature = torch.cat(
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. (image_feature, self.image_newline[None]), dim=0
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of )
# generation with cache new_image_features.append(image_feature)
elif past_key_values is not None: image_features = torch.stack(new_image_features, dim=0)
seq_len = input_ids.shape[1] inputs_embeds = self._merge_input_ids_with_image_features(
pad_len = seq_len - token_idx.item() inputs_embeds, image_features, input_ids
input_ids = torch.index_select(input_ids, 1, token_idx - 1) )
# Retrieve the first layer to inspect the logits and mask out the hidden states self.image_offset = (
# that are set to 0 image_features.shape[1] - 1
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] ) # image_token has occupied 1 token position.
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
) )
return model_inputs # Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs

View File

@ -58,6 +58,7 @@ def set_model_id(model_id: str):
global MODEL_ID global MODEL_ID
MODEL_ID = model_id MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None

View File

@ -12,6 +12,7 @@ from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
@ -93,7 +94,9 @@ class Model(ABC):
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]: def warmup(
self, batch: generate_pb2.WarmupRequest
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
self.generate_token(batch) self.generate_token(batch)
return None, None, None return None, None, None

View File

@ -13,7 +13,7 @@ class StarCoderCausalLMBatch(CausalLMBatch):
def detach_kv_cache(self): def detach_kv_cache(self):
past_keys = [] past_keys = []
past_values = [] past_values = []
last_dim = int(self.past_key_values[0].size(dim=-1)/2) last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
for key_value in self.past_key_values: for key_value in self.past_key_values:
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
@ -23,7 +23,9 @@ class StarCoderCausalLMBatch(CausalLMBatch):
def attach_kv_cache(self, past_keys, past_values): def attach_kv_cache(self, past_keys, past_values):
self.past_key_values = [ self.past_key_values = [
torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)] torch.cat((key, value), dim=-1)
for key, value in zip(past_keys, past_values)
]
class StarCoder(CausalLM): class StarCoder(CausalLM):

View File

@ -66,23 +66,26 @@ IDEFICS2_IMAGE_TOKEN = "<image>"
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048))
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072)) MAX_BATCH_TOTAL_TOKENS = int(os.environ.get("MAX_BATCH_TOTAL_TOKENS", 131072))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_BATCH_SIZE_LIST = []
PREFILL_WARMUP_SEQLEN_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = []
DECODE_WARMUP_BATCH_SIZE_LIST = [] DECODE_WARMUP_BATCH_SIZE_LIST = []
def round_up(warmup_list:list, num) :
def round_up(warmup_list: list, num):
i = 0 i = 0
for i in warmup_list: for i in warmup_list:
if num <= i : if num <= i:
break break
return i return i
def split(string) -> List[Dict[str, str]]: def split(string) -> List[Dict[str, str]]:
parts = [] parts = []
cursor = 0 cursor = 0
@ -99,6 +102,7 @@ def split(string) -> List[Dict[str, str]]:
return parts return parts
def image_text_replacement(processor, image_input, config, image_id: int) -> str: def image_text_replacement(processor, image_input, config, image_id: int) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
image_seq_len = 64 image_seq_len = 64
@ -196,14 +200,19 @@ class VlmCausalLMBatch(CausalLMBatch):
is_warmup: bool = False, is_warmup: bool = False,
) -> "VlmCausalLMBatch": ) -> "VlmCausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] requests = [
CausalLMRequest.from_pb(idx, req, tokenizer)
for idx, req in enumerate(pb.requests)
]
max_input_length = max(r.data.truncate for r in requests) max_input_length = max(r.data.truncate for r in requests)
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
# TODO: Add support for sparse batches # TODO: Add support for sparse batches
top_n_tokens = [r.top_n_tokens for r in pb.requests] top_n_tokens = [r.top_n_tokens for r in pb.requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
# TODO: by tokenizing all inputs at once we loose information on actual input lengths # TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence # this means that we cannot shift inputs to the left after a long input sequence
@ -226,7 +235,7 @@ class VlmCausalLMBatch(CausalLMBatch):
bucket_size = max_input_length bucket_size = max_input_length
left_padding = max_input_length - input_len left_padding = max_input_length - input_len
if is_warmup is False: if is_warmup is False:
if input_len < max_input_length : if input_len < max_input_length:
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
if rounded_seq_len <= max_input_length: if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1 bucket_size = rounded_seq_len - 1
@ -276,10 +285,14 @@ class VlmCausalLMBatch(CausalLMBatch):
input_length=input_len, input_length=input_len,
) )
@classmethod @classmethod
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup cls,
requests: Iterable[generate_pb2.Request],
tokenizer,
processor,
config,
is_warmup,
): ):
# Process images first. We need all of them so that the processor # Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final # can make the image splits the same size. And we need the final
@ -345,24 +358,24 @@ class VlmCausalLMBatch(CausalLMBatch):
) )
if missing_inputs > 0 and image_inputs is not None: if missing_inputs > 0 and image_inputs is not None:
dummy_shape = list(image_inputs['pixel_values'].shape) dummy_shape = list(image_inputs["pixel_values"].shape)
dummy_shape[0] = missing_inputs dummy_shape[0] = missing_inputs
dummy_images = torch.rand(dummy_shape) dummy_images = torch.rand(dummy_shape)
new_image_inputs = { new_image_inputs = {
"pixel_values": torch.cat( "pixel_values": torch.cat(
(image_inputs['pixel_values'], dummy_images), dim=0 (image_inputs["pixel_values"], dummy_images), dim=0
), ),
} }
if "pixel_attention_mask" in image_inputs: if "pixel_attention_mask" in image_inputs:
dummy_shape = list(image_inputs['pixel_attention_mask'].shape) dummy_shape = list(image_inputs["pixel_attention_mask"].shape)
dummy_shape[0] = missing_inputs dummy_shape[0] = missing_inputs
dummy_attention = torch.zeros(dummy_shape) dummy_attention = torch.zeros(dummy_shape)
new_image_inputs["pixel_attention_mask"] = torch.cat( new_image_inputs["pixel_attention_mask"] = torch.cat(
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0 (image_inputs["pixel_attention_mask"], dummy_attention), dim=0
) )
if "image_sizes" in image_inputs: if "image_sizes" in image_inputs:
dummy_shape = list(list(image_inputs['image_sizes'])[0]) dummy_shape = list(list(image_inputs["image_sizes"])[0])
dummy_shape = missing_inputs*[dummy_shape] dummy_shape = missing_inputs * [dummy_shape]
dummy_sizes = torch.IntTensor(dummy_shape) dummy_sizes = torch.IntTensor(dummy_shape)
new_image_inputs["image_sizes"] = torch.cat( new_image_inputs["image_sizes"] = torch.cat(
(image_inputs["image_sizes"], dummy_sizes), dim=0 (image_inputs["image_sizes"], dummy_sizes), dim=0
@ -406,19 +419,27 @@ class VlmCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch": def concatenate(
cls,
batches: List["CausalLMBatch"],
pad_token_id: int = 0,
is_warmup: bool = False,
) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id, is_warmup) return cls.recombine(batches, pad_token_id, is_warmup)
@classmethod @classmethod
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch": def recombine(
cls,
batches: List["VlmCausalLMBatch"],
pad_token_id: int,
is_warmup: bool = False,
) -> "VlmCausalLMBatch":
if not all(b.past_key_values is not None for b in batches): if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!") raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
total_requests = sum(len(b) for b in batches) total_requests = sum(len(b) for b in batches)
new_bs = total_requests new_bs = total_requests
if is_warmup is False : if is_warmup is False:
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
batch_id = batches[0].batch_id batch_id = batches[0].batch_id
device = batches[0].input_ids.device device = batches[0].input_ids.device
@ -431,31 +452,39 @@ class VlmCausalLMBatch(CausalLMBatch):
# For prefill there is a space allocated only for first token # For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode # Need to add padding to the max total tokens before first decode
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] moves_needed = [
total_requests - len(b) if b.batch_size == new_bs else total_requests
for b in batches
]
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
reshape = (batches[dst_batch_idx].batch_size < new_bs) reshape = batches[dst_batch_idx].batch_size < new_bs
# TODO: Add support for changing max seq len, i.e. due to output length bucketing # TODO: Add support for changing max seq len, i.e. due to output length bucketing
# FIXME: max_seq_len for non optimized code # FIXME: max_seq_len for non optimized code
if len(batches) > 1: if len(batches) > 1:
scenario = 'CONCAT' scenario = "CONCAT"
elif reshape: elif reshape:
scenario = 'RESHAPE' scenario = "RESHAPE"
elif cur_padding[dst_batch_idx] <= 0: elif cur_padding[dst_batch_idx] <= 0:
scenario = 'SHIFT' scenario = "SHIFT"
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] offsets = [
biggest_single_chunk(b.max_input_length - max_input_length)
for b in batches
]
max_input_length = max_input_length + offsets[dst_batch_idx] max_input_length = max_input_length + offsets[dst_batch_idx]
else: else:
# Nothing to do # Nothing to do
return batches[0] return batches[0]
dbg_trace( dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' scenario,
f' reqs:{[len(b) for b in batches]}' f"bs:{[b.batch_size for b in batches]}->{new_bs}"
f' offsets:{offsets}' f" reqs:{[len(b) for b in batches]}"
f' input_lengths:{input_lengths}' f" offsets:{offsets}"
f' cur_padding:{cur_padding}' f" input_lengths:{input_lengths}"
f' dst_batch:{dst_batch_idx}') f" cur_padding:{cur_padding}"
f" dst_batch:{dst_batch_idx}",
)
grouped_requests = [[req for req in batch.requests] for batch in batches] grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests)) flat_requests = list(itertools.chain(*grouped_requests))
@ -466,10 +495,14 @@ class VlmCausalLMBatch(CausalLMBatch):
batches[i].realign(target_bs, offsets[i], pad_token_id) batches[i].realign(target_bs, offsets[i], pad_token_id)
batches[i].split_kv_cache_if_needed(i == dst_batch_idx) batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
batches[dst_batch_idx].expand_bs(new_bs) batches[dst_batch_idx].expand_bs(new_bs)
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) batches[dst_batch_idx].move_data(
[batches[i] for i in range(len(batches)) if i != dst_batch_idx]
)
top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
parameters = [r.data.parameters for r in flat_requests] parameters = [r.data.parameters for r in flat_requests]
# append the dummy parameters for dummy requests # append the dummy parameters for dummy requests
@ -480,7 +513,9 @@ class VlmCausalLMBatch(CausalLMBatch):
fsm_grammar_states = [0] * batch_size fsm_grammar_states = [0] * batch_size
for batch in batches: for batch in batches:
for i, req in enumerate(batch.requests): for i, req in enumerate(batch.requests):
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] fsm_grammar_states[req.idx] = (
batch.next_token_chooser.fsm_grammar_states[i]
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters, parameters,
@ -513,6 +548,7 @@ class VlmCausalLMBatch(CausalLMBatch):
input_length=input_length, input_length=input_length,
) )
class VlmCausalLM(Model): class VlmCausalLM(Model):
def __init__( def __init__(
self, self,
@ -561,18 +597,14 @@ class VlmCausalLM(Model):
htorch.core.hpu_set_env() htorch.core.hpu_set_env()
if world_size > 1: if world_size > 1:
model = self.get_deepspeed_model( model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
model_class, model_id, dtype, revision
)
model = hq_env.prepare_model_for_quantization(model) model = hq_env.prepare_model_for_quantization(model)
else: else:
get_repo_root(model_id) get_repo_root(model_id)
# Check support for rope scaling # Check support for rope scaling
model_kwargs = {} model_kwargs = {}
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(model_id)
model_id
)
if hasattr(config, "rope_scaling"): if hasattr(config, "rope_scaling"):
model_kwargs["rope_scaling"] = self.get_rope_scaling() model_kwargs["rope_scaling"] = self.get_rope_scaling()
@ -581,23 +613,29 @@ class VlmCausalLM(Model):
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**model_kwargs **model_kwargs,
) )
model = hq_env.prepare_model_for_quantization(model) model = hq_env.prepare_model_for_quantization(model)
model = model.eval().to(device) model = model.eval().to(device)
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 self.enable_hpu_graph = (
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
)
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
model = remove_kv_cache_from_output(model) model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph: if self.enable_hpu_graph:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True) model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else: else:
if LAZY_MODE == 0: if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done # It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace( dbg_trace("TORCH COMPILE", "Torch compiling of model")
"TORCH COMPILE", 'Torch compiling of model') model.model = torch.compile(
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) model.model,
backend="hpu_backend",
options={"keep_input_mutations": True},
)
model = hq_env.setup_quantization(model) model = hq_env.setup_quantization(model)
@ -647,11 +685,15 @@ class VlmCausalLM(Model):
) )
# Create profiler # Create profiler
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 self.profiling_warmup_steps = (
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
)
self.profiling_steps = (
int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
)
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
if self.profiling_steps > 0: if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile( self.hb_profiler = HabanaProfile(
@ -659,14 +701,13 @@ class VlmCausalLM(Model):
warmup=self.profiling_warmup_steps, warmup=self.profiling_warmup_steps,
active=self.profiling_steps, active=self.profiling_steps,
output_dir=output_dir, output_dir=output_dir,
record_shapes=record_shapes record_shapes=record_shapes,
) )
self.hb_profiler.start() self.hb_profiler.start()
else: else:
self.hb_profiler = None self.hb_profiler = None
self.step = 0 self.step = 0
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class return self.batch_class
@ -679,20 +720,20 @@ class VlmCausalLM(Model):
model_class, model_class,
model_id: str, model_id: str,
dtype: torch.dtype, dtype: torch.dtype,
revision: Optional[str] = None revision: Optional[str] = None,
) -> torch.nn.Module: ) -> torch.nn.Module:
import deepspeed import deepspeed
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu() world_size, rank, local_rank = initialize_distributed_hpu()
model_kwargs = { model_kwargs = {"revision": revision}
"revision": revision
}
# Initialize process(es) for DeepSpeed # Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl") deepspeed.init_distributed(dist_backend="hccl")
logger.info( logger.info(
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
world_size, rank, local_rank
)
) )
config = AutoConfig.from_pretrained(model_id, **model_kwargs) config = AutoConfig.from_pretrained(model_id, **model_kwargs)
load_to_meta = model_on_meta(config) load_to_meta = model_on_meta(config)
@ -710,14 +751,18 @@ class VlmCausalLM(Model):
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible # TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"): with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) model = model_class.from_pretrained(
model_id, torch_dtype=dtype, **model_kwargs
)
model = model.eval() model = model.eval()
# Initialize the model # Initialize the model
ds_inference_kwargs = {"dtype": dtype} ds_inference_kwargs = {"dtype": dtype}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = False ds_inference_kwargs["enable_cuda_graph"] = False
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config) ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(
model.language_model.config
)
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
@ -734,12 +779,12 @@ class VlmCausalLM(Model):
return None return None
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
return { return {"type": rope_scaling, "factor": float(rope_factor)}
'type': rope_scaling, 'factor': float(rope_factor)
}
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def decode_token( def decode_token(
self, self,
@ -748,7 +793,9 @@ class VlmCausalLM(Model):
read_offset: int = 0, read_offset: int = 0,
) -> Tuple[str, int, int]: ) -> Tuple[str, int, int]:
if is_tokenizer_transparent(self.tokenizer): if is_tokenizer_transparent(self.tokenizer):
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) new_text = self.tokenizer.decode(
all_input_ids[read_offset:], skip_special_tokens=False
)
return new_text, read_offset, len(all_input_ids) return new_text, read_offset, len(all_input_ids)
else: else:
return super().decode_token(all_input_ids, prefix_offset, read_offset) return super().decode_token(all_input_ids, prefix_offset, read_offset)
@ -776,7 +823,7 @@ class VlmCausalLM(Model):
hpu_kwargs = {} hpu_kwargs = {}
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models # Optimum Habana got "lazy_mode" key-val only supported for llama type of models
if self.model.config.model_type == "llama" : if self.model.config.model_type == "llama":
hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
if self.has_position_ids: if self.has_position_ids:
@ -814,18 +861,26 @@ class VlmCausalLM(Model):
token_idx_scalar = batch.attention_mask.shape[-1] - 1 token_idx_scalar = batch.attention_mask.shape[-1] - 1
token_idx = torch.tensor(token_idx_scalar).to(self.device) token_idx = torch.tensor(token_idx_scalar).to(self.device)
else: else:
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding token_idx_scalar = (
batch.attention_mask.shape[-1] - batch.right_padding
)
token_idx = torch.tensor(token_idx_scalar).to(self.device) token_idx = torch.tensor(token_idx_scalar).to(self.device)
# Select next token # Select next token
input_length = batch.input_length input_length = batch.input_length
if logits.shape[-2] > 1: if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs, _, _ = (
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate batch.next_token_chooser(
batch.input_ids,
logits[:, input_length - 1 : input_length, :].squeeze(-2),
self.speculate,
)
) )
else: else:
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs, _, _ = (
batch.input_ids, logits.squeeze(-2), self.speculate batch.next_token_chooser(
batch.input_ids, logits.squeeze(-2), self.speculate
)
) )
# Speculation is not active for causal # Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0] accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
@ -836,23 +891,29 @@ class VlmCausalLM(Model):
accepted_ids, accepted_ids,
) )
prev_batches.append({ prev_batches.append(
'next_token_ids': next_token_ids, {
'next_token_logprobs': next_token_logprobs, "next_token_ids": next_token_ids,
}) "next_token_logprobs": next_token_logprobs,
}
)
for req_idx, req in enumerate(batch.requests): for req_idx, req in enumerate(batch.requests):
requests_to_generate.append({ requests_to_generate.append(
'req': req, {
'prev_req_idx': req.idx, "req": req,
'batch_id': batch_id, "prev_req_idx": req.idx,
'seed': batch.next_token_chooser.seeds[req_idx], "batch_id": batch_id,
'do_sample': batch.next_token_chooser.do_sample[req_idx], "seed": batch.next_token_chooser.seeds[req_idx],
'top_n_tokens': batch.top_n_tokens[req_idx], "do_sample": batch.next_token_chooser.do_sample[req_idx],
'top_token_ids': batch_top_token_ids[req_idx], "top_n_tokens": batch.top_n_tokens[req_idx],
'top_token_logprobs': batch_top_token_logprobs[req_idx], "top_token_ids": batch_top_token_ids[req_idx],
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], "top_token_logprobs": batch_top_token_logprobs[req_idx],
}) "grammar_state": batch.next_token_chooser.fsm_grammar_states[
req.idx
],
}
)
htorch.core.mark_step() htorch.core.mark_step()
@ -867,7 +928,9 @@ class VlmCausalLM(Model):
# Update position_ids # Update position_ids
if prefill: if prefill:
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 batch.position_ids = (
torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
)
else: else:
batch.position_ids += 1 batch.position_ids += 1
# Update past key values # Update past key values
@ -878,7 +941,9 @@ class VlmCausalLM(Model):
# Stage 2. Prepare new batch for speculative scheduling # Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1: if len(batches) > 1:
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) batch = self.batch_type.concatenate(
batches, self.tokenizer.pad_token_id, is_warmup
)
else: else:
batch = batches[0] batch = batches[0]
@ -886,15 +951,24 @@ class VlmCausalLM(Model):
# Check if we need to do any bookkeeping first # Check if we need to do any bookkeeping first
if not prefill: if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) batch = batch.__class__.recombine(
[batch], self.tokenizer.pad_token_id, is_warmup
)
scenario = 'PREFILL' if prefill else 'GENERATE' scenario = "PREFILL" if prefill else "GENERATE"
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: if (
self.enable_hpu_graph
and self.limit_hpu_graph
and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
!= self.prev_bs
):
self.model.clear_cache() self.model.clear_cache()
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
dbg_trace( dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') scenario,
#assert batch.right_padding > 0, 'No more room for next token!' f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
)
# assert batch.right_padding > 0, 'No more room for next token!'
# Execute batch # Execute batch
if prefill: if prefill:
@ -908,21 +982,27 @@ class VlmCausalLM(Model):
batch.past_key_values, batch.past_key_values,
batch.pixel_values, batch.pixel_values,
batch.image_sizes, batch.image_sizes,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, bypass_hpu_graph=(
prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
),
) )
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
# Don't schedule next forward if max_new_tokens for all requests equals 1 # Don't schedule next forward if max_new_tokens for all requests equals 1
# - we've already generated the first and only needed token in the prefill phase # - we've already generated the first and only needed token in the prefill phase
pass pass
else: else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) token_idx = torch.tensor(
batch.attention_mask.shape[-1] - batch.right_padding
).to(self.device)
batch.logits = self.forward( batch.logits = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, bypass_hpu_graph=(
prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
),
) )
htorch.core.mark_step() htorch.core.mark_step()
@ -932,40 +1012,45 @@ class VlmCausalLM(Model):
# Stage 3. Finish and return previous generations # Stage 3. Finish and return previous generations
stopped = len(requests_to_generate) > 0 stopped = len(requests_to_generate) > 0
for prev_batch in prev_batches: for prev_batch in prev_batches:
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() prev_batch["next_token_logprobs"] = prev_batch[
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() "next_token_logprobs"
].tolist()
prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
htorch.core.mark_step() htorch.core.mark_step()
for req_data in requests_to_generate: for req_data in requests_to_generate:
req = req_data['req'] req = req_data["req"]
i = req_data['prev_req_idx'] i = req_data["prev_req_idx"]
prev_batch_id = req_data['batch_id'] prev_batch_id = req_data["batch_id"]
assert len(prev_batches) > prev_batch_id assert len(prev_batches) > prev_batch_id
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
request = req.data request = req.data
input_length = req.input_length input_length = req.input_length
prefix_offset = req.prefix_offset prefix_offset = req.prefix_offset
read_offset = req.read_offset read_offset = req.read_offset
do_sample = req_data['do_sample'] do_sample = req_data["do_sample"]
seed = req_data['seed'] seed = req_data["seed"]
stopping_criteria = req.stopping_criteria stopping_criteria = req.stopping_criteria
all_input_ids = req.all_input_ids all_input_ids = req.all_input_ids
next_token_id = next_token_ids_cpu[i] next_token_id = next_token_ids_cpu[i]
next_token_logprob = next_token_logprobs[i] next_token_logprob = next_token_logprobs[i]
top_n_tokens = req_data['top_n_tokens'] top_n_tokens = req_data["top_n_tokens"]
top_token_ids = req_data['top_token_ids'] top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data['top_token_logprobs'] top_token_logprobs = req_data["top_token_logprobs"]
grammar_state = req_data['grammar_state'] grammar_state = req_data["grammar_state"]
# Append next token to all tokens # Append next token to all tokens
all_input_ids[input_length] = next_token_id all_input_ids[input_length] = next_token_id
new_input_length = input_length + 1 new_input_length = input_length + 1
# Generated token # Generated token
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: if (
next_token_text = '' is_tokenizer_transparent(self.tokenizer)
and len(stopping_criteria.stop_sequence_criterias) == 0
):
next_token_text = ""
else: else:
next_token_text, prefix_offset, read_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
@ -989,7 +1074,11 @@ class VlmCausalLM(Model):
output_text = None output_text = None
else: else:
output_text = self.decode( output_text = self.decode(
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] all_input_ids[
new_input_length
- stopping_criteria.current_tokens : new_input_length,
0,
]
) )
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
@ -1004,7 +1093,7 @@ class VlmCausalLM(Model):
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + next_token_logprobs prefill_logprobs = [float("nan")] + next_token_logprobs
prefill_token_ids = all_input_ids[0: new_input_length - 1] prefill_token_ids = all_input_ids[0 : new_input_length - 1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
@ -1073,7 +1162,12 @@ class VlmCausalLM(Model):
htorch.core.mark_step() htorch.core.mark_step()
self.step = self.step + 1 self.step = self.step + 1
if self.hb_profiler is not None: if self.hb_profiler is not None:
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: if (
self.step
> self.profiling_wait_steps
+ self.profiling_warmup_steps
+ self.profiling_steps
):
self.hb_profiler.stop() self.hb_profiler.stop()
else: else:
self.hb_profiler.step() self.hb_profiler.step()
@ -1090,7 +1184,7 @@ class VlmCausalLM(Model):
self.model.config, self.model.config,
self.dtype, self.dtype,
self.device, self.device,
is_warmup is_warmup,
) )
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
@ -1117,14 +1211,14 @@ class VlmCausalLM(Model):
) )
global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
max_input_length = batch.input_ids.shape[1] max_input_length = batch.input_ids.shape[1]
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_BATCH_SIZE_LIST = []
batch_size = 1 batch_size = 1
while batch_size <= max_prefill_batch_size: while batch_size <= max_prefill_batch_size:
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = batch_size * 2 batch_size = batch_size * 2
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size:
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
seq_len = BASE_IMAGE_TOKENS seq_len = BASE_IMAGE_TOKENS
@ -1132,19 +1226,21 @@ class VlmCausalLM(Model):
i = 0 i = 0
while seq_len <= max_input_length: while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i)
i += 1 i += 1
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
#Prefill and decode warmup # Prefill and decode warmup
DECODE_WARMUP_BATCH_SIZE_LIST = [] DECODE_WARMUP_BATCH_SIZE_LIST = []
prefill_batch = None prefill_batch = None
decode_batch = None decode_batch = None
try: try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST:
for seq_len in PREFILL_WARMUP_SEQLEN_LIST : for seq_len in PREFILL_WARMUP_SEQLEN_LIST:
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) batch = self.generate_warmup_batch(
request, seq_len, batch_size, is_warmup
)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
@ -1161,21 +1257,29 @@ class VlmCausalLM(Model):
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing prefill and decode warmup successfully.\n" f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
batch_size = max_prefill_batch_size * 2 batch_size = max_prefill_batch_size * 2
# Decode warmup with bigger batch_size # Decode warmup with bigger batch_size
try: try:
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: if (
DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size
and batch_size <= max_decode_batch_size
):
batches = [] batches = []
for i in range(int(batch_size/max_prefill_batch_size)) : for i in range(int(batch_size / max_prefill_batch_size)):
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) batch = self.generate_warmup_batch(
request,
PREFILL_WARMUP_SEQLEN_LIST[0],
DECODE_WARMUP_BATCH_SIZE_LIST[-1],
is_warmup,
)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch) batches.append(prefill_batch)
while batch_size <= max_decode_batch_size: while batch_size <= max_decode_batch_size:
@ -1184,17 +1288,24 @@ class VlmCausalLM(Model):
batch_size = batch_size * 2 batch_size = batch_size * 2
batches.clear() batches.clear()
for i in range(int(batch_size/max_prefill_batch_size)) : for i in range(int(batch_size / max_prefill_batch_size)):
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) batch = self.generate_warmup_batch(
request,
PREFILL_WARMUP_SEQLEN_LIST[0],
DECODE_WARMUP_BATCH_SIZE_LIST[-1],
is_warmup,
)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch) batches.append(prefill_batch)
batches.clear() batches.clear()
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2
batch_size = max_decode_batch_size batch_size = max_decode_batch_size
for i in range(int(max_decode_batch_size / 2)) : for i in range(int(max_decode_batch_size / 2)):
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) batch = self.generate_warmup_batch(
request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup
)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch) batches.append(prefill_batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup) _, decode_batch, _ = self.generate_token(batches, is_warmup)
@ -1211,9 +1322,9 @@ class VlmCausalLM(Model):
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing decode warmup successfully.\n" f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats}" f"Memory stats: {mem_stats}"
) )
return MAX_BATCH_TOTAL_TOKENS return MAX_BATCH_TOTAL_TOKENS

View File

@ -38,7 +38,10 @@ try:
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash. # These imports can fail on CPU/Non flash.
VLM_BATCH_TYPES = set() VLM_BATCH_TYPES = set()
from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION from text_generation_server.utils.version import (
is_driver_compatible,
MIN_TGI_GAUDI_SYNAPSE_VERSION,
)
class SignalHandler: class SignalHandler:
@ -72,7 +75,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService # Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True) # self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info
@ -101,7 +103,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
max_supported_total_tokens, max_input_tokens, max_total_tokens = self.model.warmup(request) max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(request)
)
# W/A for the skip tokenizer path # W/A for the skip tokenizer path
# We need to call make_tokenizer_optional after the warmup, # We need to call make_tokenizer_optional after the warmup,
@ -194,7 +198,9 @@ def serve(
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if not is_driver_compatible(): if not is_driver_compatible():
logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures") logger.warning(
f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures"
)
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
adapter_to_index = {} adapter_to_index = {}
@ -204,14 +210,19 @@ def serve(
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
logger.info("Server:server_inner: rank ={}".format(rank)) logger.info("Server:server_inner: rank ={}".format(rank))
server_urls = [ server_urls = [
unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"])) unix_socket_template.format(uds_path, rank)
for rank in range(int(os.environ["WORLD_SIZE"]))
] ]
local_url = server_urls[int(os.environ["RANK"])] local_url = server_urls[int(os.environ["RANK"])]
else: else:
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url)) logger.info(
"Server:server_inner: data type = {}, local_url = {}".format(
dtype, local_url
)
)
if dtype == "bfloat16" or None: if dtype == "bfloat16" or None:
data_type = torch.bfloat16 data_type = torch.bfloat16
else: else:

View File

@ -10,7 +10,13 @@ def main(args):
logger.info("TGIService: starting tgi service .... ") logger.info("TGIService: starting tgi service .... ")
logger.info( logger.info(
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format( "TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path args.model_id,
args.revision,
args.sharded,
args.speculate,
args.dtype,
args.trust_remote_code,
args.uds_path,
) )
) )
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
@ -24,7 +30,7 @@ def main(args):
dtype=args.dtype, dtype=args.dtype,
trust_remote_code=args.trust_remote_code, trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path, uds_path=args.uds_path,
max_input_tokens=args.max_input_tokens max_input_tokens=args.max_input_tokens,
) )

View File

@ -8,14 +8,14 @@ from optimum.habana.utils import to_gb_rounded
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
START_TS = None START_TS = None
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
if 'GRAPH_VISUALIZATION' in os.environ: if "GRAPH_VISUALIZATION" in os.environ:
for f in glob.glob('.graph_dumps/*'): for f in glob.glob(".graph_dumps/*"):
os.remove(f) os.remove(f)
def count_hpu_graphs(): def count_hpu_graphs():
return len(glob.glob('.graph_dumps/*PreGraph*')) return len(glob.glob(".graph_dumps/*PreGraph*"))
def dbg_trace(tag, txt): def dbg_trace(tag, txt):
@ -25,7 +25,11 @@ def dbg_trace(tag, txt):
START_TS = time.perf_counter() START_TS = time.perf_counter()
time_offset = time.perf_counter() - START_TS time_offset = time.perf_counter() - START_TS
mem_stats = htorch.hpu.memory.memory_stats() mem_stats = htorch.hpu.memory.memory_stats()
mem_used = to_gb_rounded(mem_stats['InUse']) mem_used = to_gb_rounded(mem_stats["InUse"])
max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) max_mem_used = to_gb_rounded(mem_stats["MaxInUse"])
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' print(
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) f"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB "
f"mmu:{max_mem_used:.1f}GB | {tag} | {txt}",
flush=True,
file=open(DBG_TRACE_FILENAME, "a"),
)

View File

@ -64,10 +64,12 @@ def initialize_torch_distributed():
backend = "hccl" backend = "hccl"
n_hpus = torch.hpu.device_count() n_hpus = torch.hpu.device_count()
if world_size > n_hpus: if world_size > n_hpus:
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") raise ValueError(
f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})."
)
else: else:
try: try:
import oneccl_bindings_for_pytorch # noqa: F401 import oneccl_bindings_for_pytorch # noqa: F401
backend = "ccl" backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None: if os.getenv("CCL_WORKER_COUNT", None) is None:

View File

@ -63,7 +63,9 @@ class StaticWarper:
self.static_warped_scores.copy_(local_scores) self.static_warped_scores.copy_(local_scores)
# Compute logprobs # Compute logprobs
self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1)) self.static_next_logprob.copy_(
torch.log_softmax(self.static_warped_scores, -1)
)
self.static_scores.copy_(scores) self.static_scores.copy_(scores)
self.hpu_graph.replay() self.hpu_graph.replay()
@ -78,7 +80,9 @@ def static_warper(
top_p: Optional[float], top_p: Optional[float],
typical_p: Optional[float], typical_p: Optional[float],
) -> StaticWarper: ) -> StaticWarper:
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
@ -95,13 +99,17 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty self.penalty = penalty
self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) score = torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score) scores.scatter_(1, input_ids, score)
return scores return scores
@ -163,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
batch_size, vocab_size, dtype=scores.dtype, device=scores.device batch_size, vocab_size, dtype=scores.dtype, device=scores.device
) )
token_freq.scatter_add_( token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device) 1,
input_ids,
torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device),
) )
token_freq /= input_size token_freq /= input_size
@ -190,9 +200,13 @@ class HeterogeneousTemperatureLogitsWarper:
The value used to module the logits distribution. The value used to module the logits distribution.
""" """
def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): def __init__(
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = temperature self.temperature = temperature
self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) self.temperature_tensor = torch.tensor(
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature_tensor) scores.div_(self.temperature_tensor)
@ -231,7 +245,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
): ):
self.top_p = top_p self.top_p = top_p
self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) self.top_p_opposite = 1 - torch.tensor(
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@ -248,7 +264,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores return warped_scores
@ -296,7 +314,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
disabled = [x == 0 for x in top_k] disabled = [x == 0 for x in top_k]
if any(disabled): if any(disabled):
self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1) self.top_k_disabled_mask = torch.tensor(
disabled, dtype=torch.bool, device=device
).view(-1, 1)
else: else:
self.top_k_disabled_mask = None self.top_k_disabled_mask = None
@ -332,7 +352,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
self.max_top_k = max(self.top_k) self.max_top_k = max(self.top_k)
if self.top_k_disabled_mask is not None: if self.top_k_disabled_mask is not None:
self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None self.top_k_disabled_mask = (
self.top_k_disabled_mask[indices] if any(disabled) else None
)
return self return self
return None return None
@ -398,11 +420,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
if self.disabled_mask is not None: if self.disabled_mask is not None:
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
@ -416,7 +442,9 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
self.mass_tensor = self.mass_tensor[indices] self.mass_tensor = self.mass_tensor[indices]
if self.disabled_mask is not None: if self.disabled_mask is not None:
self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None self.disabled_mask = (
self.disabled_mask[indices] if any(disabled) else None
)
return self return self
return None return None

View File

@ -254,7 +254,7 @@ class HeterogeneousNextTokenChooser:
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
grammars: List[str], grammars: List[str],
grammar_types: List[int], grammar_types: List[int],
fsm_grammar_states:List[int], fsm_grammar_states: List[int],
quantization_enabled: bool, quantization_enabled: bool,
): ):
warpers = [] warpers = []
@ -447,7 +447,9 @@ class HeterogeneousNextTokenChooser:
next_id = next_id.item() next_id = next_id.item()
self.fsm_grammar_states[grammar_state_index] = ( self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index( self.grammar_processor.advance_at_index(
next_id, past_state, grammar_state_index, next_id,
past_state,
grammar_state_index,
) )
) )
return self return self
@ -544,9 +546,7 @@ def pad_next_token_chooser_parameters(
grammar="", grammar="",
grammar_type=0, grammar_type=0,
) )
parameters.extend( parameters.extend([empty_parameters] * (expected_size - len(parameters)))
[empty_parameters] * (expected_size - len(parameters))
)
return parameters return parameters
@ -701,24 +701,47 @@ def make_tokenizer_optional(tokenizer):
padding, padding,
return_token_type_ids, return_token_type_ids,
truncation, truncation,
max_length max_length,
): ):
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" assert (
assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" return_tensors == "pt"
assert not return_token_type_ids, "inccorrect input arguments when calling TransparentTokenizer" ), "inccorrect input arguments when calling TransparentTokenizer"
assert truncation, "inccorrect input arguments when calling TransparentTokenizer" assert (
padding == "max_length" or padding == "longest"
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
not return_token_type_ids
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
truncation
), "inccorrect input arguments when calling TransparentTokenizer"
def str_token_to_int(i): def str_token_to_int(i):
if i == '?': if i == "?":
return tokenizer.pad_token_id return tokenizer.pad_token_id
else: else:
return int(i) return int(i)
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
for inner_text in text] all_tokens = [
[str_token_to_int(i.strip()) for i in inner_text.split(",")]
for inner_text in text
]
if padding == "longest": if padding == "longest":
max_length = max(len(tokens) for tokens in all_tokens) max_length = max(len(tokens) for tokens in all_tokens)
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]), return {
"attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])} "input_ids": torch.tensor(
[
[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
for tokens in all_tokens
]
),
"attention_mask": torch.tensor(
[
[0] * (max_length - len(tokens)) + [1] * len(tokens)
for tokens in all_tokens
]
),
}
def decode( def decode(
self, self,
@ -728,9 +751,10 @@ def make_tokenizer_optional(tokenizer):
**kwargs, **kwargs,
) -> str: ) -> str:
# I don't think this method is used anywhere and should be removed when doing refactoring # I don't think this method is used anywhere and should be removed when doing refactoring
return ','.join(str(i) for i in to_py_obj(token_ids)) # noqa: F821 return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
import os import os
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
tokenizer.__class__ = _ tokenizer.__class__ = _
tokenizer.is_transparent = True tokenizer.is_transparent = True

View File

@ -1,7 +1,7 @@
from optimum.habana.utils import get_driver_version from optimum.habana.utils import get_driver_version
from packaging.version import Version from packaging.version import Version
MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0") MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.16.0")
def is_driver_compatible(): def is_driver_compatible():