diff --git a/.gitignore b/.gitignore index 248001fda..2a7714994 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ server/fbgemmm .venv/ # Gaudi auto-generated files -hl-smi_log*.txt \ No newline at end of file +hl-smi_log*.txt diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 0af462487..ba5d6ec3e 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -112,4 +112,4 @@ COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] -CMD ["--json-output"] \ No newline at end of file +CMD ["--json-output"] diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index 507dfedb8..8162972d2 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -47,4 +47,4 @@ local-dev-install: install-dependencies bash -c 'source "$$HOME/.cargo/env" && \ make install-server && \ make install-router && \ - make install-launcher' \ No newline at end of file + make install-launcher' diff --git a/backends/gaudi/README.md b/backends/gaudi/README.md index 695fa41ab..c2bff110c 100644 --- a/backends/gaudi/README.md +++ b/backends/gaudi/README.md @@ -16,7 +16,7 @@ make -C backends/gaudi image Option 2: From the Gaudi backend directory: ```bash cd backends/gaudi -make image +make image ``` You can now run the server with the following command: @@ -50,7 +50,7 @@ make -C backends/gaudi local-dev-install Add rust to path: ```bash -. "$HOME/.cargo/env" +. "$HOME/.cargo/env" ``` Option 1: Run the server (sharded model): diff --git a/backends/gaudi/server/dill-0.3.7-patch.sh b/backends/gaudi/server/dill-0.3.7-patch.sh index ad8c8be58..5efd6c54b 100644 --- a/backends/gaudi/server/dill-0.3.7-patch.sh +++ b/backends/gaudi/server/dill-0.3.7-patch.sh @@ -39,7 +39,7 @@ index d0cf543..f6eb662 100644 - self._main = _main_module + self._main = _main_module.module self._ignore = settings['ignore'] if _ignore is None else _ignore - + def load(self): #NOTE: if settings change, need to update attributes obj = StockUnpickler.load(self) - if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): diff --git a/backends/gaudi/server/dill-0.3.8-patch.sh b/backends/gaudi/server/dill-0.3.8-patch.sh index da263960f..414790e7b 100644 --- a/backends/gaudi/server/dill-0.3.8-patch.sh +++ b/backends/gaudi/server/dill-0.3.8-patch.sh @@ -39,7 +39,7 @@ index d42432f..1d251e6 100644 - self._main = _main_module + self._main = _main_module.module self._ignore = settings['ignore'] if _ignore is None else _ignore - + def load(self): #NOTE: if settings change, need to update attributes obj = StockUnpickler.load(self) - if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): @@ -88,4 +88,4 @@ EOF git apply dill-0.3.8.patch python -m pip install . popd -rm -fr dill \ No newline at end of file +rm -fr dill diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index f9b4caa98..700f763e9 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -113,12 +113,14 @@ def serve( logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) if sharded: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" + tgi_file = Path(__file__).resolve().parent / "tgi_service.py" num_shard = int(os.getenv("WORLD_SIZE", "1")) logger.info("CLI SHARDED = {}".format(num_shard)) import subprocess - cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" + cmd = ( + f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" + ) cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" @@ -130,6 +132,7 @@ def serve( with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: do_terminate = False current_handler = signal.getsignal(signal.SIGTERM) + def terminate_handler(sig, frame): nonlocal do_terminate do_terminate = True diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py index e942fdcf0..b03b7e266 100644 --- a/backends/gaudi/server/text_generation_server/habana_quantization_env.py +++ b/backends/gaudi/server/text_generation_server/habana_quantization_env.py @@ -17,7 +17,9 @@ if is_quantization_enabled: def patch_scoped_linear_all_reduce(model): from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + from optimum.habana.transformers.models.modeling_all_models import ( + ScopedLinearAllReduce, + ) for name, module in model.named_children(): if type(module) is LinearAllreduce: @@ -36,7 +38,13 @@ def setup_quantization(model): def prepare_model_for_quantization(model): if is_quantization_enabled: - if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]: + if model.config.model_type in [ + "llama", + "falcon", + "qwen2", + "starcoder2", + "gemma", + ]: patch_scoped_linear_all_reduce(model) from neural_compressor.torch.quantization import FP8Config, convert diff --git a/backends/gaudi/server/text_generation_server/interceptor.py b/backends/gaudi/server/text_generation_server/interceptor.py index 05339282b..47f33cd0b 100644 --- a/backends/gaudi/server/text_generation_server/interceptor.py +++ b/backends/gaudi/server/text_generation_server/interceptor.py @@ -24,7 +24,7 @@ class ExceptionInterceptor(AsyncServerInterceptor): response = method(request_or_iterator, context) return await response except Exception as err: - trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else '' + trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else "" method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") @@ -36,7 +36,8 @@ class ExceptionInterceptor(AsyncServerInterceptor): torch.cuda.empty_cache() from .utils.debug import dbg_trace - dbg_trace('EXCEPTION', traceback.format_exc()) + + dbg_trace("EXCEPTION", traceback.format_exc()) await context.abort_with_status( rpc_status.to_status( status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 502e4d8c5..651b71ecf 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -8,6 +8,7 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path from typing import List, Dict + # Needed to properly setup habana_frameworks from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -16,10 +17,12 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.vlm_causal_lm import VlmCausalLM -#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM + +# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) + # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch # from text_generation_server.models.custom_modeling.mllama import ( # MllamaForConditionalGeneration, diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 3844d89f9..21195d6a2 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -59,7 +59,12 @@ CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2)) -MAX_BATCH_SIZE = int(os.environ.get('MAX_BATCH_SIZE')) if os.environ.get('MAX_BATCH_SIZE') is not None else None +MAX_BATCH_SIZE = ( + int(os.environ.get("MAX_BATCH_SIZE")) + if os.environ.get("MAX_BATCH_SIZE") is not None + else None +) + def torch_compile_for_eager(func): if LAZY_MODE == 1: @@ -564,9 +569,9 @@ class CausalLMBatch(Batch): bucket_size = max_input_length left_padding = max_input_length - input_len if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, ( - "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - ) + assert ( + PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length + ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 @@ -1080,9 +1085,9 @@ class CausalLM(Model): batch.position_ids, token_idx, batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph - if self.enable_hpu_graph - else None, + bypass_hpu_graph=( + prefill and self.limit_hpu_graph if self.enable_hpu_graph else None + ), ) elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): # Don't schedule next forward if max_new_tokens for all requests equals 1 @@ -1099,9 +1104,9 @@ class CausalLM(Model): batch.position_ids, token_idx, batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph - if self.enable_hpu_graph - else None, + bypass_hpu_graph=( + prefill and self.limit_hpu_graph if self.enable_hpu_graph else None + ), ) if self.model.config.model_type in ["gpt_bigcode"]: batch.logits, batch.past = logits @@ -1289,13 +1294,17 @@ class CausalLM(Model): return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) - def warmup(self, request: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]: - assert MAX_BATCH_SIZE is not None, "MAX_BATCH_SIZE is not set, it should be set in the launcher" + def warmup( + self, request: generate_pb2.WarmupRequest + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: + assert ( + MAX_BATCH_SIZE is not None + ), "MAX_BATCH_SIZE is not set, it should be set in the launcher" MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}") logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}") MAX_TOTAL_TOKENS = request.max_total_tokens - + batch = self.batch_type.from_pb( request.batch, self.tokenizer, self.dtype, self.device ) @@ -1313,7 +1322,14 @@ class CausalLM(Model): # Warmup prefill batch_size max_input_tokens = request.max_input_tokens - prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)] + prefill_batch_size_list = [ + batch + for batch in range( + PREFILL_BATCH_BUCKET_SIZE, + max_prefill_batch_size, + PREFILL_BATCH_BUCKET_SIZE, + ) + ] prefill_batch_size_list.append(max_prefill_batch_size) prefill_seqlen_list = [ seq @@ -1349,7 +1365,7 @@ class CausalLM(Model): f"Prefill sequence length list:{prefill_seqlen_list}\n" f"Memory stats: {mem_stats} " ) - + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) decode_batch_size_list = [ @@ -1398,8 +1414,8 @@ class CausalLM(Model): f"Decode batch size list:{decode_batch_size_list}\n" f"Memory stats: {mem_stats} " ) - - max_input_tokens=max_input_tokens - max_total_tokens=MAX_TOTAL_TOKENS - + + max_input_tokens = max_input_tokens + max_total_tokens = MAX_TOTAL_TOKENS + return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index 6d6675ad2..d2fbff545 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -22,9 +22,10 @@ import torch.utils.checkpoint from transformers.models.llava_next.modeling_llava_next import ( unpad_image, ) -from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration +from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. @@ -49,7 +50,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - + def _merge_input_ids_with_image_features( self, inputs_embeds: torch.Tensor, @@ -89,11 +90,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): ): if token_idx is not None: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -120,175 +129,214 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return outputs def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, + ): + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 + The only differences are: + - add new args token_idx + - add the process of merging images into inputs_embeds + """ + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) + else: + use_flash_attention = kwargs.get("use_flash_attention", False) + flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + + position_ids = kwargs.get("position_ids", None) + labels = kwargs.get("labels", None) + if ( + past_key_values is None + and pixel_values is not None + and input_ids.shape[1] != 1 + ): + vision_feature_select_strategy = kwargs.get( + "vision_feature_select_strategy", None + ) + vision_feature_layer = kwargs.get("vision_feature_layer", None) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_feature_layer ) - else: - use_flash_attention = kwargs.get("use_flash_attention", False) - flash_attention_recompute = kwargs.get("flash_attention_recompute", False) - - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: - vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - batch_size, num_patches, num_channels, height, width = pixel_values.shape - reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) - image_features = self.vision_tower( - reshaped_pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + batch_size, num_patches, num_channels, height, width = ( + pixel_values.shape + ) + reshaped_pixel_values = pixel_values.view( + batch_size * num_patches, num_channels, height, width + ) + image_features = self.vision_tower( + reshaped_pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) - selected_image_feature = image_features.hidden_states[vision_feature_layer] + selected_image_feature = image_features.hidden_states[ + vision_feature_layer + ] - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature - image_features = self.multi_modal_projector(selected_image_feature) + image_features = self.multi_modal_projector(selected_image_feature) - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] - if height * width != base_image_feature.shape[0]: - raise ValueError("The number of patches is not consistent with the image size.") - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx].tolist(), - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." ) - - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx].tolist(), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) - self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + inputs_embeds = self._merge_input_ids_with_image_features( + inputs_embeds, image_features, input_ids + ) + self.image_offset = ( + image_features.shape[1] - 1 + ) # image_token has occupied 1 token position. + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None: + seq_len = input_ids.shape[1] + pad_len = seq_len - token_idx.item() + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where( + first_layer_past_key_value.float().sum(-2) == 0 ) - return model_inputs + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = extended_attention_mask + attention_mask[:, -pad_len:] = 0 + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = ( + torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + ) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "labels": labels, + "use_flash_attention": use_flash_attention, + "flash_attention_recompute": flash_attention_recompute, + } + ) + + return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index a48ea6de5..30a5d3da4 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -58,6 +58,7 @@ def set_model_id(model_id: str): global MODEL_ID MODEL_ID = model_id + # NOTE: eventually we should move this into the router and pass back the # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py index c691f15d2..4fda22713 100644 --- a/backends/gaudi/server/text_generation_server/models/model.py +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -12,6 +12,7 @@ from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.pb import generate_pb2 + BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -93,7 +94,9 @@ class Model(ABC): ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: raise NotImplementedError - def warmup(self, batch: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]: + def warmup( + self, batch: generate_pb2.WarmupRequest + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: self.generate_token(batch) return None, None, None diff --git a/backends/gaudi/server/text_generation_server/models/starcoder.py b/backends/gaudi/server/text_generation_server/models/starcoder.py index bb13503cc..6c6ca2cf9 100644 --- a/backends/gaudi/server/text_generation_server/models/starcoder.py +++ b/backends/gaudi/server/text_generation_server/models/starcoder.py @@ -13,7 +13,7 @@ class StarCoderCausalLMBatch(CausalLMBatch): def detach_kv_cache(self): past_keys = [] past_values = [] - last_dim = int(self.past_key_values[0].size(dim=-1)/2) + last_dim = int(self.past_key_values[0].size(dim=-1) / 2) for key_value in self.past_key_values: past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) @@ -23,7 +23,9 @@ class StarCoderCausalLMBatch(CausalLMBatch): def attach_kv_cache(self, past_keys, past_values): self.past_key_values = [ - torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)] + torch.cat((key, value), dim=-1) + for key, value in zip(past_keys, past_values) + ] class StarCoder(CausalLM): @@ -42,4 +44,4 @@ class StarCoder(CausalLM): @property def batch_type(self) -> Type[CausalLMBatch]: - return StarCoderCausalLMBatch \ No newline at end of file + return StarCoderCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index c5d9eda55..181bc51af 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -66,23 +66,26 @@ IDEFICS2_IMAGE_TOKEN = "" IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") -BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) -MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) -MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) +BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048)) +MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) +MAX_BATCH_TOTAL_TOKENS = int(os.environ.get("MAX_BATCH_TOTAL_TOKENS", 131072)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) +LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = [] DECODE_WARMUP_BATCH_SIZE_LIST = [] -def round_up(warmup_list:list, num) : + + +def round_up(warmup_list: list, num): i = 0 for i in warmup_list: - if num <= i : + if num <= i: break return i + def split(string) -> List[Dict[str, str]]: parts = [] cursor = 0 @@ -99,6 +102,7 @@ def split(string) -> List[Dict[str, str]]: return parts + def image_text_replacement(processor, image_input, config, image_id: int) -> str: if config.model_type == "idefics2": image_seq_len = 64 @@ -196,14 +200,19 @@ class VlmCausalLMBatch(CausalLMBatch): is_warmup: bool = False, ) -> "VlmCausalLMBatch": - dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') - requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] + dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") + requests = [ + CausalLMRequest.from_pb(idx, req, tokenizer) + for idx, req in enumerate(pb.requests) + ] max_input_length = max(r.data.truncate for r in requests) max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) # TODO: Add support for sparse batches top_n_tokens = [r.top_n_tokens for r in pb.requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence @@ -226,7 +235,7 @@ class VlmCausalLMBatch(CausalLMBatch): bucket_size = max_input_length left_padding = max_input_length - input_len if is_warmup is False: - if input_len < max_input_length : + if input_len < max_input_length: rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 @@ -276,10 +285,14 @@ class VlmCausalLMBatch(CausalLMBatch): input_length=input_len, ) - @classmethod def batch_tokenized_inputs( - cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup + cls, + requests: Iterable[generate_pb2.Request], + tokenizer, + processor, + config, + is_warmup, ): # Process images first. We need all of them so that the processor # can make the image splits the same size. And we need the final @@ -345,24 +358,24 @@ class VlmCausalLMBatch(CausalLMBatch): ) if missing_inputs > 0 and image_inputs is not None: - dummy_shape = list(image_inputs['pixel_values'].shape) + dummy_shape = list(image_inputs["pixel_values"].shape) dummy_shape[0] = missing_inputs dummy_images = torch.rand(dummy_shape) new_image_inputs = { "pixel_values": torch.cat( - (image_inputs['pixel_values'], dummy_images), dim=0 + (image_inputs["pixel_values"], dummy_images), dim=0 ), } if "pixel_attention_mask" in image_inputs: - dummy_shape = list(image_inputs['pixel_attention_mask'].shape) + dummy_shape = list(image_inputs["pixel_attention_mask"].shape) dummy_shape[0] = missing_inputs dummy_attention = torch.zeros(dummy_shape) new_image_inputs["pixel_attention_mask"] = torch.cat( (image_inputs["pixel_attention_mask"], dummy_attention), dim=0 ) if "image_sizes" in image_inputs: - dummy_shape = list(list(image_inputs['image_sizes'])[0]) - dummy_shape = missing_inputs*[dummy_shape] + dummy_shape = list(list(image_inputs["image_sizes"])[0]) + dummy_shape = missing_inputs * [dummy_shape] dummy_sizes = torch.IntTensor(dummy_shape) new_image_inputs["image_sizes"] = torch.cat( (image_inputs["image_sizes"], dummy_sizes), dim=0 @@ -406,19 +419,27 @@ class VlmCausalLMBatch(CausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch": + def concatenate( + cls, + batches: List["CausalLMBatch"], + pad_token_id: int = 0, + is_warmup: bool = False, + ) -> "CausalLMBatch": return cls.recombine(batches, pad_token_id, is_warmup) - - @classmethod - def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch": + def recombine( + cls, + batches: List["VlmCausalLMBatch"], + pad_token_id: int, + is_warmup: bool = False, + ) -> "VlmCausalLMBatch": if not all(b.past_key_values is not None for b in batches): raise ValueError("KV cache not allocated! Cannot recombine before prefill!") total_requests = sum(len(b) for b in batches) new_bs = total_requests - if is_warmup is False : + if is_warmup is False: new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) batch_id = batches[0].batch_id device = batches[0].input_ids.device @@ -431,31 +452,39 @@ class VlmCausalLMBatch(CausalLMBatch): # For prefill there is a space allocated only for first token # Need to add padding to the max total tokens before first decode - moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + moves_needed = [ + total_requests - len(b) if b.batch_size == new_bs else total_requests + for b in batches + ] dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = (batches[dst_batch_idx].batch_size < new_bs) + reshape = batches[dst_batch_idx].batch_size < new_bs # TODO: Add support for changing max seq len, i.e. due to output length bucketing # FIXME: max_seq_len for non optimized code if len(batches) > 1: - scenario = 'CONCAT' + scenario = "CONCAT" elif reshape: - scenario = 'RESHAPE' + scenario = "RESHAPE" elif cur_padding[dst_batch_idx] <= 0: - scenario = 'SHIFT' - offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + scenario = "SHIFT" + offsets = [ + biggest_single_chunk(b.max_input_length - max_input_length) + for b in batches + ] max_input_length = max_input_length + offsets[dst_batch_idx] else: # Nothing to do return batches[0] dbg_trace( - scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' - f' reqs:{[len(b) for b in batches]}' - f' offsets:{offsets}' - f' input_lengths:{input_lengths}' - f' cur_padding:{cur_padding}' - f' dst_batch:{dst_batch_idx}') + scenario, + f"bs:{[b.batch_size for b in batches]}->{new_bs}" + f" reqs:{[len(b) for b in batches]}" + f" offsets:{offsets}" + f" input_lengths:{input_lengths}" + f" cur_padding:{cur_padding}" + f" dst_batch:{dst_batch_idx}", + ) grouped_requests = [[req for req in batch.requests] for batch in batches] flat_requests = list(itertools.chain(*grouped_requests)) @@ -466,10 +495,14 @@ class VlmCausalLMBatch(CausalLMBatch): batches[i].realign(target_bs, offsets[i], pad_token_id) batches[i].split_kv_cache_if_needed(i == dst_batch_idx) batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) + batches[dst_batch_idx].move_data( + [batches[i] for i in range(len(batches)) if i != dst_batch_idx] + ) top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) parameters = [r.data.parameters for r in flat_requests] # append the dummy parameters for dummy requests @@ -480,7 +513,9 @@ class VlmCausalLMBatch(CausalLMBatch): fsm_grammar_states = [0] * batch_size for batch in batches: for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] + fsm_grammar_states[req.idx] = ( + batch.next_token_chooser.fsm_grammar_states[i] + ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( parameters, @@ -513,6 +548,7 @@ class VlmCausalLMBatch(CausalLMBatch): input_length=input_length, ) + class VlmCausalLM(Model): def __init__( self, @@ -561,18 +597,14 @@ class VlmCausalLM(Model): htorch.core.hpu_set_env() if world_size > 1: - model = self.get_deepspeed_model( - model_class, model_id, dtype, revision - ) + model = self.get_deepspeed_model(model_class, model_id, dtype, revision) model = hq_env.prepare_model_for_quantization(model) else: get_repo_root(model_id) # Check support for rope scaling model_kwargs = {} - config = AutoConfig.from_pretrained( - model_id - ) + config = AutoConfig.from_pretrained(model_id) if hasattr(config, "rope_scaling"): model_kwargs["rope_scaling"] = self.get_rope_scaling() @@ -581,23 +613,29 @@ class VlmCausalLM(Model): revision=revision, torch_dtype=dtype, trust_remote_code=trust_remote_code, - **model_kwargs + **model_kwargs, ) model = hq_env.prepare_model_for_quantization(model) model = model.eval().to(device) - self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + self.enable_hpu_graph = ( + os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + ) self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" model = remove_kv_cache_from_output(model) if self.enable_hpu_graph: from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) else: if LAZY_MODE == 0: # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace( - "TORCH COMPILE", 'Torch compiling of model') - model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + dbg_trace("TORCH COMPILE", "Torch compiling of model") + model.model = torch.compile( + model.model, + backend="hpu_backend", + options={"keep_input_mutations": True}, + ) model = hq_env.setup_quantization(model) @@ -647,11 +685,15 @@ class VlmCausalLM(Model): ) # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_warmup_steps = ( + int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + ) + self.profiling_steps = ( + int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + ) self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) if self.profiling_steps > 0: self.hb_profiler = HabanaProfile( @@ -659,14 +701,13 @@ class VlmCausalLM(Model): warmup=self.profiling_warmup_steps, active=self.profiling_steps, output_dir=output_dir, - record_shapes=record_shapes + record_shapes=record_shapes, ) self.hb_profiler.start() else: self.hb_profiler = None self.step = 0 - @property def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class @@ -679,20 +720,20 @@ class VlmCausalLM(Model): model_class, model_id: str, dtype: torch.dtype, - revision: Optional[str] = None + revision: Optional[str] = None, ) -> torch.nn.Module: import deepspeed from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = { - "revision": revision - } + model_kwargs = {"revision": revision} # Initialize process(es) for DeepSpeed deepspeed.init_distributed(dist_backend="hccl") logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( + world_size, rank, local_rank + ) ) config = AutoConfig.from_pretrained(model_id, **model_kwargs) load_to_meta = model_on_meta(config) @@ -710,14 +751,18 @@ class VlmCausalLM(Model): get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) # TODO: revisit placement on CPU when auto-injection is possible with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = model_class.from_pretrained( + model_id, torch_dtype=dtype, **model_kwargs + ) model = model.eval() # Initialize the model ds_inference_kwargs = {"dtype": dtype} ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} ds_inference_kwargs["enable_cuda_graph"] = False - ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config) + ds_inference_kwargs["injection_policy"] = get_ds_injection_policy( + model.language_model.config + ) if load_to_meta: # model loaded to meta is managed differently @@ -734,12 +779,12 @@ class VlmCausalLM(Model): return None rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return { - 'type': rope_scaling, 'factor': float(rope_factor) - } + return {"type": rope_scaling, "factor": float(rope_factor)} def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def decode_token( self, @@ -748,7 +793,9 @@ class VlmCausalLM(Model): read_offset: int = 0, ) -> Tuple[str, int, int]: if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + new_text = self.tokenizer.decode( + all_input_ids[read_offset:], skip_special_tokens=False + ) return new_text, read_offset, len(all_input_ids) else: return super().decode_token(all_input_ids, prefix_offset, read_offset) @@ -776,7 +823,7 @@ class VlmCausalLM(Model): hpu_kwargs = {} # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama" : + if self.model.config.model_type == "llama": hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 if self.has_position_ids: @@ -814,18 +861,26 @@ class VlmCausalLM(Model): token_idx_scalar = batch.attention_mask.shape[-1] - 1 token_idx = torch.tensor(token_idx_scalar).to(self.device) else: - token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx_scalar = ( + batch.attention_mask.shape[-1] - batch.right_padding + ) token_idx = torch.tensor(token_idx_scalar).to(self.device) # Select next token input_length = batch.input_length if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate + next_token_ids, next_token_logprobs, logprobs, _, _ = ( + batch.next_token_chooser( + batch.input_ids, + logits[:, input_length - 1 : input_length, :].squeeze(-2), + self.speculate, + ) ) else: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate + next_token_ids, next_token_logprobs, logprobs, _, _ = ( + batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) ) # Speculation is not active for causal accepted_ids = torch.ones_like(batch.input_ids)[:, 0] @@ -836,23 +891,29 @@ class VlmCausalLM(Model): accepted_ids, ) - prev_batches.append({ - 'next_token_ids': next_token_ids, - 'next_token_logprobs': next_token_logprobs, - }) + prev_batches.append( + { + "next_token_ids": next_token_ids, + "next_token_logprobs": next_token_logprobs, + } + ) for req_idx, req in enumerate(batch.requests): - requests_to_generate.append({ - 'req': req, - 'prev_req_idx': req.idx, - 'batch_id': batch_id, - 'seed': batch.next_token_chooser.seeds[req_idx], - 'do_sample': batch.next_token_chooser.do_sample[req_idx], - 'top_n_tokens': batch.top_n_tokens[req_idx], - 'top_token_ids': batch_top_token_ids[req_idx], - 'top_token_logprobs': batch_top_token_logprobs[req_idx], - 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], - }) + requests_to_generate.append( + { + "req": req, + "prev_req_idx": req.idx, + "batch_id": batch_id, + "seed": batch.next_token_chooser.seeds[req_idx], + "do_sample": batch.next_token_chooser.do_sample[req_idx], + "top_n_tokens": batch.top_n_tokens[req_idx], + "top_token_ids": batch_top_token_ids[req_idx], + "top_token_logprobs": batch_top_token_logprobs[req_idx], + "grammar_state": batch.next_token_chooser.fsm_grammar_states[ + req.idx + ], + } + ) htorch.core.mark_step() @@ -867,7 +928,9 @@ class VlmCausalLM(Model): # Update position_ids if prefill: - batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + batch.position_ids = ( + torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + ) else: batch.position_ids += 1 # Update past key values @@ -878,7 +941,9 @@ class VlmCausalLM(Model): # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) + batch = self.batch_type.concatenate( + batches, self.tokenizer.pad_token_id, is_warmup + ) else: batch = batches[0] @@ -886,15 +951,24 @@ class VlmCausalLM(Model): # Check if we need to do any bookkeeping first if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) + batch = batch.__class__.recombine( + [batch], self.tokenizer.pad_token_id, is_warmup + ) - scenario = 'PREFILL' if prefill else 'GENERATE' - if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: + scenario = "PREFILL" if prefill else "GENERATE" + if ( + self.enable_hpu_graph + and self.limit_hpu_graph + and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) + != self.prev_bs + ): self.model.clear_cache() self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) dbg_trace( - scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') - #assert batch.right_padding > 0, 'No more room for next token!' + scenario, + f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", + ) + # assert batch.right_padding > 0, 'No more room for next token!' # Execute batch if prefill: @@ -908,21 +982,27 @@ class VlmCausalLM(Model): batch.past_key_values, batch.pixel_values, batch.image_sizes, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + bypass_hpu_graph=( + prefill and self.limit_hpu_graph if self.enable_hpu_graph else None + ), ) elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): # Don't schedule next forward if max_new_tokens for all requests equals 1 # - we've already generated the first and only needed token in the prefill phase pass else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + token_idx = torch.tensor( + batch.attention_mask.shape[-1] - batch.right_padding + ).to(self.device) batch.logits = self.forward( batch.input_ids, batch.attention_mask, batch.position_ids, token_idx, batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + bypass_hpu_graph=( + prefill and self.limit_hpu_graph if self.enable_hpu_graph else None + ), ) htorch.core.mark_step() @@ -932,40 +1012,45 @@ class VlmCausalLM(Model): # Stage 3. Finish and return previous generations stopped = len(requests_to_generate) > 0 for prev_batch in prev_batches: - prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() - prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + prev_batch["next_token_logprobs"] = prev_batch[ + "next_token_logprobs" + ].tolist() + prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() htorch.core.mark_step() for req_data in requests_to_generate: - req = req_data['req'] - i = req_data['prev_req_idx'] - prev_batch_id = req_data['batch_id'] + req = req_data["req"] + i = req_data["prev_req_idx"] + prev_batch_id = req_data["batch_id"] assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] - next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] + next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] request = req.data input_length = req.input_length prefix_offset = req.prefix_offset read_offset = req.read_offset - do_sample = req_data['do_sample'] - seed = req_data['seed'] + do_sample = req_data["do_sample"] + seed = req_data["seed"] stopping_criteria = req.stopping_criteria all_input_ids = req.all_input_ids next_token_id = next_token_ids_cpu[i] next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data['top_n_tokens'] - top_token_ids = req_data['top_token_ids'] - top_token_logprobs = req_data['top_token_logprobs'] - grammar_state = req_data['grammar_state'] + top_n_tokens = req_data["top_n_tokens"] + top_token_ids = req_data["top_token_ids"] + top_token_logprobs = req_data["top_token_logprobs"] + grammar_state = req_data["grammar_state"] # Append next token to all tokens all_input_ids[input_length] = next_token_id new_input_length = input_length + 1 # Generated token - if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: - next_token_text = '' + if ( + is_tokenizer_transparent(self.tokenizer) + and len(stopping_criteria.stop_sequence_criterias) == 0 + ): + next_token_text = "" else: next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids[0:new_input_length, 0], prefix_offset, read_offset @@ -989,7 +1074,11 @@ class VlmCausalLM(Model): output_text = None else: output_text = self.decode( - all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + all_input_ids[ + new_input_length + - stopping_criteria.current_tokens : new_input_length, + 0, + ] ) generated_text = GeneratedText( output_text, @@ -1004,7 +1093,7 @@ class VlmCausalLM(Model): if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0: new_input_length - 1] + prefill_token_ids = all_input_ids[0 : new_input_length - 1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, @@ -1073,7 +1162,12 @@ class VlmCausalLM(Model): htorch.core.mark_step() self.step = self.step + 1 if self.hb_profiler is not None: - if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: + if ( + self.step + > self.profiling_wait_steps + + self.profiling_warmup_steps + + self.profiling_steps + ): self.hb_profiler.stop() else: self.hb_profiler.step() @@ -1090,7 +1184,7 @@ class VlmCausalLM(Model): self.model.config, self.dtype, self.device, - is_warmup + is_warmup, ) def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): @@ -1117,14 +1211,14 @@ class VlmCausalLM(Model): ) global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - max_input_length = batch.input_ids.shape[1] + max_input_length = batch.input_ids.shape[1] max_prefill_batch_size = batch.input_ids.shape[0] PREFILL_WARMUP_BATCH_SIZE_LIST = [] batch_size = 1 while batch_size <= max_prefill_batch_size: PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) batch_size = batch_size * 2 - if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : + if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size: PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) seq_len = BASE_IMAGE_TOKENS @@ -1132,19 +1226,21 @@ class VlmCausalLM(Model): i = 0 while seq_len <= max_input_length: PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) + seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i) i += 1 if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) - #Prefill and decode warmup + # Prefill and decode warmup DECODE_WARMUP_BATCH_SIZE_LIST = [] prefill_batch = None decode_batch = None try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : - for seq_len in PREFILL_WARMUP_SEQLEN_LIST : - batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) + for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST: + for seq_len in PREFILL_WARMUP_SEQLEN_LIST: + batch = self.generate_warmup_batch( + request, seq_len, batch_size, is_warmup + ) _, prefill_batch, _ = self.generate_token([batch], is_warmup) _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) @@ -1161,21 +1257,29 @@ class VlmCausalLM(Model): mem_stats = get_hpu_memory_stats(self.device) logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats} " - ) + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats} " + ) max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) batch_size = max_prefill_batch_size * 2 # Decode warmup with bigger batch_size try: - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: + if ( + DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size + and batch_size <= max_decode_batch_size + ): batches = [] - for i in range(int(batch_size/max_prefill_batch_size)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + for i in range(int(batch_size / max_prefill_batch_size)): + batch = self.generate_warmup_batch( + request, + PREFILL_WARMUP_SEQLEN_LIST[0], + DECODE_WARMUP_BATCH_SIZE_LIST[-1], + is_warmup, + ) _, prefill_batch, _ = self.generate_token([batch], is_warmup) batches.append(prefill_batch) while batch_size <= max_decode_batch_size: @@ -1184,17 +1288,24 @@ class VlmCausalLM(Model): batch_size = batch_size * 2 batches.clear() - for i in range(int(batch_size/max_prefill_batch_size)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + for i in range(int(batch_size / max_prefill_batch_size)): + batch = self.generate_warmup_batch( + request, + PREFILL_WARMUP_SEQLEN_LIST[0], + DECODE_WARMUP_BATCH_SIZE_LIST[-1], + is_warmup, + ) _, prefill_batch, _ = self.generate_token([batch], is_warmup) batches.append(prefill_batch) batches.clear() if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: - max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 + max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2 batch_size = max_decode_batch_size - for i in range(int(max_decode_batch_size / 2)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) + for i in range(int(max_decode_batch_size / 2)): + batch = self.generate_warmup_batch( + request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup + ) _, prefill_batch, _ = self.generate_token([batch], is_warmup) batches.append(prefill_batch) _, decode_batch, _ = self.generate_token(batches, is_warmup) @@ -1211,9 +1322,9 @@ class VlmCausalLM(Model): mem_stats = get_hpu_memory_stats(self.device) logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats}" - ) + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats}" + ) - return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file + return MAX_BATCH_TOTAL_TOKENS diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 61b0f27fc..674a8aed1 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -38,7 +38,10 @@ try: except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. VLM_BATCH_TYPES = set() -from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION +from text_generation_server.utils.version import ( + is_driver_compatible, + MIN_TGI_GAUDI_SYNAPSE_VERSION, +) class SignalHandler: @@ -72,7 +75,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # Force inference mode for the lifetime of TextGenerationService # self._inference_mode_raii_guard = torch._C._InferenceMode(True) - async def Info(self, request, context): return self.model.info @@ -101,7 +103,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens, max_input_tokens, max_total_tokens = self.model.warmup(request) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(request) + ) # W/A for the skip tokenizer path # We need to call make_tokenizer_optional after the warmup, @@ -194,7 +198,9 @@ def serve( trust_remote_code: bool = False, ): if not is_driver_compatible(): - logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures") + logger.warning( + f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures" + ) unix_socket_template = "unix://{}-{}" adapter_to_index = {} @@ -204,14 +210,19 @@ def serve( rank = int(os.environ["RANK"]) logger.info("Server:server_inner: rank ={}".format(rank)) server_urls = [ - unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"])) + unix_socket_template.format(uds_path, rank) + for rank in range(int(os.environ["WORLD_SIZE"])) ] local_url = server_urls[int(os.environ["RANK"])] else: local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] - logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url)) + logger.info( + "Server:server_inner: data type = {}, local_url = {}".format( + dtype, local_url + ) + ) if dtype == "bfloat16" or None: data_type = torch.bfloat16 else: diff --git a/backends/gaudi/server/text_generation_server/tgi_service.py b/backends/gaudi/server/text_generation_server/tgi_service.py index 714b3566b..18e88a7eb 100644 --- a/backends/gaudi/server/text_generation_server/tgi_service.py +++ b/backends/gaudi/server/text_generation_server/tgi_service.py @@ -10,7 +10,13 @@ def main(args): logger.info("TGIService: starting tgi service .... ") logger.info( "TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format( - args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path + args.model_id, + args.revision, + args.sharded, + args.speculate, + args.dtype, + args.trust_remote_code, + args.uds_path, ) ) lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) @@ -24,7 +30,7 @@ def main(args): dtype=args.dtype, trust_remote_code=args.trust_remote_code, uds_path=args.uds_path, - max_input_tokens=args.max_input_tokens + max_input_tokens=args.max_input_tokens, ) diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py index ef8d437b7..8bbcad6a3 100644 --- a/backends/gaudi/server/text_generation_server/utils/debug.py +++ b/backends/gaudi/server/text_generation_server/utils/debug.py @@ -8,14 +8,14 @@ from optimum.habana.utils import to_gb_rounded import habana_frameworks.torch as htorch START_TS = None -DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') -if 'GRAPH_VISUALIZATION' in os.environ: - for f in glob.glob('.graph_dumps/*'): +DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME") +if "GRAPH_VISUALIZATION" in os.environ: + for f in glob.glob(".graph_dumps/*"): os.remove(f) def count_hpu_graphs(): - return len(glob.glob('.graph_dumps/*PreGraph*')) + return len(glob.glob(".graph_dumps/*PreGraph*")) def dbg_trace(tag, txt): @@ -25,7 +25,11 @@ def dbg_trace(tag, txt): START_TS = time.perf_counter() time_offset = time.perf_counter() - START_TS mem_stats = htorch.hpu.memory.memory_stats() - mem_used = to_gb_rounded(mem_stats['InUse']) - max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) - print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' - f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) + mem_used = to_gb_rounded(mem_stats["InUse"]) + max_mem_used = to_gb_rounded(mem_stats["MaxInUse"]) + print( + f"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB " + f"mmu:{max_mem_used:.1f}GB | {tag} | {txt}", + flush=True, + file=open(DBG_TRACE_FILENAME, "a"), + ) diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index cf8acaccf..0e9b97fb2 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -64,10 +64,12 @@ def initialize_torch_distributed(): backend = "hccl" n_hpus = torch.hpu.device_count() if world_size > n_hpus: - raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") + raise ValueError( + f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})." + ) else: try: - import oneccl_bindings_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 backend = "ccl" if os.getenv("CCL_WORKER_COUNT", None) is None: diff --git a/backends/gaudi/server/text_generation_server/utils/logits_process.py b/backends/gaudi/server/text_generation_server/utils/logits_process.py index 104fc2f09..472f2dcb0 100644 --- a/backends/gaudi/server/text_generation_server/utils/logits_process.py +++ b/backends/gaudi/server/text_generation_server/utils/logits_process.py @@ -63,7 +63,9 @@ class StaticWarper: self.static_warped_scores.copy_(local_scores) # Compute logprobs - self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1)) + self.static_next_logprob.copy_( + torch.log_softmax(self.static_warped_scores, -1) + ) self.static_scores.copy_(scores) self.hpu_graph.replay() @@ -78,7 +80,9 @@ def static_warper( top_p: Optional[float], typical_p: Optional[float], ) -> StaticWarper: - return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + return StaticWarper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): @@ -95,13 +99,17 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): self.penalty = penalty - self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) + self.penalty_tensor = torch.tensor( + penalty, dtype=dtype, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) + score = torch.where( + score < 0, score * self.penalty_tensor, score / self.penalty_tensor + ) scores.scatter_(1, input_ids, score) return scores @@ -163,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): batch_size, vocab_size, dtype=scores.dtype, device=scores.device ) token_freq.scatter_add_( - 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device) + 1, + input_ids, + torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device), ) token_freq /= input_size @@ -190,9 +200,13 @@ class HeterogeneousTemperatureLogitsWarper: The value used to module the logits distribution. """ - def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): + def __init__( + self, temperature: List[float], dtype: torch.dtype, device: torch.device + ): self.temperature = temperature - self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) + self.temperature_tensor = torch.tensor( + temperature, dtype=dtype, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: scores.div_(self.temperature_tensor) @@ -231,7 +245,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): min_tokens_to_keep: int = 1, ): self.top_p = top_p - self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) + self.top_p_opposite = 1 - torch.tensor( + top_p, dtype=dtype, device=device + ).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @@ -248,7 +264,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) return warped_scores @@ -296,7 +314,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): disabled = [x == 0 for x in top_k] if any(disabled): - self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1) + self.top_k_disabled_mask = torch.tensor( + disabled, dtype=torch.bool, device=device + ).view(-1, 1) else: self.top_k_disabled_mask = None @@ -332,7 +352,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): self.max_top_k = max(self.top_k) if self.top_k_disabled_mask is not None: - self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None + self.top_k_disabled_mask = ( + self.top_k_disabled_mask[indices] if any(disabled) else None + ) return self return None @@ -398,11 +420,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): if self.disabled_mask is not None: last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) - sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) @@ -416,7 +442,9 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): self.mass_tensor = self.mass_tensor[indices] if self.disabled_mask is not None: - self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None + self.disabled_mask = ( + self.disabled_mask[indices] if any(disabled) else None + ) return self return None diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py index 56f5b8c73..b0282c42e 100644 --- a/backends/gaudi/server/text_generation_server/utils/tokens.py +++ b/backends/gaudi/server/text_generation_server/utils/tokens.py @@ -254,7 +254,7 @@ class HeterogeneousNextTokenChooser: tokenizer: PreTrainedTokenizerBase, grammars: List[str], grammar_types: List[int], - fsm_grammar_states:List[int], + fsm_grammar_states: List[int], quantization_enabled: bool, ): warpers = [] @@ -447,7 +447,9 @@ class HeterogeneousNextTokenChooser: next_id = next_id.item() self.fsm_grammar_states[grammar_state_index] = ( self.grammar_processor.advance_at_index( - next_id, past_state, grammar_state_index, + next_id, + past_state, + grammar_state_index, ) ) return self @@ -544,9 +546,7 @@ def pad_next_token_chooser_parameters( grammar="", grammar_type=0, ) - parameters.extend( - [empty_parameters] * (expected_size - len(parameters)) - ) + parameters.extend([empty_parameters] * (expected_size - len(parameters))) return parameters @@ -701,24 +701,47 @@ def make_tokenizer_optional(tokenizer): padding, return_token_type_ids, truncation, - max_length + max_length, ): - assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" - assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" - assert not return_token_type_ids, "inccorrect input arguments when calling TransparentTokenizer" - assert truncation, "inccorrect input arguments when calling TransparentTokenizer" + assert ( + return_tensors == "pt" + ), "inccorrect input arguments when calling TransparentTokenizer" + assert ( + padding == "max_length" or padding == "longest" + ), "inccorrect input arguments when calling TransparentTokenizer" + assert ( + not return_token_type_ids + ), "inccorrect input arguments when calling TransparentTokenizer" + assert ( + truncation + ), "inccorrect input arguments when calling TransparentTokenizer" def str_token_to_int(i): - if i == '?': + if i == "?": return tokenizer.pad_token_id else: return int(i) - all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] - for inner_text in text] + + all_tokens = [ + [str_token_to_int(i.strip()) for i in inner_text.split(",")] + for inner_text in text + ] if padding == "longest": max_length = max(len(tokens) for tokens in all_tokens) - return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]), - "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])} + return { + "input_ids": torch.tensor( + [ + [tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens + for tokens in all_tokens + ] + ), + "attention_mask": torch.tensor( + [ + [0] * (max_length - len(tokens)) + [1] * len(tokens) + for tokens in all_tokens + ] + ), + } def decode( self, @@ -728,9 +751,10 @@ def make_tokenizer_optional(tokenizer): **kwargs, ) -> str: # I don't think this method is used anywhere and should be removed when doing refactoring - return ','.join(str(i) for i in to_py_obj(token_ids)) # noqa: F821 + return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821 import os + if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": tokenizer.__class__ = _ tokenizer.is_transparent = True diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py index a72a9ea7b..84c916bf3 100644 --- a/backends/gaudi/server/text_generation_server/utils/version.py +++ b/backends/gaudi/server/text_generation_server/utils/version.py @@ -1,7 +1,7 @@ from optimum.habana.utils import get_driver_version from packaging.version import Version -MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0") +MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.16.0") def is_driver_compatible(): @@ -9,4 +9,4 @@ def is_driver_compatible(): if driver_version is not None: if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION: return False - return True \ No newline at end of file + return True diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index 58080bd1e..d7ae11d54 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -68,4 +68,4 @@ fn hl_smi() -> Option { let hl_smi = String::from_utf8(output.stdout).ok()?; let output = hl_smi.replace('\n', "\n "); Some(output.trim().to_string()) -} \ No newline at end of file +}