mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
fix prehooks issues
This commit is contained in:
parent
31535bcde2
commit
77dca4dfbe
@ -118,7 +118,9 @@ def serve(
|
|||||||
logger.info("CLI SHARDED = {}".format(num_shard))
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
cmd = (
|
||||||
|
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||||||
|
)
|
||||||
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||||||
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||||||
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
||||||
@ -130,6 +132,7 @@ def serve(
|
|||||||
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||||
do_terminate = False
|
do_terminate = False
|
||||||
current_handler = signal.getsignal(signal.SIGTERM)
|
current_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
||||||
def terminate_handler(sig, frame):
|
def terminate_handler(sig, frame):
|
||||||
nonlocal do_terminate
|
nonlocal do_terminate
|
||||||
do_terminate = True
|
do_terminate = True
|
||||||
|
@ -17,7 +17,9 @@ if is_quantization_enabled:
|
|||||||
|
|
||||||
def patch_scoped_linear_all_reduce(model):
|
def patch_scoped_linear_all_reduce(model):
|
||||||
from deepspeed.module_inject.layers import LinearAllreduce
|
from deepspeed.module_inject.layers import LinearAllreduce
|
||||||
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
from optimum.habana.transformers.models.modeling_all_models import (
|
||||||
|
ScopedLinearAllReduce,
|
||||||
|
)
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
if type(module) is LinearAllreduce:
|
if type(module) is LinearAllreduce:
|
||||||
@ -36,7 +38,13 @@ def setup_quantization(model):
|
|||||||
|
|
||||||
def prepare_model_for_quantization(model):
|
def prepare_model_for_quantization(model):
|
||||||
if is_quantization_enabled:
|
if is_quantization_enabled:
|
||||||
if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]:
|
if model.config.model_type in [
|
||||||
|
"llama",
|
||||||
|
"falcon",
|
||||||
|
"qwen2",
|
||||||
|
"starcoder2",
|
||||||
|
"gemma",
|
||||||
|
]:
|
||||||
patch_scoped_linear_all_reduce(model)
|
patch_scoped_linear_all_reduce(model)
|
||||||
from neural_compressor.torch.quantization import FP8Config, convert
|
from neural_compressor.torch.quantization import FP8Config, convert
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
response = method(request_or_iterator, context)
|
response = method(request_or_iterator, context)
|
||||||
return await response
|
return await response
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else ''
|
trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
|
||||||
method_name = method_name.split("/")[-1]
|
method_name = method_name.split("/")[-1]
|
||||||
logger.exception(f"Method {method_name} encountered an error.")
|
logger.exception(f"Method {method_name} encountered an error.")
|
||||||
|
|
||||||
@ -36,7 +36,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
from .utils.debug import dbg_trace
|
from .utils.debug import dbg_trace
|
||||||
dbg_trace('EXCEPTION', traceback.format_exc())
|
|
||||||
|
dbg_trace("EXCEPTION", traceback.format_exc())
|
||||||
await context.abort_with_status(
|
await context.abort_with_status(
|
||||||
rpc_status.to_status(
|
rpc_status.to_status(
|
||||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||||
|
@ -8,6 +8,7 @@ from huggingface_hub import hf_hub_download, HfApi
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
# Needed to properly setup habana_frameworks
|
# Needed to properly setup habana_frameworks
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
@ -16,10 +17,12 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.starcoder import StarCoder
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
|
||||||
|
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
# from text_generation_server.models.custom_modeling.mllama import (
|
# from text_generation_server.models.custom_modeling.mllama import (
|
||||||
# MllamaForConditionalGeneration,
|
# MllamaForConditionalGeneration,
|
||||||
|
@ -59,7 +59,12 @@ CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
|||||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
|
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
|
||||||
MAX_BATCH_SIZE = int(os.environ.get('MAX_BATCH_SIZE')) if os.environ.get('MAX_BATCH_SIZE') is not None else None
|
MAX_BATCH_SIZE = (
|
||||||
|
int(os.environ.get("MAX_BATCH_SIZE"))
|
||||||
|
if os.environ.get("MAX_BATCH_SIZE") is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def torch_compile_for_eager(func):
|
def torch_compile_for_eager(func):
|
||||||
if LAZY_MODE == 1:
|
if LAZY_MODE == 1:
|
||||||
@ -564,9 +569,9 @@ class CausalLMBatch(Batch):
|
|||||||
bucket_size = max_input_length
|
bucket_size = max_input_length
|
||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, (
|
assert (
|
||||||
"PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
||||||
)
|
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
@ -1080,9 +1085,9 @@ class CausalLM(Model):
|
|||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph
|
bypass_hpu_graph=(
|
||||||
if self.enable_hpu_graph
|
prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
else None,
|
),
|
||||||
)
|
)
|
||||||
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||||
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||||
@ -1099,9 +1104,9 @@ class CausalLM(Model):
|
|||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph
|
bypass_hpu_graph=(
|
||||||
if self.enable_hpu_graph
|
prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
else None,
|
),
|
||||||
)
|
)
|
||||||
if self.model.config.model_type in ["gpt_bigcode"]:
|
if self.model.config.model_type in ["gpt_bigcode"]:
|
||||||
batch.logits, batch.past = logits
|
batch.logits, batch.past = logits
|
||||||
@ -1289,8 +1294,12 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
||||||
|
|
||||||
def warmup(self, request: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
def warmup(
|
||||||
assert MAX_BATCH_SIZE is not None, "MAX_BATCH_SIZE is not set, it should be set in the launcher"
|
self, request: generate_pb2.WarmupRequest
|
||||||
|
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
||||||
|
assert (
|
||||||
|
MAX_BATCH_SIZE is not None
|
||||||
|
), "MAX_BATCH_SIZE is not set, it should be set in the launcher"
|
||||||
MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens
|
MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens
|
||||||
logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}")
|
logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}")
|
||||||
logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}")
|
logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}")
|
||||||
@ -1313,7 +1322,14 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Warmup prefill batch_size
|
# Warmup prefill batch_size
|
||||||
max_input_tokens = request.max_input_tokens
|
max_input_tokens = request.max_input_tokens
|
||||||
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
|
prefill_batch_size_list = [
|
||||||
|
batch
|
||||||
|
for batch in range(
|
||||||
|
PREFILL_BATCH_BUCKET_SIZE,
|
||||||
|
max_prefill_batch_size,
|
||||||
|
PREFILL_BATCH_BUCKET_SIZE,
|
||||||
|
)
|
||||||
|
]
|
||||||
prefill_batch_size_list.append(max_prefill_batch_size)
|
prefill_batch_size_list.append(max_prefill_batch_size)
|
||||||
prefill_seqlen_list = [
|
prefill_seqlen_list = [
|
||||||
seq
|
seq
|
||||||
@ -1399,7 +1415,7 @@ class CausalLM(Model):
|
|||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
max_input_tokens=max_input_tokens
|
max_input_tokens = max_input_tokens
|
||||||
max_total_tokens=MAX_TOTAL_TOKENS
|
max_total_tokens = MAX_TOTAL_TOKENS
|
||||||
|
|
||||||
return max_supported_total_tokens, max_input_tokens, max_total_tokens
|
return max_supported_total_tokens, max_input_tokens, max_total_tokens
|
||||||
|
@ -25,6 +25,7 @@ from transformers.models.llava_next.modeling_llava_next import (
|
|||||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
@ -89,11 +90,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
output_hidden_states = (
|
output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@ -152,8 +161,14 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
labels = kwargs.get("labels", None)
|
labels = kwargs.get("labels", None)
|
||||||
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
if (
|
||||||
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
|
past_key_values is None
|
||||||
|
and pixel_values is not None
|
||||||
|
and input_ids.shape[1] != 1
|
||||||
|
):
|
||||||
|
vision_feature_select_strategy = kwargs.get(
|
||||||
|
"vision_feature_select_strategy", None
|
||||||
|
)
|
||||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||||
vision_feature_select_strategy = (
|
vision_feature_select_strategy = (
|
||||||
vision_feature_select_strategy
|
vision_feature_select_strategy
|
||||||
@ -161,14 +176,20 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
else self.config.vision_feature_select_strategy
|
else self.config.vision_feature_select_strategy
|
||||||
)
|
)
|
||||||
vision_feature_layer = (
|
vision_feature_layer = (
|
||||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
vision_feature_layer
|
||||||
|
if vision_feature_layer is not None
|
||||||
|
else self.config.vision_feature_layer
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Extract the input embeddings
|
# 1. Extract the input embeddings
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
# 2. Merge text and images
|
# 2. Merge text and images
|
||||||
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
batch_size, num_patches, num_channels, height, width = (
|
||||||
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
pixel_values.shape
|
||||||
|
)
|
||||||
|
reshaped_pixel_values = pixel_values.view(
|
||||||
|
batch_size * num_patches, num_channels, height, width
|
||||||
|
)
|
||||||
image_features = self.vision_tower(
|
image_features = self.vision_tower(
|
||||||
reshaped_pixel_values,
|
reshaped_pixel_values,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@ -176,7 +197,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
flash_attention_recompute=flash_attention_recompute,
|
flash_attention_recompute=flash_attention_recompute,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
selected_image_feature = image_features.hidden_states[
|
||||||
|
vision_feature_layer
|
||||||
|
]
|
||||||
|
|
||||||
if vision_feature_select_strategy == "default":
|
if vision_feature_select_strategy == "default":
|
||||||
selected_image_feature = selected_image_feature[:, 1:]
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
@ -192,7 +215,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
|
||||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
height = width = (
|
||||||
|
self.config.vision_config.image_size
|
||||||
|
// self.config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
new_image_features = []
|
new_image_features = []
|
||||||
for image_idx, image_feature in enumerate(image_features):
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
@ -201,7 +227,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
image_feature = image_feature[1:]
|
image_feature = image_feature[1:]
|
||||||
|
|
||||||
if height * width != base_image_feature.shape[0]:
|
if height * width != base_image_feature.shape[0]:
|
||||||
raise ValueError("The number of patches is not consistent with the image size.")
|
raise ValueError(
|
||||||
|
"The number of patches is not consistent with the image size."
|
||||||
|
)
|
||||||
|
|
||||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
image_sizes[image_idx].tolist(),
|
image_sizes[image_idx].tolist(),
|
||||||
@ -209,26 +237,42 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
self.config.vision_config.image_size,
|
self.config.vision_config.image_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
image_feature = image_feature.view(
|
||||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
image_feature = image_feature.permute(
|
||||||
|
4, 0, 2, 1, 3
|
||||||
|
).contiguous()
|
||||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
image_feature = unpad_image(
|
||||||
|
image_feature, image_sizes[image_idx]
|
||||||
|
)
|
||||||
image_feature = torch.cat(
|
image_feature = torch.cat(
|
||||||
(
|
(
|
||||||
image_feature,
|
image_feature,
|
||||||
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
|
self.image_newline[:, None, None].expand(
|
||||||
|
*image_feature.shape[:-1], 1
|
||||||
|
),
|
||||||
),
|
),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
image_feature = torch.cat(
|
||||||
|
(base_image_feature, image_feature), dim=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_feature = image_feature[0]
|
image_feature = image_feature[0]
|
||||||
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
image_feature = torch.cat(
|
||||||
|
(image_feature, self.image_newline[None]), dim=0
|
||||||
|
)
|
||||||
new_image_features.append(image_feature)
|
new_image_features.append(image_feature)
|
||||||
image_features = torch.stack(new_image_features, dim=0)
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
|
inputs_embeds, image_features, input_ids
|
||||||
|
)
|
||||||
|
self.image_offset = (
|
||||||
|
image_features.shape[1] - 1
|
||||||
|
) # image_token has occupied 1 token position.
|
||||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||||
# generation with cache
|
# generation with cache
|
||||||
elif past_key_values is not None:
|
elif past_key_values is not None:
|
||||||
@ -240,7 +284,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
|
|
||||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
batch_index, non_attended_tokens = torch.where(
|
||||||
|
first_layer_past_key_value.float().sum(-2) == 0
|
||||||
|
)
|
||||||
|
|
||||||
# Get the target length
|
# Get the target length
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
@ -268,7 +314,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = (
|
||||||
|
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ def set_model_id(model_id: str):
|
|||||||
global MODEL_ID
|
global MODEL_ID
|
||||||
MODEL_ID = model_id
|
MODEL_ID = model_id
|
||||||
|
|
||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
|
@ -12,6 +12,7 @@ from text_generation_server.utils.speculate import get_speculate
|
|||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +94,9 @@ class Model(ABC):
|
|||||||
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
def warmup(
|
||||||
|
self, batch: generate_pb2.WarmupRequest
|
||||||
|
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ class StarCoderCausalLMBatch(CausalLMBatch):
|
|||||||
def detach_kv_cache(self):
|
def detach_kv_cache(self):
|
||||||
past_keys = []
|
past_keys = []
|
||||||
past_values = []
|
past_values = []
|
||||||
last_dim = int(self.past_key_values[0].size(dim=-1)/2)
|
last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
|
||||||
for key_value in self.past_key_values:
|
for key_value in self.past_key_values:
|
||||||
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
|
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
|
||||||
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
|
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
|
||||||
@ -23,7 +23,9 @@ class StarCoderCausalLMBatch(CausalLMBatch):
|
|||||||
|
|
||||||
def attach_kv_cache(self, past_keys, past_values):
|
def attach_kv_cache(self, past_keys, past_values):
|
||||||
self.past_key_values = [
|
self.past_key_values = [
|
||||||
torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)]
|
torch.cat((key, value), dim=-1)
|
||||||
|
for key, value in zip(past_keys, past_values)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class StarCoder(CausalLM):
|
class StarCoder(CausalLM):
|
||||||
|
@ -66,23 +66,26 @@ IDEFICS2_IMAGE_TOKEN = "<image>"
|
|||||||
|
|
||||||
|
|
||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||||
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
|
BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048))
|
||||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
|
||||||
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072))
|
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get("MAX_BATCH_TOTAL_TOKENS", 131072))
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||||
def round_up(warmup_list:list, num) :
|
|
||||||
|
|
||||||
|
def round_up(warmup_list: list, num):
|
||||||
i = 0
|
i = 0
|
||||||
for i in warmup_list:
|
for i in warmup_list:
|
||||||
if num <= i :
|
if num <= i:
|
||||||
break
|
break
|
||||||
return i
|
return i
|
||||||
|
|
||||||
|
|
||||||
def split(string) -> List[Dict[str, str]]:
|
def split(string) -> List[Dict[str, str]]:
|
||||||
parts = []
|
parts = []
|
||||||
cursor = 0
|
cursor = 0
|
||||||
@ -99,6 +102,7 @@ def split(string) -> List[Dict[str, str]]:
|
|||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
image_seq_len = 64
|
image_seq_len = 64
|
||||||
@ -196,14 +200,19 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
is_warmup: bool = False,
|
is_warmup: bool = False,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
|
|
||||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
|
||||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
requests = [
|
||||||
|
CausalLMRequest.from_pb(idx, req, tokenizer)
|
||||||
|
for idx, req in enumerate(pb.requests)
|
||||||
|
]
|
||||||
|
|
||||||
max_input_length = max(r.data.truncate for r in requests)
|
max_input_length = max(r.data.truncate for r in requests)
|
||||||
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
||||||
# TODO: Add support for sparse batches
|
# TODO: Add support for sparse batches
|
||||||
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
# this means that we cannot shift inputs to the left after a long input sequence
|
# this means that we cannot shift inputs to the left after a long input sequence
|
||||||
@ -226,7 +235,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
bucket_size = max_input_length
|
bucket_size = max_input_length
|
||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if is_warmup is False:
|
if is_warmup is False:
|
||||||
if input_len < max_input_length :
|
if input_len < max_input_length:
|
||||||
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
@ -276,10 +285,14 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
input_length=input_len,
|
input_length=input_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup
|
cls,
|
||||||
|
requests: Iterable[generate_pb2.Request],
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
config,
|
||||||
|
is_warmup,
|
||||||
):
|
):
|
||||||
# Process images first. We need all of them so that the processor
|
# Process images first. We need all of them so that the processor
|
||||||
# can make the image splits the same size. And we need the final
|
# can make the image splits the same size. And we need the final
|
||||||
@ -345,24 +358,24 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if missing_inputs > 0 and image_inputs is not None:
|
if missing_inputs > 0 and image_inputs is not None:
|
||||||
dummy_shape = list(image_inputs['pixel_values'].shape)
|
dummy_shape = list(image_inputs["pixel_values"].shape)
|
||||||
dummy_shape[0] = missing_inputs
|
dummy_shape[0] = missing_inputs
|
||||||
dummy_images = torch.rand(dummy_shape)
|
dummy_images = torch.rand(dummy_shape)
|
||||||
new_image_inputs = {
|
new_image_inputs = {
|
||||||
"pixel_values": torch.cat(
|
"pixel_values": torch.cat(
|
||||||
(image_inputs['pixel_values'], dummy_images), dim=0
|
(image_inputs["pixel_values"], dummy_images), dim=0
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if "pixel_attention_mask" in image_inputs:
|
if "pixel_attention_mask" in image_inputs:
|
||||||
dummy_shape = list(image_inputs['pixel_attention_mask'].shape)
|
dummy_shape = list(image_inputs["pixel_attention_mask"].shape)
|
||||||
dummy_shape[0] = missing_inputs
|
dummy_shape[0] = missing_inputs
|
||||||
dummy_attention = torch.zeros(dummy_shape)
|
dummy_attention = torch.zeros(dummy_shape)
|
||||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||||
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
|
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
|
||||||
)
|
)
|
||||||
if "image_sizes" in image_inputs:
|
if "image_sizes" in image_inputs:
|
||||||
dummy_shape = list(list(image_inputs['image_sizes'])[0])
|
dummy_shape = list(list(image_inputs["image_sizes"])[0])
|
||||||
dummy_shape = missing_inputs*[dummy_shape]
|
dummy_shape = missing_inputs * [dummy_shape]
|
||||||
dummy_sizes = torch.IntTensor(dummy_shape)
|
dummy_sizes = torch.IntTensor(dummy_shape)
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
new_image_inputs["image_sizes"] = torch.cat(
|
||||||
(image_inputs["image_sizes"], dummy_sizes), dim=0
|
(image_inputs["image_sizes"], dummy_sizes), dim=0
|
||||||
@ -406,19 +419,27 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch":
|
def concatenate(
|
||||||
|
cls,
|
||||||
|
batches: List["CausalLMBatch"],
|
||||||
|
pad_token_id: int = 0,
|
||||||
|
is_warmup: bool = False,
|
||||||
|
) -> "CausalLMBatch":
|
||||||
return cls.recombine(batches, pad_token_id, is_warmup)
|
return cls.recombine(batches, pad_token_id, is_warmup)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch":
|
def recombine(
|
||||||
|
cls,
|
||||||
|
batches: List["VlmCausalLMBatch"],
|
||||||
|
pad_token_id: int,
|
||||||
|
is_warmup: bool = False,
|
||||||
|
) -> "VlmCausalLMBatch":
|
||||||
if not all(b.past_key_values is not None for b in batches):
|
if not all(b.past_key_values is not None for b in batches):
|
||||||
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||||
|
|
||||||
total_requests = sum(len(b) for b in batches)
|
total_requests = sum(len(b) for b in batches)
|
||||||
new_bs = total_requests
|
new_bs = total_requests
|
||||||
if is_warmup is False :
|
if is_warmup is False:
|
||||||
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
|
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
|
||||||
batch_id = batches[0].batch_id
|
batch_id = batches[0].batch_id
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
@ -431,31 +452,39 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
# For prefill there is a space allocated only for first token
|
# For prefill there is a space allocated only for first token
|
||||||
# Need to add padding to the max total tokens before first decode
|
# Need to add padding to the max total tokens before first decode
|
||||||
|
|
||||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
moves_needed = [
|
||||||
|
total_requests - len(b) if b.batch_size == new_bs else total_requests
|
||||||
|
for b in batches
|
||||||
|
]
|
||||||
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||||
reshape = (batches[dst_batch_idx].batch_size < new_bs)
|
reshape = batches[dst_batch_idx].batch_size < new_bs
|
||||||
|
|
||||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||||
# FIXME: max_seq_len for non optimized code
|
# FIXME: max_seq_len for non optimized code
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
scenario = 'CONCAT'
|
scenario = "CONCAT"
|
||||||
elif reshape:
|
elif reshape:
|
||||||
scenario = 'RESHAPE'
|
scenario = "RESHAPE"
|
||||||
elif cur_padding[dst_batch_idx] <= 0:
|
elif cur_padding[dst_batch_idx] <= 0:
|
||||||
scenario = 'SHIFT'
|
scenario = "SHIFT"
|
||||||
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
offsets = [
|
||||||
|
biggest_single_chunk(b.max_input_length - max_input_length)
|
||||||
|
for b in batches
|
||||||
|
]
|
||||||
max_input_length = max_input_length + offsets[dst_batch_idx]
|
max_input_length = max_input_length + offsets[dst_batch_idx]
|
||||||
else:
|
else:
|
||||||
# Nothing to do
|
# Nothing to do
|
||||||
return batches[0]
|
return batches[0]
|
||||||
|
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
scenario,
|
||||||
f' reqs:{[len(b) for b in batches]}'
|
f"bs:{[b.batch_size for b in batches]}->{new_bs}"
|
||||||
f' offsets:{offsets}'
|
f" reqs:{[len(b) for b in batches]}"
|
||||||
f' input_lengths:{input_lengths}'
|
f" offsets:{offsets}"
|
||||||
f' cur_padding:{cur_padding}'
|
f" input_lengths:{input_lengths}"
|
||||||
f' dst_batch:{dst_batch_idx}')
|
f" cur_padding:{cur_padding}"
|
||||||
|
f" dst_batch:{dst_batch_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||||
flat_requests = list(itertools.chain(*grouped_requests))
|
flat_requests = list(itertools.chain(*grouped_requests))
|
||||||
@ -466,10 +495,14 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||||
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
||||||
batches[dst_batch_idx].expand_bs(new_bs)
|
batches[dst_batch_idx].expand_bs(new_bs)
|
||||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
batches[dst_batch_idx].move_data(
|
||||||
|
[batches[i] for i in range(len(batches)) if i != dst_batch_idx]
|
||||||
|
)
|
||||||
|
|
||||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
parameters = [r.data.parameters for r in flat_requests]
|
parameters = [r.data.parameters for r in flat_requests]
|
||||||
# append the dummy parameters for dummy requests
|
# append the dummy parameters for dummy requests
|
||||||
@ -480,7 +513,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
fsm_grammar_states = [0] * batch_size
|
fsm_grammar_states = [0] * batch_size
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
for i, req in enumerate(batch.requests):
|
for i, req in enumerate(batch.requests):
|
||||||
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
|
fsm_grammar_states[req.idx] = (
|
||||||
|
batch.next_token_chooser.fsm_grammar_states[i]
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
parameters,
|
parameters,
|
||||||
@ -513,6 +548,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
input_length=input_length,
|
input_length=input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(Model):
|
class VlmCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -561,18 +597,14 @@ class VlmCausalLM(Model):
|
|||||||
htorch.core.hpu_set_env()
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = self.get_deepspeed_model(
|
model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
|
||||||
model_class, model_id, dtype, revision
|
|
||||||
)
|
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
|
|
||||||
# Check support for rope scaling
|
# Check support for rope scaling
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
model_id
|
|
||||||
)
|
|
||||||
if hasattr(config, "rope_scaling"):
|
if hasattr(config, "rope_scaling"):
|
||||||
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||||
|
|
||||||
@ -581,23 +613,29 @@ class VlmCausalLM(Model):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**model_kwargs
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = (
|
||||||
|
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
|
)
|
||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
model = remove_kv_cache_from_output(model)
|
model = remove_kv_cache_from_output(model)
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
|
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
else:
|
else:
|
||||||
if LAZY_MODE == 0:
|
if LAZY_MODE == 0:
|
||||||
# It is said that "keep_input_mutations" is safe for inference to be done
|
# It is said that "keep_input_mutations" is safe for inference to be done
|
||||||
dbg_trace(
|
dbg_trace("TORCH COMPILE", "Torch compiling of model")
|
||||||
"TORCH COMPILE", 'Torch compiling of model')
|
model.model = torch.compile(
|
||||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
model.model,
|
||||||
|
backend="hpu_backend",
|
||||||
|
options={"keep_input_mutations": True},
|
||||||
|
)
|
||||||
|
|
||||||
model = hq_env.setup_quantization(model)
|
model = hq_env.setup_quantization(model)
|
||||||
|
|
||||||
@ -647,11 +685,15 @@ class VlmCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create profiler
|
# Create profiler
|
||||||
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
|
||||||
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
||||||
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
||||||
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
self.profiling_warmup_steps = (
|
||||||
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
||||||
|
)
|
||||||
|
self.profiling_steps = (
|
||||||
|
int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
||||||
|
)
|
||||||
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
||||||
if self.profiling_steps > 0:
|
if self.profiling_steps > 0:
|
||||||
self.hb_profiler = HabanaProfile(
|
self.hb_profiler = HabanaProfile(
|
||||||
@ -659,14 +701,13 @@ class VlmCausalLM(Model):
|
|||||||
warmup=self.profiling_warmup_steps,
|
warmup=self.profiling_warmup_steps,
|
||||||
active=self.profiling_steps,
|
active=self.profiling_steps,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
record_shapes=record_shapes
|
record_shapes=record_shapes,
|
||||||
)
|
)
|
||||||
self.hb_profiler.start()
|
self.hb_profiler.start()
|
||||||
else:
|
else:
|
||||||
self.hb_profiler = None
|
self.hb_profiler = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return self.batch_class
|
return self.batch_class
|
||||||
@ -679,20 +720,20 @@ class VlmCausalLM(Model):
|
|||||||
model_class,
|
model_class,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||||
|
|
||||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||||
model_kwargs = {
|
model_kwargs = {"revision": revision}
|
||||||
"revision": revision
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize process(es) for DeepSpeed
|
# Initialize process(es) for DeepSpeed
|
||||||
deepspeed.init_distributed(dist_backend="hccl")
|
deepspeed.init_distributed(dist_backend="hccl")
|
||||||
logger.info(
|
logger.info(
|
||||||
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
|
||||||
|
world_size, rank, local_rank
|
||||||
|
)
|
||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||||
load_to_meta = model_on_meta(config)
|
load_to_meta = model_on_meta(config)
|
||||||
@ -710,14 +751,18 @@ class VlmCausalLM(Model):
|
|||||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
||||||
# TODO: revisit placement on CPU when auto-injection is possible
|
# TODO: revisit placement on CPU when auto-injection is possible
|
||||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||||
model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
model = model_class.from_pretrained(
|
||||||
|
model_id, torch_dtype=dtype, **model_kwargs
|
||||||
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
ds_inference_kwargs = {"dtype": dtype}
|
ds_inference_kwargs = {"dtype": dtype}
|
||||||
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||||
ds_inference_kwargs["enable_cuda_graph"] = False
|
ds_inference_kwargs["enable_cuda_graph"] = False
|
||||||
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config)
|
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(
|
||||||
|
model.language_model.config
|
||||||
|
)
|
||||||
|
|
||||||
if load_to_meta:
|
if load_to_meta:
|
||||||
# model loaded to meta is managed differently
|
# model loaded to meta is managed differently
|
||||||
@ -734,12 +779,12 @@ class VlmCausalLM(Model):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
||||||
return {
|
return {"type": rope_scaling, "factor": float(rope_factor)}
|
||||||
'type': rope_scaling, 'factor': float(rope_factor)
|
|
||||||
}
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
return self.tokenizer.decode(
|
||||||
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
@ -748,7 +793,9 @@ class VlmCausalLM(Model):
|
|||||||
read_offset: int = 0,
|
read_offset: int = 0,
|
||||||
) -> Tuple[str, int, int]:
|
) -> Tuple[str, int, int]:
|
||||||
if is_tokenizer_transparent(self.tokenizer):
|
if is_tokenizer_transparent(self.tokenizer):
|
||||||
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
|
new_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[read_offset:], skip_special_tokens=False
|
||||||
|
)
|
||||||
return new_text, read_offset, len(all_input_ids)
|
return new_text, read_offset, len(all_input_ids)
|
||||||
else:
|
else:
|
||||||
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
||||||
@ -776,7 +823,7 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
hpu_kwargs = {}
|
hpu_kwargs = {}
|
||||||
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
||||||
if self.model.config.model_type == "llama" :
|
if self.model.config.model_type == "llama":
|
||||||
hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
|
hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
|
||||||
|
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
@ -814,19 +861,27 @@ class VlmCausalLM(Model):
|
|||||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
else:
|
else:
|
||||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
token_idx_scalar = (
|
||||||
|
batch.attention_mask.shape[-1] - batch.right_padding
|
||||||
|
)
|
||||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
input_length = batch.input_length
|
input_length = batch.input_length
|
||||||
if logits.shape[-2] > 1:
|
if logits.shape[-2] > 1:
|
||||||
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs, _, _ = (
|
||||||
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
|
batch.next_token_chooser(
|
||||||
|
batch.input_ids,
|
||||||
|
logits[:, input_length - 1 : input_length, :].squeeze(-2),
|
||||||
|
self.speculate,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs, _, _ = (
|
||||||
|
batch.next_token_chooser(
|
||||||
batch.input_ids, logits.squeeze(-2), self.speculate
|
batch.input_ids, logits.squeeze(-2), self.speculate
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# Speculation is not active for causal
|
# Speculation is not active for causal
|
||||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
@ -836,23 +891,29 @@ class VlmCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_batches.append({
|
prev_batches.append(
|
||||||
'next_token_ids': next_token_ids,
|
{
|
||||||
'next_token_logprobs': next_token_logprobs,
|
"next_token_ids": next_token_ids,
|
||||||
})
|
"next_token_logprobs": next_token_logprobs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for req_idx, req in enumerate(batch.requests):
|
for req_idx, req in enumerate(batch.requests):
|
||||||
requests_to_generate.append({
|
requests_to_generate.append(
|
||||||
'req': req,
|
{
|
||||||
'prev_req_idx': req.idx,
|
"req": req,
|
||||||
'batch_id': batch_id,
|
"prev_req_idx": req.idx,
|
||||||
'seed': batch.next_token_chooser.seeds[req_idx],
|
"batch_id": batch_id,
|
||||||
'do_sample': batch.next_token_chooser.do_sample[req_idx],
|
"seed": batch.next_token_chooser.seeds[req_idx],
|
||||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
"do_sample": batch.next_token_chooser.do_sample[req_idx],
|
||||||
'top_token_ids': batch_top_token_ids[req_idx],
|
"top_n_tokens": batch.top_n_tokens[req_idx],
|
||||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
"top_token_ids": batch_top_token_ids[req_idx],
|
||||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
"top_token_logprobs": batch_top_token_logprobs[req_idx],
|
||||||
})
|
"grammar_state": batch.next_token_chooser.fsm_grammar_states[
|
||||||
|
req.idx
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
@ -867,7 +928,9 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
if prefill:
|
if prefill:
|
||||||
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
batch.position_ids = (
|
||||||
|
torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch.position_ids += 1
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
@ -878,7 +941,9 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
# Stage 2. Prepare new batch for speculative scheduling
|
# Stage 2. Prepare new batch for speculative scheduling
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
|
batch = self.batch_type.concatenate(
|
||||||
|
batches, self.tokenizer.pad_token_id, is_warmup
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
@ -886,15 +951,24 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
# Check if we need to do any bookkeeping first
|
# Check if we need to do any bookkeeping first
|
||||||
if not prefill:
|
if not prefill:
|
||||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
batch = batch.__class__.recombine(
|
||||||
|
[batch], self.tokenizer.pad_token_id, is_warmup
|
||||||
|
)
|
||||||
|
|
||||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
scenario = "PREFILL" if prefill else "GENERATE"
|
||||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
if (
|
||||||
|
self.enable_hpu_graph
|
||||||
|
and self.limit_hpu_graph
|
||||||
|
and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
||||||
|
!= self.prev_bs
|
||||||
|
):
|
||||||
self.model.clear_cache()
|
self.model.clear_cache()
|
||||||
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
scenario,
|
||||||
#assert batch.right_padding > 0, 'No more room for next token!'
|
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
|
||||||
|
)
|
||||||
|
# assert batch.right_padding > 0, 'No more room for next token!'
|
||||||
|
|
||||||
# Execute batch
|
# Execute batch
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -908,21 +982,27 @@ class VlmCausalLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
batch.pixel_values,
|
batch.pixel_values,
|
||||||
batch.image_sizes,
|
batch.image_sizes,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=(
|
||||||
|
prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||||
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||||
# - we've already generated the first and only needed token in the prefill phase
|
# - we've already generated the first and only needed token in the prefill phase
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx = torch.tensor(
|
||||||
|
batch.attention_mask.shape[-1] - batch.right_padding
|
||||||
|
).to(self.device)
|
||||||
batch.logits = self.forward(
|
batch.logits = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=(
|
||||||
|
prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
@ -932,40 +1012,45 @@ class VlmCausalLM(Model):
|
|||||||
# Stage 3. Finish and return previous generations
|
# Stage 3. Finish and return previous generations
|
||||||
stopped = len(requests_to_generate) > 0
|
stopped = len(requests_to_generate) > 0
|
||||||
for prev_batch in prev_batches:
|
for prev_batch in prev_batches:
|
||||||
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
|
prev_batch["next_token_logprobs"] = prev_batch[
|
||||||
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
|
"next_token_logprobs"
|
||||||
|
].tolist()
|
||||||
|
prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
for req_data in requests_to_generate:
|
for req_data in requests_to_generate:
|
||||||
req = req_data['req']
|
req = req_data["req"]
|
||||||
i = req_data['prev_req_idx']
|
i = req_data["prev_req_idx"]
|
||||||
prev_batch_id = req_data['batch_id']
|
prev_batch_id = req_data["batch_id"]
|
||||||
assert len(prev_batches) > prev_batch_id
|
assert len(prev_batches) > prev_batch_id
|
||||||
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
|
next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
|
||||||
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
|
next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
|
||||||
|
|
||||||
request = req.data
|
request = req.data
|
||||||
input_length = req.input_length
|
input_length = req.input_length
|
||||||
prefix_offset = req.prefix_offset
|
prefix_offset = req.prefix_offset
|
||||||
read_offset = req.read_offset
|
read_offset = req.read_offset
|
||||||
do_sample = req_data['do_sample']
|
do_sample = req_data["do_sample"]
|
||||||
seed = req_data['seed']
|
seed = req_data["seed"]
|
||||||
stopping_criteria = req.stopping_criteria
|
stopping_criteria = req.stopping_criteria
|
||||||
all_input_ids = req.all_input_ids
|
all_input_ids = req.all_input_ids
|
||||||
next_token_id = next_token_ids_cpu[i]
|
next_token_id = next_token_ids_cpu[i]
|
||||||
next_token_logprob = next_token_logprobs[i]
|
next_token_logprob = next_token_logprobs[i]
|
||||||
top_n_tokens = req_data['top_n_tokens']
|
top_n_tokens = req_data["top_n_tokens"]
|
||||||
top_token_ids = req_data['top_token_ids']
|
top_token_ids = req_data["top_token_ids"]
|
||||||
top_token_logprobs = req_data['top_token_logprobs']
|
top_token_logprobs = req_data["top_token_logprobs"]
|
||||||
grammar_state = req_data['grammar_state']
|
grammar_state = req_data["grammar_state"]
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids[input_length] = next_token_id
|
all_input_ids[input_length] = next_token_id
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
|
if (
|
||||||
next_token_text = ''
|
is_tokenizer_transparent(self.tokenizer)
|
||||||
|
and len(stopping_criteria.stop_sequence_criterias) == 0
|
||||||
|
):
|
||||||
|
next_token_text = ""
|
||||||
else:
|
else:
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
||||||
@ -989,7 +1074,11 @@ class VlmCausalLM(Model):
|
|||||||
output_text = None
|
output_text = None
|
||||||
else:
|
else:
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
all_input_ids[
|
||||||
|
new_input_length
|
||||||
|
- stopping_criteria.current_tokens : new_input_length,
|
||||||
|
0,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
@ -1004,7 +1093,7 @@ class VlmCausalLM(Model):
|
|||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + next_token_logprobs
|
prefill_logprobs = [float("nan")] + next_token_logprobs
|
||||||
prefill_token_ids = all_input_ids[0: new_input_length - 1]
|
prefill_token_ids = all_input_ids[0 : new_input_length - 1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
@ -1073,7 +1162,12 @@ class VlmCausalLM(Model):
|
|||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
self.step = self.step + 1
|
self.step = self.step + 1
|
||||||
if self.hb_profiler is not None:
|
if self.hb_profiler is not None:
|
||||||
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
if (
|
||||||
|
self.step
|
||||||
|
> self.profiling_wait_steps
|
||||||
|
+ self.profiling_warmup_steps
|
||||||
|
+ self.profiling_steps
|
||||||
|
):
|
||||||
self.hb_profiler.stop()
|
self.hb_profiler.stop()
|
||||||
else:
|
else:
|
||||||
self.hb_profiler.step()
|
self.hb_profiler.step()
|
||||||
@ -1090,7 +1184,7 @@ class VlmCausalLM(Model):
|
|||||||
self.model.config,
|
self.model.config,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
is_warmup
|
is_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
|
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
|
||||||
@ -1124,7 +1218,7 @@ class VlmCausalLM(Model):
|
|||||||
while batch_size <= max_prefill_batch_size:
|
while batch_size <= max_prefill_batch_size:
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
batch_size = batch_size * 2
|
batch_size = batch_size * 2
|
||||||
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size:
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
||||||
|
|
||||||
seq_len = BASE_IMAGE_TOKENS
|
seq_len = BASE_IMAGE_TOKENS
|
||||||
@ -1132,19 +1226,21 @@ class VlmCausalLM(Model):
|
|||||||
i = 0
|
i = 0
|
||||||
while seq_len <= max_input_length:
|
while seq_len <= max_input_length:
|
||||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||||
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i)
|
||||||
i += 1
|
i += 1
|
||||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
||||||
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
||||||
|
|
||||||
#Prefill and decode warmup
|
# Prefill and decode warmup
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||||
prefill_batch = None
|
prefill_batch = None
|
||||||
decode_batch = None
|
decode_batch = None
|
||||||
try:
|
try:
|
||||||
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST:
|
||||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
for seq_len in PREFILL_WARMUP_SEQLEN_LIST:
|
||||||
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
|
batch = self.generate_warmup_batch(
|
||||||
|
request, seq_len, batch_size, is_warmup
|
||||||
|
)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
|
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
|
||||||
|
|
||||||
@ -1172,10 +1268,18 @@ class VlmCausalLM(Model):
|
|||||||
batch_size = max_prefill_batch_size * 2
|
batch_size = max_prefill_batch_size * 2
|
||||||
# Decode warmup with bigger batch_size
|
# Decode warmup with bigger batch_size
|
||||||
try:
|
try:
|
||||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
if (
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size
|
||||||
|
and batch_size <= max_decode_batch_size
|
||||||
|
):
|
||||||
batches = []
|
batches = []
|
||||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
for i in range(int(batch_size / max_prefill_batch_size)):
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
batch = self.generate_warmup_batch(
|
||||||
|
request,
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST[0],
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST[-1],
|
||||||
|
is_warmup,
|
||||||
|
)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
while batch_size <= max_decode_batch_size:
|
while batch_size <= max_decode_batch_size:
|
||||||
@ -1184,17 +1288,24 @@ class VlmCausalLM(Model):
|
|||||||
batch_size = batch_size * 2
|
batch_size = batch_size * 2
|
||||||
batches.clear()
|
batches.clear()
|
||||||
|
|
||||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
for i in range(int(batch_size / max_prefill_batch_size)):
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
batch = self.generate_warmup_batch(
|
||||||
|
request,
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST[0],
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST[-1],
|
||||||
|
is_warmup,
|
||||||
|
)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
batches.clear()
|
batches.clear()
|
||||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
||||||
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2
|
||||||
batch_size = max_decode_batch_size
|
batch_size = max_decode_batch_size
|
||||||
for i in range(int(max_decode_batch_size / 2)) :
|
for i in range(int(max_decode_batch_size / 2)):
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
|
batch = self.generate_warmup_batch(
|
||||||
|
request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup
|
||||||
|
)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||||
|
@ -38,7 +38,10 @@ try:
|
|||||||
except (ImportError, NotImplementedError):
|
except (ImportError, NotImplementedError):
|
||||||
# These imports can fail on CPU/Non flash.
|
# These imports can fail on CPU/Non flash.
|
||||||
VLM_BATCH_TYPES = set()
|
VLM_BATCH_TYPES = set()
|
||||||
from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION
|
from text_generation_server.utils.version import (
|
||||||
|
is_driver_compatible,
|
||||||
|
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
@ -72,7 +75,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
# Force inference mode for the lifetime of TextGenerationService
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
|
||||||
@ -101,7 +103,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
max_supported_total_tokens, max_input_tokens, max_total_tokens = self.model.warmup(request)
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
|
self.model.warmup(request)
|
||||||
|
)
|
||||||
|
|
||||||
# W/A for the skip tokenizer path
|
# W/A for the skip tokenizer path
|
||||||
# We need to call make_tokenizer_optional after the warmup,
|
# We need to call make_tokenizer_optional after the warmup,
|
||||||
@ -194,7 +198,9 @@ def serve(
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if not is_driver_compatible():
|
if not is_driver_compatible():
|
||||||
logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures")
|
logger.warning(
|
||||||
|
f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures"
|
||||||
|
)
|
||||||
|
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
adapter_to_index = {}
|
adapter_to_index = {}
|
||||||
@ -204,14 +210,19 @@ def serve(
|
|||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
logger.info("Server:server_inner: rank ={}".format(rank))
|
logger.info("Server:server_inner: rank ={}".format(rank))
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"]))
|
unix_socket_template.format(uds_path, rank)
|
||||||
|
for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||||
]
|
]
|
||||||
local_url = server_urls[int(os.environ["RANK"])]
|
local_url = server_urls[int(os.environ["RANK"])]
|
||||||
else:
|
else:
|
||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url))
|
logger.info(
|
||||||
|
"Server:server_inner: data type = {}, local_url = {}".format(
|
||||||
|
dtype, local_url
|
||||||
|
)
|
||||||
|
)
|
||||||
if dtype == "bfloat16" or None:
|
if dtype == "bfloat16" or None:
|
||||||
data_type = torch.bfloat16
|
data_type = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
|
@ -10,7 +10,13 @@ def main(args):
|
|||||||
logger.info("TGIService: starting tgi service .... ")
|
logger.info("TGIService: starting tgi service .... ")
|
||||||
logger.info(
|
logger.info(
|
||||||
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
||||||
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path
|
args.model_id,
|
||||||
|
args.revision,
|
||||||
|
args.sharded,
|
||||||
|
args.speculate,
|
||||||
|
args.dtype,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.uds_path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||||||
@ -24,7 +30,7 @@ def main(args):
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
uds_path=args.uds_path,
|
uds_path=args.uds_path,
|
||||||
max_input_tokens=args.max_input_tokens
|
max_input_tokens=args.max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,14 +8,14 @@ from optimum.habana.utils import to_gb_rounded
|
|||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
START_TS = None
|
START_TS = None
|
||||||
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
|
||||||
if 'GRAPH_VISUALIZATION' in os.environ:
|
if "GRAPH_VISUALIZATION" in os.environ:
|
||||||
for f in glob.glob('.graph_dumps/*'):
|
for f in glob.glob(".graph_dumps/*"):
|
||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
|
|
||||||
def count_hpu_graphs():
|
def count_hpu_graphs():
|
||||||
return len(glob.glob('.graph_dumps/*PreGraph*'))
|
return len(glob.glob(".graph_dumps/*PreGraph*"))
|
||||||
|
|
||||||
|
|
||||||
def dbg_trace(tag, txt):
|
def dbg_trace(tag, txt):
|
||||||
@ -25,7 +25,11 @@ def dbg_trace(tag, txt):
|
|||||||
START_TS = time.perf_counter()
|
START_TS = time.perf_counter()
|
||||||
time_offset = time.perf_counter() - START_TS
|
time_offset = time.perf_counter() - START_TS
|
||||||
mem_stats = htorch.hpu.memory.memory_stats()
|
mem_stats = htorch.hpu.memory.memory_stats()
|
||||||
mem_used = to_gb_rounded(mem_stats['InUse'])
|
mem_used = to_gb_rounded(mem_stats["InUse"])
|
||||||
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
|
max_mem_used = to_gb_rounded(mem_stats["MaxInUse"])
|
||||||
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
|
print(
|
||||||
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
|
f"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB "
|
||||||
|
f"mmu:{max_mem_used:.1f}GB | {tag} | {txt}",
|
||||||
|
flush=True,
|
||||||
|
file=open(DBG_TRACE_FILENAME, "a"),
|
||||||
|
)
|
||||||
|
@ -64,7 +64,9 @@ def initialize_torch_distributed():
|
|||||||
backend = "hccl"
|
backend = "hccl"
|
||||||
n_hpus = torch.hpu.device_count()
|
n_hpus = torch.hpu.device_count()
|
||||||
if world_size > n_hpus:
|
if world_size > n_hpus:
|
||||||
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
|
raise ValueError(
|
||||||
|
f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import oneccl_bindings_for_pytorch # noqa: F401
|
import oneccl_bindings_for_pytorch # noqa: F401
|
||||||
|
@ -63,7 +63,9 @@ class StaticWarper:
|
|||||||
|
|
||||||
self.static_warped_scores.copy_(local_scores)
|
self.static_warped_scores.copy_(local_scores)
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1))
|
self.static_next_logprob.copy_(
|
||||||
|
torch.log_softmax(self.static_warped_scores, -1)
|
||||||
|
)
|
||||||
|
|
||||||
self.static_scores.copy_(scores)
|
self.static_scores.copy_(scores)
|
||||||
self.hpu_graph.replay()
|
self.hpu_graph.replay()
|
||||||
@ -78,7 +80,9 @@ def static_warper(
|
|||||||
top_p: Optional[float],
|
top_p: Optional[float],
|
||||||
typical_p: Optional[float],
|
typical_p: Optional[float],
|
||||||
) -> StaticWarper:
|
) -> StaticWarper:
|
||||||
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
return StaticWarper(
|
||||||
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
@ -95,13 +99,17 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
|
self.penalty_tensor = torch.tensor(
|
||||||
|
penalty, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor)
|
score = torch.where(
|
||||||
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
|
)
|
||||||
|
|
||||||
scores.scatter_(1, input_ids, score)
|
scores.scatter_(1, input_ids, score)
|
||||||
return scores
|
return scores
|
||||||
@ -163,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
|
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
|
||||||
)
|
)
|
||||||
token_freq.scatter_add_(
|
token_freq.scatter_add_(
|
||||||
1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device)
|
1,
|
||||||
|
input_ids,
|
||||||
|
torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device),
|
||||||
)
|
)
|
||||||
token_freq /= input_size
|
token_freq /= input_size
|
||||||
|
|
||||||
@ -190,9 +200,13 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
The value used to module the logits distribution.
|
The value used to module the logits distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device):
|
def __init__(
|
||||||
|
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
||||||
|
):
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1)
|
self.temperature_tensor = torch.tensor(
|
||||||
|
temperature, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
scores.div_(self.temperature_tensor)
|
scores.div_(self.temperature_tensor)
|
||||||
@ -231,7 +245,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
):
|
):
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1)
|
self.top_p_opposite = 1 - torch.tensor(
|
||||||
|
top_p, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
@ -248,7 +264,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
return warped_scores
|
return warped_scores
|
||||||
@ -296,7 +314,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
disabled = [x == 0 for x in top_k]
|
disabled = [x == 0 for x in top_k]
|
||||||
|
|
||||||
if any(disabled):
|
if any(disabled):
|
||||||
self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1)
|
self.top_k_disabled_mask = torch.tensor(
|
||||||
|
disabled, dtype=torch.bool, device=device
|
||||||
|
).view(-1, 1)
|
||||||
else:
|
else:
|
||||||
self.top_k_disabled_mask = None
|
self.top_k_disabled_mask = None
|
||||||
|
|
||||||
@ -332,7 +352,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
self.max_top_k = max(self.top_k)
|
self.max_top_k = max(self.top_k)
|
||||||
|
|
||||||
if self.top_k_disabled_mask is not None:
|
if self.top_k_disabled_mask is not None:
|
||||||
self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None
|
self.top_k_disabled_mask = (
|
||||||
|
self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
return None
|
return None
|
||||||
@ -398,11 +420,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
if self.disabled_mask is not None:
|
if self.disabled_mask is not None:
|
||||||
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
||||||
|
|
||||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||||
|
1, last_ind.view(-1, 1)
|
||||||
|
)
|
||||||
if self.min_tokens_to_keep > 1:
|
if self.min_tokens_to_keep > 1:
|
||||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
@ -416,7 +442,9 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
self.mass_tensor = self.mass_tensor[indices]
|
self.mass_tensor = self.mass_tensor[indices]
|
||||||
|
|
||||||
if self.disabled_mask is not None:
|
if self.disabled_mask is not None:
|
||||||
self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None
|
self.disabled_mask = (
|
||||||
|
self.disabled_mask[indices] if any(disabled) else None
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
return None
|
return None
|
||||||
|
@ -254,7 +254,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
grammars: List[str],
|
grammars: List[str],
|
||||||
grammar_types: List[int],
|
grammar_types: List[int],
|
||||||
fsm_grammar_states:List[int],
|
fsm_grammar_states: List[int],
|
||||||
quantization_enabled: bool,
|
quantization_enabled: bool,
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
@ -447,7 +447,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
next_id = next_id.item()
|
next_id = next_id.item()
|
||||||
self.fsm_grammar_states[grammar_state_index] = (
|
self.fsm_grammar_states[grammar_state_index] = (
|
||||||
self.grammar_processor.advance_at_index(
|
self.grammar_processor.advance_at_index(
|
||||||
next_id, past_state, grammar_state_index,
|
next_id,
|
||||||
|
past_state,
|
||||||
|
grammar_state_index,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
@ -544,9 +546,7 @@ def pad_next_token_chooser_parameters(
|
|||||||
grammar="",
|
grammar="",
|
||||||
grammar_type=0,
|
grammar_type=0,
|
||||||
)
|
)
|
||||||
parameters.extend(
|
parameters.extend([empty_parameters] * (expected_size - len(parameters)))
|
||||||
[empty_parameters] * (expected_size - len(parameters))
|
|
||||||
)
|
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
@ -701,24 +701,47 @@ def make_tokenizer_optional(tokenizer):
|
|||||||
padding,
|
padding,
|
||||||
return_token_type_ids,
|
return_token_type_ids,
|
||||||
truncation,
|
truncation,
|
||||||
max_length
|
max_length,
|
||||||
):
|
):
|
||||||
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
|
assert (
|
||||||
assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer"
|
return_tensors == "pt"
|
||||||
assert not return_token_type_ids, "inccorrect input arguments when calling TransparentTokenizer"
|
), "inccorrect input arguments when calling TransparentTokenizer"
|
||||||
assert truncation, "inccorrect input arguments when calling TransparentTokenizer"
|
assert (
|
||||||
|
padding == "max_length" or padding == "longest"
|
||||||
|
), "inccorrect input arguments when calling TransparentTokenizer"
|
||||||
|
assert (
|
||||||
|
not return_token_type_ids
|
||||||
|
), "inccorrect input arguments when calling TransparentTokenizer"
|
||||||
|
assert (
|
||||||
|
truncation
|
||||||
|
), "inccorrect input arguments when calling TransparentTokenizer"
|
||||||
|
|
||||||
def str_token_to_int(i):
|
def str_token_to_int(i):
|
||||||
if i == '?':
|
if i == "?":
|
||||||
return tokenizer.pad_token_id
|
return tokenizer.pad_token_id
|
||||||
else:
|
else:
|
||||||
return int(i)
|
return int(i)
|
||||||
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
|
|
||||||
for inner_text in text]
|
all_tokens = [
|
||||||
|
[str_token_to_int(i.strip()) for i in inner_text.split(",")]
|
||||||
|
for inner_text in text
|
||||||
|
]
|
||||||
if padding == "longest":
|
if padding == "longest":
|
||||||
max_length = max(len(tokens) for tokens in all_tokens)
|
max_length = max(len(tokens) for tokens in all_tokens)
|
||||||
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]),
|
return {
|
||||||
"attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])}
|
"input_ids": torch.tensor(
|
||||||
|
[
|
||||||
|
[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
|
||||||
|
for tokens in all_tokens
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"attention_mask": torch.tensor(
|
||||||
|
[
|
||||||
|
[0] * (max_length - len(tokens)) + [1] * len(tokens)
|
||||||
|
for tokens in all_tokens
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@ -728,9 +751,10 @@ def make_tokenizer_optional(tokenizer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
# I don't think this method is used anywhere and should be removed when doing refactoring
|
# I don't think this method is used anywhere and should be removed when doing refactoring
|
||||||
return ','.join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
|
return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
|
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
|
||||||
tokenizer.__class__ = _
|
tokenizer.__class__ = _
|
||||||
tokenizer.is_transparent = True
|
tokenizer.is_transparent = True
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from optimum.habana.utils import get_driver_version
|
from optimum.habana.utils import get_driver_version
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0")
|
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.16.0")
|
||||||
|
|
||||||
|
|
||||||
def is_driver_compatible():
|
def is_driver_compatible():
|
||||||
|
Loading…
Reference in New Issue
Block a user