diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 3cda2f4b..cff1fac8 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -110,6 +110,7 @@ impl Client { max_prefill_tokens: u32, max_total_tokens: u32, max_batch_size: Option, + model_id: &str ) -> Result> { let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); if !warmup_enabled { @@ -152,25 +153,76 @@ impl Client { let mut batch_counter: u64 = 0; let mut request_counter: u64 = 0; - for shape in shapes.iter() { - let (batch_size, seq_length) = shape; - let mut batches: Vec = vec![ - self.create_warmup_batch( - *shape, - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - false, - None, - ) - ]; - // if possible, create second batch in order to trigger concatenate operation - if *batch_size < max_decode_batch_size { - batches.push( + if model_id.contains("llava") { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = cmp::min(max_input_length, max_prefill_tokens - n_tokens); + + let mut inputs = String::new(); + inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + + requests.push(Request { + id: 0, + // We truncate the input on the server side to be sure that it has the correct size + inputs, + truncate, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let mut batches = Vec::new(); + batches.push(Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }); + + let request = tonic::Request::new(WarmupRequest { + batches, + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + else { + for shape in shapes.iter() { + let (batch_size, seq_length) = shape; + let mut batches: Vec = vec![ self.create_warmup_batch( - (1, *seq_length), + *shape, &mut batch_counter, &mut request_counter, max_input_length, @@ -179,56 +231,45 @@ impl Client { false, None, ) - ); + ]; + // if possible, create second batch in order to trigger concatenate operation + if *batch_size < max_decode_batch_size { + batches.push( + self.create_warmup_batch( + (1, *seq_length), + &mut batch_counter, + &mut request_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + false, + None, + ) + ); + } + + let request = tonic::Request::new(WarmupRequest { + batches, + max_input_length, + max_prefill_tokens, + max_total_tokens, + }).inject_context(); + let _response = self.stub.warmup(request).await?.into_inner(); } - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }).inject_context(); - let _response = self.stub.warmup(request).await?.into_inner(); - } + // send batches to warmup all possible decode shapes + if decode_batch_sizes.len() > 1 { + let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size { + decode_bucket_size + } else { + decode_bucket_size.div_ceil(max_prefill_batch_size) + }; + let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket; - // send batches to warmup all possible decode shapes - if decode_batch_sizes.len() > 1 { - let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size { - decode_bucket_size - } else { - decode_bucket_size.div_ceil(max_prefill_batch_size) - }; - let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket; - - let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size); - let mut batches: Vec = vec![ - self.create_warmup_batch( - (requests_send, seq_bucket_size), - &mut batch_counter, - &mut request_counter, - max_input_length, - max_total_tokens, - seq_bucket_size, - false, - Some(max_new_tokens), - ) - ]; - - let get_current_decode_batch_size = |num: u32| -> u32 { - decode_batch_sizes.iter() - .filter(|&&x| x >= num) - .min() - .copied() - .unwrap() - }; - - let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send); - while current_decode_batch_size < max_decode_batch_size { - let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send; - let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size); - batches.push( + let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size); + let mut batches: Vec = vec![ self.create_warmup_batch( - (num_requests, seq_bucket_size), + (requests_send, seq_bucket_size), &mut batch_counter, &mut request_counter, max_input_length, @@ -237,48 +278,74 @@ impl Client { false, Some(max_new_tokens), ) - ); + ]; - requests_send += num_requests; - current_decode_batch_size = get_current_decode_batch_size(requests_send); + let get_current_decode_batch_size = |num: u32| -> u32 { + decode_batch_sizes.iter() + .filter(|&&x| x >= num) + .min() + .copied() + .unwrap() + }; + + let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send); + while current_decode_batch_size < max_decode_batch_size { + let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send; + let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size); + batches.push( + self.create_warmup_batch( + (num_requests, seq_bucket_size), + &mut batch_counter, + &mut request_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + false, + Some(max_new_tokens), + ) + ); + + requests_send += num_requests; + current_decode_batch_size = get_current_decode_batch_size(requests_send); + } + + let request = tonic::Request::new(WarmupRequest { + batches, + max_input_length, + max_prefill_tokens, + max_total_tokens, + }).inject_context(); + let _response = self.stub.warmup(request).await?.into_inner(); } - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }).inject_context(); - let _response = self.stub.warmup(request).await?.into_inner(); - } - - // send batches with default params to warm up Greedy search - let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len()); - for batch_size in &prefill_batch_sizes { - greedy_shapes.push((*batch_size, seq_bucket_size.clone())); - } - for greedy_shape in greedy_shapes.iter() { - let batches: Vec = vec![ - self.create_warmup_batch( - *greedy_shape, - &mut batch_counter, - &mut request_counter, + // send batches with default params to warm up Greedy search + let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len()); + for batch_size in &prefill_batch_sizes { + greedy_shapes.push((*batch_size, seq_bucket_size.clone())); + } + for greedy_shape in greedy_shapes.iter() { + let batches: Vec = vec![ + self.create_warmup_batch( + *greedy_shape, + &mut batch_counter, + &mut request_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + true, + None, + ) + ]; + let request = tonic::Request::new(WarmupRequest { + batches, max_input_length, + max_prefill_tokens, max_total_tokens, - seq_bucket_size, - true, - None, - ) - ]; - let request = tonic::Request::new(WarmupRequest { - batches, - max_input_length, - max_prefill_tokens, - max_total_tokens, - }).inject_context(); - let _response = self.stub.warmup(request).await?.into_inner(); + }).inject_context(); + let _response = self.stub.warmup(request).await?.into_inner(); + } + Ok(None) // No support for maximum total tokens } - Ok(None) // No support for maximum total tokens } #[instrument(skip_all)] diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index e2c800dd..fdd84035 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -100,6 +100,7 @@ impl ShardedClient { max_prefill_tokens: u32, max_total_tokens: u32, max_batch_size: Option, + model_id: &str, ) -> Result> { let futures: Vec<_> = self .clients @@ -110,6 +111,7 @@ impl ShardedClient { max_prefill_tokens, max_total_tokens, max_batch_size, + model_id )) }) .collect(); diff --git a/router/src/main.rs b/router/src/main.rs index c3b8d047..4f9f0f73 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -349,6 +349,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_total_tokens as u32, max_batch_size, + &model_info.model_id ) .await .map_err(RouterError::Warmup)? diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3d3d3e1e..569b204f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,12 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder +from text_generation_server.models.vlm_causal_lm import VlmCausalLM +from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, +) + + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -159,6 +165,18 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + logger.info(f"model_type = {model_type}") + if model_type == "llava_next": + logger.info(f"################model_type = {model_type}") + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=None, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8ec6aca8..4a48ad46 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -369,6 +369,7 @@ class CausalLMBatch(Batch): input_lengths = [b.input_length for b in batches] max_input_length = max(input_lengths) offsets = [max_input_length - b.input_length for b in batches] + cur_padding = [b.right_padding for b in batches] # For prefill there is a space allocated only for first token # Need to add padding to the max total tokens before first decode diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index de9673aa..4268cc9b 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -21,17 +21,12 @@ import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN +from transformers.models.llava_next.modeling_llava_next import ( + unpad_image, +) +from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution - -from text_generation_server.models.custom_modeling.vlm import ( - load_text_model, - load_vision_model, -) -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelRowLinear, -) - +from loguru import logger def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -56,100 +51,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def unpad_image(tensor, original_size): - """ - Unpads a PyTorch tensor of a padded and resized image. - - Args: - tensor (`torch.Tensor`): - The image tensor, assumed to be of shape (num_channels, height, width). - original_size (`tuple`): - The original size of the image (height, width). - - Returns: - `torch.Tensor`: The unpadded image tensor. - """ - original_height, original_width = original_size - current_height, current_width = tensor.shape[1:] - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) - padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding : current_height - padding, :] - else: - scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) - padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding : current_width - padding] - - return unpadded_tensor - - -# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext -class LlavaNextMultiModalProjector(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.linear_1 = TensorParallelColumnLinear.load( - prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True - ) - self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = TensorParallelRowLinear.load( - prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True - ) - - def forward(self, image_features): - hidden_states = self.linear_1(image_features) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - -class LlavaNextForConditionalGeneration(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - config.vision_config.quantize = config.quantize - vision_config = config.vision_config - # Instead of selecting in hidden_states[-2]. - # Instead compute only the n -2 + 1 layers and don't pool - if config.vision_feature_layer < 0: - vision_config.num_hidden_layers += config.vision_feature_layer + 1 - else: - vision_config.num_hidden_layers = config.vision_feature_layer + 1 - self.vision_tower = load_vision_model( - prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", - config=config.vision_config, - weights=weights, - ) - - self.multi_modal_projector = LlavaNextMultiModalProjector( - prefix="multi_modal_projector", config=config, weights=weights - ) - - self.image_newline = weights.get_tensor("image_newline") - - self.vocab_size = config.text_config.vocab_size - self.config = config - config.text_config.quantize = config.quantize - config.text_config.speculator = config.speculator - self.language_model = load_text_model( - prefix="language_model" if not prefix else f"{prefix}.language_model", - config=config.text_config, - weights=weights, - ) - self.pad_token_id = ( - config.pad_token_id if config.pad_token_id is not None else -1 - ) - +class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): + def _merge_input_ids_with_image_features( self, - input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: torch.Tensor, + input_ids: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" mask = input_ids == self.config.image_token_index @@ -164,120 +72,215 @@ class LlavaNextForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, + input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, - # Unused for this model - pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, ): - inputs_embeds = self.language_model.embed_tokens(input_ids) - if pixel_values is not None and len(pixel_values) > 0: - # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() - # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" - # 1. Extract the input embeddings - # 2. Merge text and images - num_images, num_patches, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view( - num_images * num_patches, channels, height, width + if token_idx is not None: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - image_features = self.vision_tower(pixel_values) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) - # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] - # Already done within the clip model - selected_image_feature = image_features.last_hidden_state + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + logits = outputs[0] + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, + ): + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 + The only differences are: + - add new args token_idx + - add the process of merging images into inputs_embeds + """ + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) else: - raise RuntimeError( - f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + + position_ids = kwargs.get("position_ids", None) + labels = kwargs.get("labels", None) + if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: + vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) + vision_feature_layer = kwargs.get("vision_feature_layer", None) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + batch_size, num_patches, num_channels, height, width = pixel_values.shape + reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) + image_features = self.vision_tower( + reshaped_pixel_values, output_hidden_states=True + ) + + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx].tolist(), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) + self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None: + seq_len = input_ids.shape[1] + pad_len = seq_len - token_idx.item() + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = extended_attention_mask + attention_mask[:, -pad_len:] = 0 + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "labels": labels, + } ) - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [num_patches] * num_images - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_features - ) - - hidden_states = self.language_model.model( - inputs_embeds=inputs_embeds, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.language_model.lm_head(hidden_states) - return logits, speculative_logits + return model_inputs \ No newline at end of file diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 86d9b4c8..72ceca6b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,7 +10,12 @@ import numpy as np from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) from typing import Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -19,6 +24,11 @@ from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.models.types import ( Batch, Tokens, @@ -686,20 +696,97 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - num_layers: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - sliding_window: Optional[int] = None, + model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, + default_dtype=torch.bfloat16, + aliases=None, + # Used for Santacoder override of config + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, + skip_special_tokens: bool = True, ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_size = head_size + + # Create model + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = config_class.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype) + + + + prefix = "" + model = model_class(prefix, config, weights) + + # VLM models define the config we care about in their text_config + text_config = getattr(config, "text_config", None) + if text_config is not None: + config = text_config + + + self.num_layers = config.num_hidden_layers + # Validation is done in the model itself + if num_kv_heads is None: + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround + if num_kv_heads is None: + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") + self.num_kv_heads = num_kv_heads ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 + + if head_size is None: + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size + + self.cuda_graphs = {} + self.kv_cache = [] self.cuda_graphs = {} @@ -711,7 +798,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, + sliding_window=None, ) @property diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f0db89b2..7412092a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,29 +1,87 @@ import re import torch +import os +import time import math from PIL import Image from io import BytesIO import base64 - +import numpy from opentelemetry import trace +from loguru import logger from typing import Optional, Tuple, List, Type, Dict - +import tempfile +import copy +from text_generation_server.models import Model from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - FlashMistralBatch, +from text_generation_server.models.causal_lm import ( + CausalLMBatch, + CausalLMRequest, + round_up, + remove_kv_cache_from_output ) -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, +) + +from transformers import AutoProcessor +import text_generation_server.habana_quantization_env as hq_env +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from text_generation_server.models.cache_manager import ( get_cache_manager, ) +from text_generation_server.utils import ( + HeterogeneousNextTokenChooser, + StoppingCriteria, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, +) +import habana_frameworks.torch as htorch +from optimum.habana.utils import HabanaProfile +from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES + +from transformers import ( + AutoTokenizer, + AutoModel, + PreTrainedTokenizerBase, + AutoConfig, +) +from optimum.habana.checkpoint_utils import ( + get_repo_root, + model_on_meta, + write_checkpoints_json, +) + +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.debug import dbg_trace tracer = trace.get_tracer(__name__) IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") +BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) +MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) +BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 1)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) +PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 1)) +CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) +PREFILL_GRAPH_NUM = int(os.environ.get('PREFILL_GRAPH_NUM', 16)) +os.environ['MAX_TOTAL_TOKENS'] = str(MAX_TOTAL_TOKENS) +os.environ['BATCH_BUCKET_SIZE'] = str(BATCH_BUCKET_SIZE) +os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF) +os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE) +os.environ['LAZY_MODE'] = str(LAZY_MODE) def split(string) -> List[Dict[str, str]]: parts = [] @@ -41,30 +99,6 @@ def split(string) -> List[Dict[str, str]]: return parts - -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise ValueError("grid_pinpoints should be a list of tuples or lists") - - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - def image_text_replacement(image_input, config, image_id) -> str: if config.model_type == "idefics2": # TODO technically depends on image splitting which is not implemented. @@ -77,9 +111,7 @@ def image_text_replacement(image_input, config, image_id) -> str: elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) - from loguru import logger - logger.info(f"Found {num_features} in image of resolution {height}x{width}") return "" * num_features elif config.model_type == "paligemma": @@ -125,6 +157,7 @@ def get_number_of_features(height: int, width: int, config) -> int: image_grid_pinpoints, image_size, ) + unpadded_features, newline_features = get_unpadded_features( height, width, npatches, num_patch_height, num_patch_width ) @@ -140,27 +173,100 @@ def load_data_uri(image_uri: str) -> Image.Image: return image -class VlmCausalLMBatch(FlashMistralBatch): +class VlmCausalLMBatch(CausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(VlmCausalLMBatch, cls).concatenate(batches) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, + ) -> "VlmCausalLMBatch": + + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') + requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch + max_input_length = max(r.data.truncate for r in requests) + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + # TODO: Add support for sparse batches + top_n_tokens = [r.top_n_tokens for r in pb.requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + # TODO: by tokenizing all inputs at once we loose information on actual input lengths + # this means that we cannot shift inputs to the left after a long input sequence + # was filtered out + new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) + parameters = [r.parameters for r in pb.requests] + # append the dummy parameters for dummy request + parameters = pad_next_token_chooser_parameters(parameters, new_bs) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + pb=parameters, + dtype=dtype, + device=device, + tokenizer=tokenizer, + quantization_enabled=hq_env.is_quantization_enabled, + ) + tokenized_inputs = batch_tokenized_inputs + input_len = tokenized_inputs["input_ids"].shape[1] + + bucket_size = max_input_length + left_padding = max_input_length - input_len + if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: + assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len + + input_ids = tokenized_inputs["input_ids"] + attention_mask = tokenized_inputs["attention_mask"] + # Allocate space for first token + if left_padding > 0: + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + all_input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id + ).T.split(1, dim=1) + + # New input length after left padding + input_len = bucket_size + for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len + r.all_input_ids = all_input_ids[r.idx] + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + htorch.core.mark_step() + + return cls( + batch_id=pb.id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_len, + ) @classmethod def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): @@ -192,16 +298,26 @@ class VlmCausalLMBatch(FlashMistralBatch): image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}") - batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) + new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) + missing_inputs = new_bs - len(requests) + dummy_images = [] + dummy_inputs = [] + if len(batch_inputs) > 0 and len(image_inputs) > 0: + dummy_inputs = [batch_inputs[0]] * missing_inputs + dummy_images = [image_inputs[0]] * missing_inputs + + image_inputs += dummy_images batch_tokenized_inputs = tokenizer( - batch_inputs, + batch_inputs + dummy_inputs, truncation=True, max_length=max_truncation, - add_special_tokens=not config.model_type == "paligemma", - )["input_ids"] + return_tensors="pt", + padding="longest", + return_token_type_ids=False, + ) if image_inputs: image_input = image_inputs[0] new_image_inputs = { @@ -255,126 +371,626 @@ class VlmCausalLMBatch(FlashMistralBatch): return batch -class VlmCausalLM(BaseFlashMistral): +class VlmCausalLM(Model): + def __init__( + self, + model_class, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + dtype, + trust_remote_code: bool, + **kwargs, + ): + adapt_transformers_to_gaudi() + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + self.prev_bs = 0 + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + # Create model + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") + + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() + + if world_size > 1: + model = self.get_deepspeed_model( + model_class, model_id, dtype, revision + ) + model = self.prepare_model_for_quantization(model) + else: + get_repo_root(model_id) + + # Check support for rope scaling + model_kwargs = {} + config = AutoConfig.from_pretrained( + model_id + ) + if hasattr(config, "rope_scaling"): + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + model = model_class.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + trust_remote_code=trust_remote_code, + **model_kwargs + ) + model = self.prepare_model_for_quantization(model) + model = model.eval().to(device) + + self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + else: + if LAZY_MODE == 0: + # It is said that "keep_input_mutations" is safe for inference to be done + dbg_trace( + "TORCH COMPILE", f'Torch compiling of model') + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + + model = self.setup_quantization(model) + + if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: + raise ValueError(f"Model type {model.config.model_type} is not supported!") + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + kwargs = { + "use_cache": True, + "return_dict": True, + } + + if model.config.model_type in ["llama", "mistral"]: + kwargs["attn_softmax_bf16"] = True + kwargs["trim_logits"] = True + + if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + kwargs["use_flash_attention"] = True + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + kwargs["flash_attention_recompute"] = True + + self.speculate = get_speculate() + super(VlmCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + kwargs=kwargs, + ) + + + @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) + + def get_deepspeed_model( + self, + model_class, + model_id: str, + dtype: torch.dtype, + revision: Optional[str] = None + ) -> torch.nn.Module: + import deepspeed + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + model_kwargs = { + "revision": revision + } + + # Initialize process(es) for DeepSpeed + deepspeed.init_distributed(dist_backend="hccl") + logger.info( + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + ) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + load_to_meta = model_on_meta(config) + + # Check support for rope scaling + if hasattr(config, "rope_scaling"): + config.rope_scaling = self.get_rope_scaling() + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + if load_to_meta: + # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = model_class.from_config(config, torch_dtype=dtype) + else: + get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) + # TODO: revisit placement on CPU when auto-injection is possible + with deepspeed.OnDevice(dtype=dtype, device="cpu"): + model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = model.eval() + + # Initialize the model + ds_inference_kwargs = {"dtype": dtype} + ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} + ds_inference_kwargs["enable_cuda_graph"] = False + + if load_to_meta: + # model loaded to meta is managed differently + checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + write_checkpoints_json(model_id, local_rank, checkpoints_json) + ds_inference_kwargs["checkpoint"] = checkpoints_json.name + model = deepspeed.init_inference(model, **ds_inference_kwargs) + + return model.module + + def get_rope_scaling(self) -> Optional[Dict]: + rope_scaling = os.getenv("ROPE_SCALING", None) + if rope_scaling is None: + return None + + rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) + return { + 'type': rope_scaling, 'factor': float(rope_factor) + } + + def setup_quantization(self, model): + if hq_env.is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + def prepare_model_for_quantization(self, model): + if hq_env.is_quantization_enabled: + if model.config.model_type == "llama": + self.patch_scoped_linear_all_reduce(model) + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) + return model + + def finish_quantization_measurements(self, model): + if hq_env.is_quantization_enabled: + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(self.model) + return model + + def patch_scoped_linear_all_reduce(self, model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + self.patch_scoped_linear_all_reduce(module) + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + if is_tokenizer_transparent(self.tokenizer): + new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + return new_text, read_offset, len(all_input_ids) + else: + return super().decode_token(all_input_ids, prefix_offset, read_offset) def forward( - self, batch: VlmCausalLMBatch - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + self, + input_ids, + attention_mask, + position_ids, + token_idx, + past_key_values: Optional[List[Tuple]] = None, + pixel_values: Optional[List[torch.Tensor]] = None, + image_sizes: Optional[List[Tuple[int, int]]] = None, + bypass_hpu_graph: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward - if batch.speculative_ids is not None: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "token_idx": token_idx, + "pixel_values": pixel_values, + "image_sizes": image_sizes + } - speculative_ids = batch.speculative_ids + hpu_kwargs = {} + # Optimum Habana got "lazy_mode" key-val only supported for llama type of models + if self.model.config.model_type == "llama" : + hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 - B, speculative_length = speculative_ids.shape - new_length = speculative_length + 1 - new_input_ids = torch.cat( - [input_ids.unsqueeze(-1), speculative_ids], dim=1 - ).reshape(-1) - arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) - arange_int = arange.to(dtype=torch.int32) - new_position_ids = ( - position_ids.unsqueeze(-1).expand(B, new_length) + arange - ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int - ).view(-1) + if self.has_position_ids: + kwargs["position_ids"] = position_ids - # Add Copy the block tables for all members - block_tables = ( - block_tables.unsqueeze(1) - .expand(B, new_length, -1) - .reshape(B * new_length, -1) - .contiguous() - ) - max_s = max_s + speculative_length + if bypass_hpu_graph != None: + hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - input_ids = new_input_ids - position_ids = new_position_ids + kwargs.update(self.kwargs) + model_inputs = self.model.prepare_inputs_for_generation(**kwargs) + if past_key_values is not None: + return self.model.forward(**model_inputs, **hpu_kwargs) else: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices + outputs = self.model.forward(**model_inputs, **hpu_kwargs) + return outputs.logits, outputs.past_key_values - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batches: List[VlmCausalLMBatch] + ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # Results + generations: List[Generation] = [] + prev_batches = [] + requests_to_generate = [] + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + for batch_id, batch in enumerate(batches): + if batch.logits is not None: + logits = batch.logits + past = batch.past + prefill = batch.past_key_values is None + if prefill: + # no right padding for prefill + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) + else: + token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx = torch.tensor(token_idx_scalar).to(self.device) - bs = input_ids.shape[0] - # Try to find an associated cuda graph - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + # Select next token + input_length = batch.input_length + if logits.shape[-2] > 1: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate + ) + else: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + prev_batches.append({ + 'next_token_ids': next_token_ids, + 'next_token_logprobs': next_token_logprobs, + }) + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append({ + 'req': req, + 'prev_req_idx': req.idx, + 'batch_id': batch_id, + 'seed': batch.next_token_chooser.seeds[req_idx], + 'do_sample': batch.next_token_chooser.do_sample[req_idx], + 'top_n_tokens': batch.top_n_tokens[req_idx], + 'top_token_ids': batch_top_token_ids[req_idx], + 'top_token_logprobs': batch_top_token_logprobs[req_idx], + 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], + }) + + htorch.core.mark_step() + + # Add new token into input_ids + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask.index_fill_(1, token_idx, 1) + + # Adjust lengths + batch.input_length += 1 + + # Update position_ids + if prefill: + batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + else: + batch.position_ids += 1 + # Update past key values + if prefill: + batch.past_key_values = past + + htorch.core.mark_step() + + # Stage 2. Prepare new batch for speculative scheduling + if len(batches) > 1: + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) else: - cuda_graph = None - if cu_seqlen_prefill is not None or cuda_graph is None: - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, + batch = batches[0] + + prefill = batch.past_key_values is None + + # Check if we need to do any bookkeeping first + if not prefill: + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) + + scenario = 'PREFILL' if prefill else 'GENERATE' + if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: + self.model.clear_cache() + self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) + dbg_trace( + scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') + #assert batch.right_padding > 0, 'No more room for next token!' + + # Execute batch + if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + batch.logits, batch.past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + batch.pixel_values, + batch.image_sizes, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + ) + elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): + # Don't schedule next forward if max_new_tokens for all requests equals 1 + # - we've already generated the first and only needed token in the prefill phase + pass + else: + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + batch.logits = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(-1) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + htorch.core.mark_step() - # Replay the graph - cuda_graph["graph"].replay() + start_decode = time.time_ns() - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + # Stage 3. Finish and return previous generations + stopped = len(requests_to_generate) > 0 + for prev_batch in prev_batches: + prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() + prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + htorch.core.mark_step() + + for req_data in requests_to_generate: + req = req_data['req'] + i = req_data['prev_req_idx'] + prev_batch_id = req_data['batch_id'] + assert len(prev_batches) > prev_batch_id + next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] + next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + + request = req.data + input_length = req.input_length + prefix_offset = req.prefix_offset + read_offset = req.read_offset + do_sample = req_data['do_sample'] + seed = req_data['seed'] + stopping_criteria = req.stopping_criteria + all_input_ids = req.all_input_ids + next_token_id = next_token_ids_cpu[i] + next_token_logprob = next_token_logprobs[i] + top_n_tokens = req_data['top_n_tokens'] + top_token_ids = req_data['top_token_ids'] + top_token_logprobs = req_data['top_token_logprobs'] + grammar_state = req_data['grammar_state'] + + # Append next token to all tokens + all_input_ids[input_length] = next_token_id + new_input_length = input_length + 1 + + # Generated token + if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: + next_token_text = '' + else: + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[0:new_input_length, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + if is_tokenizer_transparent(self.tokenizer): + output_text = None + else: + output_text = self.decode( + all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + next_token_logprobs + prefill_token_ids = all_input_ids[0: new_input_length - 1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id], + [next_token_logprob], + [next_token_text], + [next_token_id in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single_with_past_state( + req.idx, next_token_id, grammar_state + ) + ) + + req.all_input_ids = all_input_ids + req.input_length = new_input_length + req.prefix_offset = prefix_offset + req.read_offset = read_offset + + htorch.core.mark_step() + # self.step = self.step + 1 + # if self.hb_profiler is not None: + # if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: + # self.hb_profiler.stop() + # else: + # self.hb_profiler.step() + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch if not stopped else None, (forward_ns, decode_ns) + + def batch_from_pb(self, batch): + return VlmCausalLMBatch.from_pb_processor( + batch, + self.tokenizer, + self.processor, + self.model.config, + self.dtype, + self.device ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits + + def generate_warmup_batch(self, request, seq_len, batch_size): + batch = copy.deepcopy(request.batches[0]) + for req in batch.requests: + req.truncate = seq_len + + for i in range(len(batch.requests) - batch_size): + batch.requests.pop() + + return self.batch_from_pb(batch) + + def warmup(self, request) -> None: + batches = [self.batch_from_pb(batch) for batch in request.batches] + + try: + # prefill + _, prefill_batch, _ = self.generate_token([batches[0]]) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) from e + + global BASE_IMAGE_TOKENS, PAD_SEQUENCE_TO_MULTIPLE_OF, PREFILL_BATCH_BUCKET_SIZE, PREFILL_GRAPH_NUM + max_input_length = batches[0].input_ids.shape[1] + max_batch_size = batches[0].input_ids.shape[0] + seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF + batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE + while batch_num > PREFILL_GRAPH_NUM : + PREFILL_BATCH_BUCKET_SIZE = PREFILL_BATCH_BUCKET_SIZE * 2 + os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE) + batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE + + while seq_num * batch_num >= PREFILL_GRAPH_NUM : + PAD_SEQUENCE_TO_MULTIPLE_OF = PAD_SEQUENCE_TO_MULTIPLE_OF * 2 + os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF) + seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF + + seq_lens_list = numpy.arange(BASE_IMAGE_TOKENS + PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length + 1, PAD_SEQUENCE_TO_MULTIPLE_OF).tolist() + batch_sizes_list = numpy.arange(PREFILL_BATCH_BUCKET_SIZE, max_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE).tolist() + for seq_len in seq_lens_list : + for batch_size in batch_sizes_list : + batch = self.generate_warmup_batch(request, seq_len, batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + _, decode_batch, _ = self.generate_token([prefill_batch]) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 5184731f..0b5e9e03 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -96,8 +96,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch, self.model.tokenizer, self.model.dtype, self.model.device ) - batches = [batch_from_pb(batch) for batch in request.batches] - self.model.warmup(batches) + if self.model.batch_type in VLM_BATCH_TYPES : + self.model.warmup(request) + else: + batches = [batch_from_pb(batch) for batch in request.batches] + self.model.warmup(batches) return generate_pb2.WarmupResponse()