From 6c350f2f75a05bd2b43b9814ebbe49d3a932798b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Apr 2024 15:27:29 +0000 Subject: [PATCH] Working for TP, Llama + Mistral Still unsolved: - Rust parameter validation (to calculate number of tokens). - Integration test. - Validate other text heads. - Quantization. --- .../models/custom_modeling/clip.py | 7 +- .../custom_modeling/flash_llama_modeling.py | 16 ++-- .../custom_modeling/flash_mistral_modeling.py | 39 ++------- .../models/custom_modeling/llava_next.py | 9 +- .../models/vlm_causal_lm.py | 86 +++++++++++++++++-- server/text_generation_server/server.py | 2 + 6 files changed, 108 insertions(+), 51 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 994611f0..c4917733 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -123,6 +123,7 @@ class CLIPAttention(nn.Module): f" {self.num_heads})." ) self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() self.scale = self.head_size**-0.5 self.dropout = config.attention_dropout @@ -137,7 +138,7 @@ class CLIPAttention(nn.Module): config, prefix=f"{prefix}.out_proj", weights=weights, - bias=False, + bias=True, ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -155,7 +156,7 @@ class CLIPAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + bsz, tgt_len, _ = hidden_states.size() # get query proj @@ -225,7 +226,7 @@ class CLIPAttention(nn.Module): attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c5e1c06a..4cf0fcf2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -342,12 +342,6 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" - ), - weights=weights, - ) self.layers = nn.ModuleList( [ FlashLlamaLayer( @@ -384,6 +378,8 @@ class FlashLlamaModel(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -417,6 +413,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) self.model = FlashLlamaModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, @@ -447,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots, input_lengths, max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 352907b8..cab72f63 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -348,9 +348,6 @@ class MistralModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ MistralLayer( @@ -373,34 +370,7 @@ class MistralModel(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - return self.with_hidden_states( - hidden_states, - position_ids, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - true_max_s, - prefill_cache_indices, - ) - - def with_hidden_states( - self, - hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -411,6 +381,7 @@ class MistralModel(torch.nn.Module): true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], ): + hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -441,6 +412,9 @@ class FlashMistralForCausalLM(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.model.embed_tokens", weights=weights + ) self.model = MistralModel( prefix="model" if not prefix else f"{prefix}.model", config=config, @@ -480,8 +454,9 @@ class FlashMistralForCausalLM(torch.nn.Module): # kernel requires the true values input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index c9f2268d..ed21a52b 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -142,7 +142,10 @@ class LlavaNextForConditionalGeneration(nn.Module): vision_config = config.vision_config # Instead of selecting in hidden_states[-2]. # Instead compute only the n -2 + 1 layers and don't pool - vision_config.num_hidden_layers += config.vision_feature_layer + 1 + if config.vision_feature_layer < 0: + vision_config.num_hidden_layers += config.vision_feature_layer + 1 + else: + vision_config.num_hidden_layers = config.vision_feature_layer + 1 self.vision_tower = load_vision_model( prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, @@ -195,7 +198,7 @@ class LlavaNextForConditionalGeneration(nn.Module): pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.language_model.model.embed_tokens(input_ids) + inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" @@ -290,6 +293,8 @@ class LlavaNextForConditionalGeneration(nn.Module): slots=slots, input_lengths=input_lengths, max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index c6f88013..c732b89e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -5,6 +5,7 @@ from opentelemetry import trace from typing import Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase +from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_mistral import ( BaseFlashMistral, @@ -36,14 +37,73 @@ def split(string) -> List[Dict[str, str]]: return parts +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def get_number_of_features(height: int, width: int, config) -> int: + # From config + # Hardcoded for CLIP for now + # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + image_grid_pinpoints = config.image_grid_pinpoints + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + + assert image_size % patch_size == 0 + + npatches = image_size // patch_size + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + [height, width], + image_grid_pinpoints, + image_size, + ) + import math + + height_of_patch = math.ceil(height / width * npatches) + + unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width + # They are only added after width + newline_features = height_of_patch * num_patch_width + # The base patch covers the entire image + base_features = npatches**2 + return unpadded_features + newline_features + base_features + if height == 640 and width == 640: + return 2928 + return 2634 + + +# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}" +# assert get_number_of_features(640, 640) == 2928 + + class VlmCausalLMBatch(FlashMistralBatch): pixel_values: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor): + def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): batch_inputs = [] - images = [] + image_inputs = [] max_truncation = 0 for r in requests: chunks = split(r.inputs) @@ -52,8 +112,13 @@ class VlmCausalLMBatch(FlashMistralBatch): if chunk["type"] == "text": full_text += chunk["content"] elif chunk["type"] == "image": - full_text += "" * 2928 - images.append(chunk["content"]) + image = chunk["content"] + image = processor.image_processor.fetch_images(image) + image_input = processor.image_processor(image, return_tensors="pt") + height, width = image_input["image_sizes"][0] + num_features = get_number_of_features(height, width, config) + full_text += "" * num_features + image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}") @@ -63,9 +128,13 @@ class VlmCausalLMBatch(FlashMistralBatch): batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] - images = processor.image_processor.fetch_images(images) - if images: - image_inputs = processor.image_processor(images, return_tensors="pt") + if image_inputs: + image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]), + } else: image_inputs = None return batch_tokenized_inputs, image_inputs @@ -76,11 +145,12 @@ class VlmCausalLMBatch(FlashMistralBatch): pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, processor, + config, dtype: torch.dtype, device: torch.device, ) -> "VlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor + pb.requests, tokenizer, processor, config ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index cb9b302d..495c2c0c 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -87,6 +87,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.processor, + self.model.model.config, self.model.dtype, self.model.device, ) @@ -110,6 +111,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.processor, + self.model.model.config, self.model.dtype, self.model.device, )