From b8be0d1ae7713157767aa8598c0906a26195434d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Apr 2024 16:41:01 +0000 Subject: [PATCH] Update by abstracting away text model. --- .../custom_modeling/flash_llama_modeling.py | 31 +- .../models/custom_modeling/llava_next.py | 306 ++- .../models/flash_causal_lm.py | 7 +- .../models/flash_llama.py | 3 +- .../models/flash_mistral.py | 58 +- .../text_generation_server/models/idefics.py | 4 +- .../models/idefics_causal_lm.py | 1706 +++++++++++++++++ .../models/llava_next.py | 71 +- .../models/vlm_causal_lm.py | 859 +-------- server/text_generation_server/server.py | 17 +- 10 files changed, 1951 insertions(+), 1111 deletions(-) create mode 100644 server/text_generation_server/models/idefics_causal_lm.py diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3a269fc0..efe05098 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -281,9 +281,8 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" self.self_attn = FlashLlamaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -337,27 +336,36 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=( + "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" + ), + weights=weights, ) self.layers = nn.ModuleList( [ FlashLlamaLayer( - layer_id, - config, - weights, + prefix=( + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}.model.layers.{layer_id}" + ), + config=config, + weights=weights, ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix="model.norm" if not prefix else f"{prefix}.model.norm", + weights=weights, + eps=config.rms_norm_eps, ) self.gradient_checkpointing = False @@ -406,13 +414,13 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = FlashLlamaModel(config, weights) + self.model = FlashLlamaModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, - prefix="lm_head", + prefix="lm_head" if not prefix else f"{prefix}.lm_head", weights=weights, ) @@ -426,6 +434,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 2392db8b..8de11944 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -113,7 +113,13 @@ def load_vision_model(prefix, config, weights): def load_text_model(prefix, config, weights): - if config.model_type == "mistral": + if config.model_type == "llama": + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + + return FlashLlamaForCausalLM(prefix, config, weights) + elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) @@ -124,22 +130,25 @@ def load_text_model(prefix, config, weights): class LlavaNextForConditionalGeneration(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize - self.vision_tower = load_vision_model( - prefix="vision_tower", config=config.vision_config, weights=weights - ) + # self.vision_tower = load_vision_model( + # prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights + # ) - self.multi_modal_projector = LlavaNextMultiModalProjector(config) + # self.multi_modal_projector = LlavaNextMultiModalProjector(config) self.image_newline = weights.get_tensor("image_newline") self.vocab_size = config.text_config.vocab_size + self.config = config config.text_config.quantize = config.quantize config.text_config.use_medusa = config.use_medusa self.language_model = load_text_model( - prefix="language_model", config=config.text_config, weights=weights + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, ) self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 @@ -257,168 +266,141 @@ class LlavaNextForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, ): - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer + # vision_feature_layer = ( + # vision_feature_layer + # if vision_feature_layer is not None + # else self.config.vision_feature_layer + # ) + # vision_feature_select_strategy = ( + # vision_feature_select_strategy + # if vision_feature_select_strategy is not None + # else self.config.vision_feature_select_strategy + # ) + + # if cu_seqlen_prefill is not None: + # pass + # # # 1. Extract the input embeddings + # # inputs_embeds = self.get_input_embeddings()(input_ids) + + # # # 2. Merge text and images + # # if pixel_values is not None and input_ids.shape[1] != 1: + # # batch_size, num_patches, num_channels, height, width = ( + # # pixel_values.shape + # # ) + # # reshaped_pixel_values = pixel_values.view( + # # batch_size * num_patches, num_channels, height, width + # # ) + # # image_features = self.vision_tower( + # # reshaped_pixel_values, output_hidden_states=True + # # ) + + # # selected_image_feature = image_features.hidden_states[ + # # vision_feature_layer + # # ] + + # # if vision_feature_select_strategy == "default": + # # selected_image_feature = selected_image_feature[:, 1:] + # # elif vision_feature_select_strategy == "full": + # # selected_image_feature = selected_image_feature + + # # image_features = self.multi_modal_projector(selected_image_feature) + + # # # split up image_features for each of the individual images + # # # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # # # if we assume each image has 5 image features (base image + 4 patches) + # # split_sizes = [image.shape[0] for image in pixel_values] + # # image_features = torch.split(image_features, split_sizes, dim=0) + + # # # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + # # height = width = ( + # # self.config.vision_config.image_size + # # // self.config.vision_config.patch_size + # # ) + + # # new_image_features = [] + # # for image_idx, image_feature in enumerate(image_features): + # # if image_feature.shape[0] > 1: + # # base_image_feature = image_feature[0] + # # image_feature = image_feature[1:] + + # # if height * width != base_image_feature.shape[0]: + # # raise ValueError( + # # "The number of patches is not consistent with the image size." + # # ) + # # num_patch_height, num_patch_width = get_anyres_image_grid_shape( + # # image_sizes[image_idx], + # # self.config.image_grid_pinpoints, + # # self.config.vision_config.image_size, + # # ) + # # image_feature = image_feature.view( + # # num_patch_height, num_patch_width, height, width, -1 + # # ) + # # image_feature = image_feature.permute( + # # 4, 0, 2, 1, 3 + # # ).contiguous() + # # image_feature = image_feature.flatten(1, 2).flatten(2, 3) + # # image_feature = unpad_image( + # # image_feature, image_sizes[image_idx] + # # ) + # # image_feature = torch.cat( + # # ( + # # image_feature, + # # self.image_newline[:, None, None].expand( + # # *image_feature.shape[:-1], 1 + # # ), + # # ), + # # dim=-1, + # # ) + # # image_feature = image_feature.flatten(1, 2).transpose(0, 1) + # # image_feature = torch.cat( + # # (base_image_feature, image_feature), dim=0 + # # ) + # # else: + # # image_feature = image_feature[0] + # # image_feature = torch.cat( + # # (image_feature, self.image_newline[None]), dim=0 + # # ) + # # new_image_features.append(image_feature) + # # image_features = torch.stack(new_image_features, dim=0) + + # # inputs_embeds, attention_mask, labels, position_ids = ( + # # self._merge_input_ids_with_image_features( + # # image_features, inputs_embeds, input_ids, attention_mask, labels + # # ) + # # ) + # # if labels is None: + # # labels = torch.full_like( + # # attention_mask, self.config.ignore_index + # # ).to(torch.long) + + logits = self.language_model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + lm_head_indices, ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - - if inputs_embeds is None: - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - - # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1: - batch_size, num_patches, num_channels, height, width = ( - pixel_values.shape - ) - reshaped_pixel_values = pixel_values.view( - batch_size * num_patches, num_channels, height, width - ) - image_features = self.vision_tower( - reshaped_pixel_values, output_hidden_states=True - ) - - selected_image_feature = image_features.hidden_states[ - vision_feature_layer - ] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute( - 4, 0, 2, 1, 3 - ).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image( - image_feature, image_sizes[image_idx] - ) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds, attention_mask, labels, position_ids = ( - self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - ) - if labels is None: - labels = torch.full_like( - attention_mask, self.config.ignore_index - ).to(torch.long) - - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif ( - past_key_values is not None - and pixel_values is not None - and input_ids.shape[1] == 1 - ): - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) - - # Get the target length - target_seqlen = first_layer_past_key_value.shape[-1] + 1 - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat( - (attention_mask, extended_attention_mask), dim=1 - ) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - ) - - logits = outputs[0] return logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5c25f341..22c2fa6c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1043,7 +1043,12 @@ class FlashCausalLM(Model): batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids + try: + batch.input_lengths_tensor += accepted_ids + except Exception: + import ipdb + + ipdb.set_trace() batch.slot_indices += accepted_ids if prefill and prefill_logprobs: diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a2ac759a..56768942 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) - model = FlashLlamaForCausalLM(config, weights) + prefix = "" + model = FlashLlamaForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 2e1055b2..8806942a 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -6,7 +6,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase, AutoTokenizer +from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig from transformers.models.llama import LlamaTokenizerFast from typing import Optional, Tuple, Type @@ -301,14 +301,15 @@ class FlashMistralBatch(FlashCausalLMBatch): class BaseFlashMistral(FlashCausalLM): def __init__( self, - config_cls, model_cls, model_id: str, + config_cls=AutoConfig, revision: Optional[str] = None, quantize: Optional[str] = None, use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -317,22 +318,13 @@ class BaseFlashMistral(FlashCausalLM): else: raise NotImplementedError("FlashMistral is only available on GPU") - try: - tokenizer = LlamaTokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - except Exception: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code @@ -341,10 +333,12 @@ class BaseFlashMistral(FlashCausalLM): config.use_medusa = use_medusa # Set context windows - if config.sliding_window is not None: + if getattr(config, "sliding_window", None) is not None: set_sliding_window( config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) ) + else: + config.sliding_window = None torch.distributed.barrier(group=self.process_group) @@ -353,17 +347,19 @@ class BaseFlashMistral(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) - model = model_cls(config, weights) + prefix = "" + model = model_cls(prefix, config, weights) self.cuda_graphs = {} torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( + num_layers, num_kv_heads, head_size = self.get_layer_config(model) + super().__init__( model=model, tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_size=head_size, dtype=dtype, device=device, rank=rank, @@ -371,6 +367,16 @@ class BaseFlashMistral(FlashCausalLM): sliding_window=config.sliding_window, ) + def get_layer_config(self, model) -> Tuple[int, int, int]: + return ( + len(model.model.layers), + model.model.num_key_value_heads, + model.model.head_size, + ) + + def max_past(self) -> int: + return self.model.max_past + @property def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch @@ -485,11 +491,11 @@ class BaseFlashMistral(FlashCausalLM): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - if cu_seqlen_prefill is None and self.model.max_past is not None: + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. - max_s = min(self.model.max_past, max_s) + max_s = min(self.max_past(), max_s) bs = input_ids.shape[0] padded_bs = bs diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 1a63934f..a90a0d96 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -17,7 +17,7 @@ from transformers import LlamaTokenizerFast from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) -from text_generation_server.models.vlm_causal_lm import VlmCausalLM +from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -25,7 +25,7 @@ from text_generation_server.utils import ( ) -class IDEFICSSharded(VlmCausalLM): +class IDEFICSSharded(IdeficsCausalLM): def __init__( self, model_id: str, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py new file mode 100644 index 00000000..d64e047b --- /dev/null +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -0,0 +1,1706 @@ +import torch +import torch +import time + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from typing import Optional, Tuple, List, Type, Dict + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.models.vlm_causal_lm import split + +import re + +IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") + + +tracer = trace.get_tracer(__name__) + + +@dataclass +class IdeficsCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + pixel_values: Optional[torch.Tensor] + image_hidden_states: Optional[torch.Tensor] + image_attention_mask: Optional[torch.Tensor] + past_key_values: Optional[List[Tuple]] + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor: ProcessorMixin, # Hack + dtype: torch.dtype, + device: torch.device, + ) -> "IdeficsCausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + prompts = [] + for inp in inputs: + # Each input is encoded into a list, where each element of this input list is either a string or a URL + prompts.append(split(inp)) + + # The processor replaces the call to tokenizer, and + # a/ takes care of fetching images from the URL + # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model + tokenized_inputs = processor( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_truncation, + add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append( + input_len - 5 + ) # To decode without potential fallbacks errors + read_offsets.append( + input_len + ) # To decode without potential fallbacks errors + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + + input_ids = tokenized_inputs["input_ids"] + pixel_values = tokenized_inputs.get("pixel_values", None) + image_hidden_states = None + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_input_length + padding_right_offset) + ) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + # Do the same for image_attention_mask + if pixel_values is None: + image_attention_mask = None + else: + image_attention_mask = input_ids.new_zeros( + ( + pb.size, + max_input_length + padding_right_offset, + pixel_values.size(1), + ) + ) + image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ + "image_attention_mask" + ] + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + all_input_ids = tokenized_inputs["input_ids"].T.split( + 1, dim=1 + ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list + + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + # It deletes requests from the batch. For instance when client lost connection + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] + # Do the same for pixel_values and image_attention_mask + pixel_values = self.pixel_values[keep_indices] + self.image_attention_mask = self.image_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.image_attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + :, + ] + if self.image_hidden_states is None: + image_hidden_states = None + else: + image_hidden_states = self.image_hidden_states[keep_indices] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.pixel_values = pixel_values + self.image_hidden_states = image_hidden_states + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate( + cls, batches: List["IdeficsCausalLMBatch"] + ) -> "IdeficsCausalLMBatch": + # It adds new requests to the batch + # Used for padding + total_batch_size = 0 + max_input_length = 0 + max_num_images = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + max_num_images = max(max_num_images, batch.pixel_values.size(1)) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + pixel_values = None + image_hidden_states = None + image_attention_mask = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + curr_batch_max_num_images = batch.pixel_values.size(1) + if pixel_values is None: + pixel_values = batch.pixel_values.new_zeros( + (total_batch_size, max_num_images, 3, 224, 224) + ) + pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( + batch.pixel_values + ) + + if image_attention_mask is None: + image_attention_mask = batch.image_attention_mask.new_zeros( + ( + total_batch_size, + max_input_length + padding_right_offset, + max_num_images, + ) + ) + + # We need to slice the attention mask to remove padding from previous steps + # and to remove unused allocated space + left_offset = max_input_length - batch.max_input_length + batch_left_offset = ( + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset + ) + attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + image_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + :curr_batch_max_num_images, + ] = batch.image_attention_mask[ + :, batch_left_offset : -batch.padding_right_offset, : + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values + ] + elif len(batch.past_key_values[0][0].shape) == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if batch.keys_head_dim_last: + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) + else: + # BLOOM case + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) + del past_keys + + start_index = end_index + + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) + del past_values + + # Update values + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) + + +class IdeficsCausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + from text_generation_server.models.custom_modeling.idefics_modeling import ( + IdeficsForVisionText2Text, + ) + + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = IdeficsForVisionText2Text.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": ""}) + + super(IdeficsCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[IdeficsCausalLMBatch]: + return IdeficsCausalLMBatch + + def forward( + self, + input_ids, + attention_mask, + position_ids, + pixel_values, + image_hidden_states, + image_attention_mask, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_hidden_states": image_hidden_states, + "image_attention_mask": image_attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, + } + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + outputs, speculative_logits = self.model.forward(**kwargs) + return ( + outputs.logits, + speculative_logits, + outputs.past_key_values, + outputs.image_hidden_states, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: IdeficsCausalLMBatch + ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # slice the attention mask to the correct shape + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + if batch.image_attention_mask is None: + image_attention_mask = None + else: + if batch.input_ids.size(1) == 1: + # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), + # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension + # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated + # token need to attend to the encoder hidden states (i.e. the vision encoder) + # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic + image_attention_mask = batch.image_attention_mask[ + :, -(batch.padding_right_offset + 1) + ].unsqueeze(1) + else: + image_attention_mask = batch.image_attention_mask[ + :, : -batch.padding_right_offset + ] + + logits, speculative_logits, past, image_hidden_states = self.forward( + input_ids=batch.input_ids, + attention_mask=attention_mask, + position_ids=batch.position_ids, + pixel_values=batch.pixel_values, + image_hidden_states=batch.image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=batch.past_key_values, + ) + # Hardcoded remove image tokens + logits[:, 32000:32001] = torch.finfo(logits.dtype).min + + start_decode = time.time_ns() + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( + batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] + ) + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past + batch.image_hidden_states = image_hidden_states + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) + + +import time + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from typing import Optional, Tuple, List, Type, Dict + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling + +import re + +tracer = trace.get_tracer(__name__) + + +@dataclass +class IdeficsCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + pixel_values: Optional[torch.Tensor] + image_hidden_states: Optional[torch.Tensor] + image_attention_mask: Optional[torch.Tensor] + past_key_values: Optional[List[Tuple]] + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor: ProcessorMixin, # Hack + dtype: torch.dtype, + device: torch.device, + ) -> "IdeficsCausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + # TODO Check impact on idefics + prompts = [] + for inp in inputs: + # Each input is encoded into a list, where each element of this input list is either a string or a URL + prompts.append(split(inp)) + + # The processor replaces the call to tokenizer, and + # a/ takes care of fetching images from the URL + # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model + tokenized_inputs = processor( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_truncation, + # TODO Check impact on idefics + # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append( + input_len - 5 + ) # To decode without potential fallbacks errors + read_offsets.append( + input_len + ) # To decode without potential fallbacks errors + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + + input_ids = tokenized_inputs["input_ids"] + pixel_values = tokenized_inputs.get("pixel_values", None) + image_hidden_states = None + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_input_length + padding_right_offset) + ) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + # Do the same for image_attention_mask + if pixel_values is None: + image_attention_mask = None + else: + image_attention_mask = input_ids.new_zeros( + ( + pb.size, + max_input_length + padding_right_offset, + pixel_values.size(1), + ) + ) + image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ + "image_attention_mask" + ] + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + all_input_ids = tokenized_inputs["input_ids"].T.split( + 1, dim=1 + ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list + + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + # It deletes requests from the batch. For instance when client lost connection + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] + # Do the same for pixel_values and image_attention_mask + pixel_values = self.pixel_values[keep_indices] + self.image_attention_mask = self.image_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.image_attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + :, + ] + if self.image_hidden_states is None: + image_hidden_states = None + else: + image_hidden_states = self.image_hidden_states[keep_indices] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.pixel_values = pixel_values + self.image_hidden_states = image_hidden_states + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate( + cls, batches: List["IdeficsCausalLMBatch"] + ) -> "IdeficsCausalLMBatch": + # It adds new requests to the batch + # Used for padding + total_batch_size = 0 + max_input_length = 0 + max_num_images = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + max_num_images = max(max_num_images, batch.pixel_values.size(1)) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + pixel_values = None + image_hidden_states = None + image_attention_mask = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + curr_batch_max_num_images = batch.pixel_values.size(1) + if pixel_values is None: + pixel_values = batch.pixel_values.new_zeros( + (total_batch_size, max_num_images, 3, 224, 224) + ) + pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( + batch.pixel_values + ) + + if image_attention_mask is None: + image_attention_mask = batch.image_attention_mask.new_zeros( + ( + total_batch_size, + max_input_length + padding_right_offset, + max_num_images, + ) + ) + + # We need to slice the attention mask to remove padding from previous steps + # and to remove unused allocated space + left_offset = max_input_length - batch.max_input_length + batch_left_offset = ( + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset + ) + attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + image_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + :curr_batch_max_num_images, + ] = batch.image_attention_mask[ + :, batch_left_offset : -batch.padding_right_offset, : + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values + ] + elif len(batch.past_key_values[0][0].shape) == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if batch.keys_head_dim_last: + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) + else: + # BLOOM case + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) + del past_keys + + start_index = end_index + + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) + del past_values + + # Update values + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) + + +class IdeficsCausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + from text_generation_server.models.custom_modeling.idefics_modeling import ( + IdeficsForVisionText2Text, + ) + + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = IdeficsForVisionText2Text.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": ""}) + + super(IdeficsCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[IdeficsCausalLMBatch]: + return IdeficsCausalLMBatch + + def forward( + self, + input_ids, + attention_mask, + position_ids, + pixel_values, + image_hidden_states, + image_attention_mask, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_hidden_states": image_hidden_states, + "image_attention_mask": image_attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, + } + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + outputs, speculative_logits = self.model.forward(**kwargs) + return ( + outputs.logits, + speculative_logits, + outputs.past_key_values, + outputs.image_hidden_states, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: IdeficsCausalLMBatch + ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # slice the attention mask to the correct shape + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + if batch.image_attention_mask is None: + image_attention_mask = None + else: + if batch.input_ids.size(1) == 1: + # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), + # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension + # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated + # token need to attend to the encoder hidden states (i.e. the vision encoder) + # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic + image_attention_mask = batch.image_attention_mask[ + :, -(batch.padding_right_offset + 1) + ].unsqueeze(1) + else: + image_attention_mask = batch.image_attention_mask[ + :, : -batch.padding_right_offset + ] + + logits, speculative_logits, past, image_hidden_states = self.forward( + input_ids=batch.input_ids, + attention_mask=attention_mask, + position_ids=batch.position_ids, + pixel_values=batch.pixel_values, + image_hidden_states=batch.image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=batch.past_key_values, + ) + # Hardcoded remove image tokens + logits[:, 32000:32001] = torch.finfo(logits.dtype).min + + start_decode = time.time_ns() + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( + batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] + ) + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past + batch.image_hidden_states = image_hidden_states + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py index 9e35d0fc..0ae1b46d 100644 --- a/server/text_generation_server/models/llava_next.py +++ b/server/text_generation_server/models/llava_next.py @@ -1,24 +1,15 @@ import torch -import torch.distributed -from typing import List, Optional, Tuple +from typing import Optional from transformers import ( - AutoTokenizer, - AutoConfig, AutoProcessor, ) from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) -# from transformers import AutoConfig, AutoTokenizer, AutoProcessor from text_generation_server.models.vlm_causal_lm import VlmCausalLM -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) class LlavaNext(VlmCausalLM): @@ -31,59 +22,15 @@ class LlavaNext(VlmCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - # 9b seems to work correctly enough in float16, but 80b seems - # to be really saturating for f16. - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - self.device, self.dtype = device, dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.use_medusa = use_medusa - config.vision_config.quantize = quantize - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) self.processor = AutoProcessor.from_pretrained( - model_id, + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + super().__init__( + model_cls=LlavaNextForConditionalGeneration, + model_id=model_id, revision=revision, - padding_side="left", - truncation_side="left", + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, trust_remote_code=trust_remote_code, ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - ) - - model = LlavaNextForConditionalGeneration(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(VlmCausalLM, self).__init__( - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a7cfc3e4..00088a4d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,869 +1,50 @@ -import torch -import time +import re -from dataclasses import dataclass from opentelemetry import trace -from transformers import ( - AutoProcessor, - AutoTokenizer, - PreTrainedTokenizerBase, - ProcessorMixin, -) from typing import Optional, Tuple, List, Type, Dict -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, +from text_generation_server.models.flash_mistral import ( + BaseFlashMistral, + FlashMistralBatch, ) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -import re +tracer = trace.get_tracer(__name__) IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") -def split(string): +def split(string) -> List[Dict[str, str]]: parts = [] cursor = 0 for pattern in IMAGES.finditer(string): start = pattern.start() if start != cursor: - parts.append(string[cursor:start]) + parts.append({"type": "text", "content": string[cursor:start]}) - parts.append(pattern.group(1)) + parts.append({"type": "image", "content": pattern.group(1)}) cursor = pattern.end() if cursor != len(string): - parts.append(string[cursor:]) + parts.append({"type": "text", "content": string[cursor:]}) return parts -tracer = trace.get_tracer(__name__) +class VlmCausalLMBatch(FlashMistralBatch): + pass -@dataclass -class VlmCausalLMBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - pixel_values: Optional[torch.Tensor] - image_hidden_states: Optional[torch.Tensor] - image_attention_mask: Optional[torch.Tensor] - past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor: ProcessorMixin, # Hack - dtype: torch.dtype, - device: torch.device, - ) -> "VlmCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.inputs) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - # TODO Check impact on idefics - # prompts = [] - # for inp in inputs: - # # Each input is encoded into a list, where each element of this input list is either a string or a URL - # prompts.append(split(inp)) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - tokenized_inputs = processor( - inputs, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - # TODO Check impact on idefics - # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append( - input_len - 5 - ) # To decode without potential fallbacks errors - read_offsets.append( - input_len - ) # To decode without potential fallbacks errors - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - pixel_values = tokenized_inputs.get("pixel_values", None) - image_hidden_states = None - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - # Do the same for image_attention_mask - if pixel_values is None: - image_attention_mask = None - else: - image_attention_mask = input_ids.new_zeros( - ( - pb.size, - max_input_length + padding_right_offset, - pixel_values.size(1), - ) - ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ - "image_attention_mask" - ] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split( - 1, dim=1 - ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list - - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["VlmCausalLMBatch"]: - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - # Do the same for pixel_values and image_attention_mask - pixel_values = self.pixel_values[keep_indices] - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] - if self.image_hidden_states is None: - image_hidden_states = None - else: - image_hidden_states = self.image_hidden_states[keep_indices] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.pixel_values = pixel_values - self.image_hidden_states = image_hidden_states - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["VlmCausalLMBatch"]) -> "VlmCausalLMBatch": - # It adds new requests to the batch - # Used for padding - total_batch_size = 0 - max_input_length = 0 - max_num_images = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - max_num_images = max(max_num_images, batch.pixel_values.size(1)) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - pixel_values = None - image_hidden_states = None - image_attention_mask = None - past_key_values = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) - ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values - ) - - if image_attention_mask is None: - image_attention_mask = batch.image_attention_mask.new_zeros( - ( - total_batch_size, - max_input_length + padding_right_offset, - max_num_images, - ) - ) - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] - ) - else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] - ) - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values - - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) - - def __len__(self): - return len(self.requests) - - -class VlmCausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - from text_generation_server.models.custom_modeling.idefics_modeling import ( - VlmForVisionText2Text, - ) - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.bfloat16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = VlmForVisionText2Text.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": ""}) - - super(VlmCausalLM, self).__init__( - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - +class VlmCausalLM(BaseFlashMistral): @property - def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + def batch_type(self) -> Type[FlashMistralBatch]: + return FlashMistralBatch - def forward( - self, - input_ids, - attention_mask, - position_ids, - pixel_values, - image_hidden_states, - image_attention_mask, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_hidden_states": image_hidden_states, - "image_attention_mask": image_attention_mask, - "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, - } - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - outputs, speculative_logits = self.model.forward(**kwargs) + def get_layer_config(self, model) -> Tuple[int, int, int]: return ( - outputs.logits, - speculative_logits, - outputs.past_key_values, - outputs.image_hidden_states, + len(model.language_model.model.layers), + model.language_model.model.num_key_value_heads, + model.language_model.model.head_size, ) - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: VlmCausalLMBatch - ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - if batch.image_attention_mask is None: - image_attention_mask = None - else: - if batch.input_ids.size(1) == 1: - # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), - # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension - # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated - # token need to attend to the encoder hidden states (i.e. the vision encoder) - # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[ - :, -(batch.padding_right_offset + 1) - ].unsqueeze(1) - else: - image_attention_mask = batch.image_attention_mask[ - :, : -batch.padding_right_offset - ] - - logits, speculative_logits, past, image_hidden_states = self.forward( - input_ids=batch.input_ids, - attention_mask=attention_mask, - position_ids=batch.position_ids, - pixel_values=batch.pixel_values, - image_hidden_states=batch.image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=batch.past_key_values, - ) - # Hardcoded remove image tokens - logits[:, 32000:32001] = torch.finfo(logits.dtype).min - - start_decode = time.time_ns() - - # Results - generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( - next_token_id_squeezed.item() - ) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 - - # Update past key values - batch.past_key_values = past - batch.image_hidden_states = image_hidden_states - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) + def max_past(self) -> Optional[int]: + return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 308fb1e2..beca36b9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -13,9 +13,10 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model +from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch +from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -78,9 +79,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): except ImportError: pass - if ( - self.model.batch_type == VlmCausalLMBatch - ): # Hack, i would rather use kwargs in the `from_pb` call + if self.model.batch_type in { + IdeficsCausalLMBatch, + VlmCausalLMBatch, + }: # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, @@ -100,9 +102,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): start = time.time_ns() - if ( - self.model.batch_type == VlmCausalLMBatch - ): # Hack, i would rather use kwargs in the `from_pb` call + if self.model.batch_type in { + IdeficsCausalLMBatch, + VlmCausalLMBatch, + }: # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer,