From fc02d99e57efea8b577e284ddb407122461e1105 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 15 Aug 2023 11:10:52 +0000 Subject: [PATCH] Much better code with correct RMS + rotary: Next: - precompute cos/sin - Fix KV layout - Refactor everything for flash (long). --- .../custom_modeling/idefics_modeling.py | 243 ++++++++++++------ .../text_generation_server/models/idefics.py | 28 -- .../models/idefics_causal_lm.py | 50 ++-- server/text_generation_server/server.py | 17 ++ 4 files changed, 204 insertions(+), 134 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index cbed601b..caf11fb0 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass from transformers.modeling_utils import PretrainedConfig from transformers.utils import ( add_start_docstrings, @@ -47,6 +47,15 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, FastLinear, ) +import dropout_layer_norm + +@dataclass +class BaseModelOutputWithPastImage(BaseModelOutputWithPast): + image_hidden_states: Optional[torch.FloatTensor] = None + +@dataclass +class CausalLMOutputWithPastImage(CausalLMOutputWithPast): + image_hidden_states: Optional[torch.FloatTensor] = None # logger = logging.get_logger(__name__) @@ -279,10 +288,30 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] # this was adapted from LlamaRMSNorm +# class IdeficsRMSNorm(nn.Module): +# def __init__(self, prefix, weights, eps=1e-6): +# """ +# IdeficsRMSNorm is equivalent to T5LayerNorm +# """ +# super().__init__() +# +# weight = weights.get_tensor(f"{prefix}.weight") +# self.weight = nn.Parameter(weight) +# self.variance_epsilon = eps +# +# def forward(self, hidden_states): +# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) +# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) +# +# # convert into half-precision if necessary +# if self.weight.dtype in [torch.float16, torch.bfloat16]: +# hidden_states = hidden_states.to(self.weight.dtype) +# +# return self.weight * hidden_states class IdeficsRMSNorm(nn.Module): def __init__(self, prefix, weights, eps=1e-6): """ - IdeficsRMSNorm is equivalent to T5LayerNorm + LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() @@ -290,15 +319,56 @@ class IdeficsRMSNorm(nn.Module): self.weight = nn.Parameter(weight) self.variance_epsilon = eps - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) - return self.weight * hidden_states + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + else: + # faster post attention rms norm + unwrap = False + if len(hidden_states.shape) > 2: + unwrap = True + shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, shape[-1]) + + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + if unwrap: + normed_hidden_states = normed_hidden_states.view(*shape) + + + return normed_hidden_states # this was adapted from LlamaRotaryEmbedding @@ -341,14 +411,14 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +# def apply_rotary_pos_emb(q, k, cos, sin, position_ids): +# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] +# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) +# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) +# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) +# q_embed = (q * cos) + (rotate_half(q) * sin) +# k_embed = (k * cos) + (rotate_half(k) * sin) +# return q_embed, k_embed # this was adapted from LlamaMLP @@ -438,11 +508,9 @@ class IdeficsAttention(nn.Module): self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False ) - # self.rotary_emb = PositionRotaryEmbedding.load( - # prefix=f"{prefix}.rotary_emb", weights=weights - # ) - self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash - + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_dim, base=10000.0, device=weights.device + ) self.qk_layer_norms = qk_layer_norms if self.qk_layer_norms: self.q_layer_norm = IdeficsRMSNorm( @@ -470,11 +538,29 @@ class IdeficsAttention(nn.Module): bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) if not is_cross_attention: - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) + kv_seq_len = q_len + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + max_s = max(kv_seq_len, q_len) + cos, sin = self.rotary_emb.get_cos_sin( + position_ids.view(-1), max_s, hidden_states.dtype + ) + + shape = query_states.shape + query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape) + + shape = key_states.shape + key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) else: + query_states = query_states.transpose(1, 2) _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = ( @@ -484,14 +570,6 @@ class IdeficsAttention(nn.Module): kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - if not is_cross_attention: - cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len)) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # cos, sin = self.rotary_emb.get_cos_sin( - # position_ids=torch.arange(), - # max_s=max(kv_seq_len, q_len), - # dtype=hidden_states.dtype, - # ) # [bsz, nh, t, hd] if past_key_value is not None: @@ -951,13 +1029,14 @@ class IdeficsModel(IdeficsPreTrainedModel): past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, image_embeddings: Optional[torch.FloatTensor] = None, image_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> Union[Tuple, BaseModelOutputWithPastImage]: device = input_ids.device if input_ids is not None else inputs_embeds.device output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -999,30 +1078,37 @@ class IdeficsModel(IdeficsPreTrainedModel): position_ids = position_ids.view(-1, seq_length).long() no_images = False - if pixel_values is None and image_embeddings is None: - raise ValueError("Either pixel_values and image_embeddings have to be not-None.") + + if image_hidden_states is None: + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values and image_embeddings have to be not-None.") - elif pixel_values is not None and image_embeddings is not None: - raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time") + elif pixel_values is not None and image_embeddings is not None: + raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time") - elif pixel_values is not None: - no_images = len(torch.nonzero(pixel_values)) == 0 - pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility - batch_size, num_images = pixel_values.shape[:2] - pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:]) + elif pixel_values is not None: + no_images = len(torch.nonzero(pixel_values)) == 0 + pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility + batch_size, num_images = pixel_values.shape[:2] + pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:]) - # Get sequence from the vision encoder - image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state + # Get sequence from the vision encoder + image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state - elif image_embeddings is not None: - batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size() - image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device) - image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) + elif image_embeddings is not None: + batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size() + image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device) + image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) + + if self.config.use_resampler: + image_hidden_states = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) + image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) + else: + no_images = True + num_images = pixel_values.shape[1] + image_seq_len = image_hidden_states.shape[1] // num_images - if self.config.use_resampler: - image_hidden_states = self.perceiver_resampler(image_hidden_states) - image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) - image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) # # Hack to use the model in full language modeling mode # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) # Make image_attention_mask compatible with hidden states @@ -1030,15 +1116,19 @@ class IdeficsModel(IdeficsPreTrainedModel): image_attention_mask = image_attention_mask.unsqueeze(-1) image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len) image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len) + image_batch_size, image_sequence_length, _ = image_hidden_states.size() + image_hidden_shape = (image_batch_size, image_sequence_length) + if image_attention_mask is None: + image_attention_mask = torch.ones(image_hidden_shape, device=device) + image_attention_mask = self.invert_attention_mask(image_attention_mask) - if image_hidden_states is not None: - image_batch_size, image_sequence_length, _ = image_hidden_states.size() - image_hidden_shape = (image_batch_size, image_sequence_length) - if image_attention_mask is None: - image_attention_mask = torch.ones(image_hidden_shape, device=device) - image_attention_mask = self.invert_attention_mask(image_attention_mask) - else: - image_attention_mask = None + # if list(image_attention_mask.shape) != [4, 1, 1024, 64]: + # raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}") + + + # if image_hidden_states is not None: + # else: + # image_attention_mask = None if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1170,11 +1260,12 @@ class IdeficsModel(IdeficsPreTrainedModel): next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + return BaseModelOutputWithPastImage( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, + image_hidden_states=image_hidden_states, ) @@ -1204,13 +1295,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_embeddings: Optional[torch.FloatTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, image_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, CausalLMOutputWithPastImage]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1244,11 +1336,6 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - from loguru import logger; logger.info(f"forward in idefics_modeling.py - {input_ids.size()=}") - from loguru import logger; logger.info(f"forward in idefics_modeling.py - {attention_mask.size()=}") - from loguru import logger; logger.info(f"forward in idefics_modeling.py - {position_ids.size()=}") - from loguru import logger; logger.info(f"forward in idefics_modeling.py - {pixel_values.size()=} {pixel_values.sum()=}") - from loguru import logger; logger.info(f"forward in idefics_modeling.py - {image_attention_mask.size()=} {image_attention_mask.sum()=}") outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -1257,6 +1344,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_embeddings=image_embeddings, + image_hidden_states=image_hidden_states, image_attention_mask=image_attention_mask, use_cache=use_cache, output_attentions=output_attentions, @@ -1268,29 +1356,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): logits = self.lm_head(hidden_states) loss = None - # if labels is not None: - # # Shift so that tokens < n predict n - # if attention_mask is not None: - # shift_attention_mask = attention_mask[..., 1:] - # shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - # shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() - # else: - # shift_logits = logits[..., :-1, :].contiguous() - # shift_labels = labels[..., 1:].contiguous() - # # Flatten the tokens - # loss_fct = CrossEntropyLoss() - # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - # if not return_dict: - # output = (logits,) + outputs[1:] - # return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( + return CausalLMOutputWithPastImage( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index b30c5f80..b507da19 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -82,31 +82,3 @@ class IDEFICSSharded(IdeficsCausalLM): rank=rank, world_size=world_size, ) - - def forward( - self, - input_ids, - attention_mask, - position_ids, - pixel_values: Optional = None, - image_attention_mask: Optional = None, - past_key_values: Optional = None, - ) -> Tuple[ - torch.Tensor, - List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - ]: - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_attention_mask=image_attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return ( - outputs.logits, - outputs.past_key_values, - ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 74afc0a4..7245bc65 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -64,6 +64,7 @@ class IdeficsCausalLMBatch(Batch): attention_mask: torch.Tensor position_ids: torch.Tensor pixel_values: Optional[torch.Tensor] + image_hidden_states: Optional[torch.Tensor] image_attention_mask: Optional[torch.Tensor] past_key_values: Optional[List[Tuple]] @@ -118,7 +119,7 @@ class IdeficsCausalLMBatch(Batch): padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}") + # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}") requests_idx_mapping[r.id] = i inputs.append(r.inputs) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) @@ -135,7 +136,7 @@ class IdeficsCausalLMBatch(Batch): prompts = [] for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL - from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}") + # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}") if isinstance(inp, str): prompts.append([inp]) elif isinstance(inp, list): @@ -168,7 +169,7 @@ class IdeficsCausalLMBatch(Batch): max_length=max_truncation, add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token ).to(device) - from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}") + # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}") # from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()}) # {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])} for _ in pb.requests: @@ -181,6 +182,7 @@ class IdeficsCausalLMBatch(Batch): input_ids = tokenized_inputs["input_ids"] pixel_values = tokenized_inputs["pixel_values"] + image_hidden_states = None # Allocate maximum attention_mask attention_mask = input_ids.new_zeros( (pb.size, max_input_length + padding_right_offset) @@ -192,7 +194,7 @@ class IdeficsCausalLMBatch(Batch): (pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1)) ) # image_attention_mask = tokenized_inputs["image_attention_mask"] - from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") + # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 @@ -209,6 +211,7 @@ class IdeficsCausalLMBatch(Batch): attention_mask=attention_mask, position_ids=position_ids, pixel_values=pixel_values, + image_hidden_states=image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=None, all_input_ids=list(all_input_ids), @@ -224,7 +227,7 @@ class IdeficsCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - from loguru import logger; logger.info(f"filter in idefics_causal_lm.py") + # from loguru import logger; logger.info(f"filter in idefics_causal_lm.py") # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") @@ -336,7 +339,7 @@ class IdeficsCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch": - from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py") + # from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py") # It adds new requests to the batch # Used for padding total_batch_size = 0 @@ -433,8 +436,8 @@ class IdeficsCausalLMBatch(Batch): :, batch_left_offset : -batch.padding_right_offset, ] - from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") - from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}") + # from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") + # from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}") image_attention_mask[ start_index:end_index, left_offset:-padding_right_offset, @@ -648,6 +651,7 @@ class IdeficsCausalLM(Model): attention_mask, position_ids, pixel_values, + image_hidden_states, image_attention_mask, past_key_values: Optional = None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: @@ -656,6 +660,7 @@ class IdeficsCausalLM(Model): "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, + "image_hidden_states": image_hidden_states, "image_attention_mask": image_attention_mask, "past_key_values": past_key_values, "use_cache": True, @@ -665,13 +670,13 @@ class IdeficsCausalLM(Model): kwargs["position_ids"] = position_ids outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values + return outputs.logits, outputs.past_key_values, outputs.image_hidden_states @tracer.start_as_current_span("generate_token") def generate_token( self, batch: IdeficsCausalLMBatch ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]: - from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter") + # from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter") # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] if batch.input_ids.size(1) == 1: @@ -683,19 +688,21 @@ class IdeficsCausalLM(Model): image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around else: image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset] - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}") - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}") - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}") - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}") - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}") - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}") - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}") + # image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset] + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}") - logits, past = self.forward( + logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, attention_mask=attention_mask, position_ids=batch.position_ids, pixel_values=batch.pixel_values, + image_hidden_states=batch.image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=batch.past_key_values, ) @@ -806,7 +813,7 @@ class IdeficsCausalLM(Model): # Update values batch.input_ids[i, 0] = next_token_id - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}") batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length batch.prefix_offsets[i] = prefix_offset @@ -819,7 +826,7 @@ class IdeficsCausalLM(Model): # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}") # Update attention_mask as we added a new token to input_ids batch.attention_mask[:, -batch.padding_right_offset] = 1 @@ -832,6 +839,7 @@ class IdeficsCausalLM(Model): # Update past key values batch.past_key_values = past + batch.image_hidden_states = image_hidden_states - from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}") + # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}") return generations, batch diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e7ed8875..9507d331 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch +PROFILE = False class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]): @@ -27,6 +28,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # Force inference mode for the lifetime of TextGenerationService self._inference_mode_raii_guard = torch._C._InferenceMode(True) + if PROFILE: + self.prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'), + record_shapes=True, + with_stack=True + ) + self.prof.start() + async def Info(self, request, context): return self.model.info @@ -80,6 +90,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) generations, next_batch = self.model.generate_token(batch) + if PROFILE: + self.prof.step() self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -107,7 +119,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = batches[0] generations, next_batch = self.model.generate_token(batch) + if PROFILE: + self.prof.step() self.cache.set(next_batch) + if next_batch is None: + if PROFILE: + self.prof.stop() return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations],