diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ab830b58..2b6ea31a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -206,9 +206,15 @@ try: from text_generation_server.models.transformers_flash_causal_lm import ( TransformersFlashCausalLM, ) -except ImportError: + from text_generation_server.models.transformers_flash_vlm import ( + TransformersFlashVlmCausalLM, + ) +except ImportError as e: + log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}") FLASH_TRANSFORMERS_BACKEND = False +FLASH_ATTENTION = False + class ModelType(enum.Enum): DEEPSEEK_V2 = { @@ -1173,12 +1179,13 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) elif FLASH_TRANSFORMERS_BACKEND: - return TransformersFlashCausalLM.fallback( + return TransformersFlashVlmCausalLM.fallback( model_id, + # AutoModelForConditionalGeneration, revision, quantize=quantize, speculator=speculator, - dtype=dtype, + dtype=torch.bfloat16, trust_remote_code=trust_remote_code, ) elif sharded: @@ -1483,6 +1490,15 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == QWEN2_VL: + return TransformersFlashVlmCausalLM.fallback( + model_id, + # AutoModelForConditionalGeneration, + revision, + quantize=quantize, + speculator=speculator, + dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + ) return VlmCausalLM( model_id=model_id, model_class=Qwen2VLForConditionalGeneration, @@ -1563,23 +1579,33 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: - if FLASH_ATTENTION: - return VlmCausalLM( - model_id=model_id, - model_class=PaliGemmaForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - kv_cache_dtype=kv_cache_dtype, - # Works better for these models - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, - ) - else: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + # return TransformersFlashVlmCausalLM.fallback( + # model_id, + # # AutoModelForConditionalGeneration, + # revision, + # quantize=quantize, + # speculator=speculator, + # dtype=torch.bfloat16, + # trust_remote_code=trust_remote_code, + # batch_class=PaliGemmaBatch, + # ) + # if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, + ) + # else: + # raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d3a83e27..c7c5a374 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1344,9 +1344,6 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch - def max_past(self) -> int: - return getattr(self.model, "max_past", None) - def init_kv_cache( self, num_blocks: int, @@ -1792,12 +1789,6 @@ class FlashCausalLM(Model): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) if sorted_padded_bs: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 19696372..4cea5a59 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -6,6 +6,7 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} + ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 8773bfd3..58bc4b2d 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -36,10 +36,12 @@ def tgi_flash_attention_forward( softcap: Optional[float] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): + # from pdb import set_trace; set_trace() kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) + # from pdb import set_trace; set_trace() # Take care of updating the cache in-place kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) @@ -47,6 +49,7 @@ def tgi_flash_attention_forward( _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window + # from pdb import set_trace; set_trace() if cu_seqlen_prefill is not None: attn_output = attention( @@ -72,6 +75,7 @@ def tgi_flash_attention_forward( max_s, kv_scales=kv_scales, softcap=softcap, + window_size_left=sliding_window, ) attn_output = attn_output.view(-1, num_heads * head_dim) @@ -104,6 +108,7 @@ class TransformersFlashCausalLM(FlashCausalLM): tokenizer_class=AutoTokenizer, kv_cache_dtype: Optional[torch.dtype] = None, ): + # # from pdb import set_trace; set_trace() self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() @@ -157,7 +162,14 @@ class TransformersFlashCausalLM(FlashCausalLM): self.num_layers = model.config.num_hidden_layers self.num_heads = model.config.num_attention_heads self.num_kv_heads = model.config.num_key_value_heads - self.head_size = model.config.hidden_size // model.config.num_attention_heads + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(model.config, "head_dim"): + self.head_size = model.config.head_dim + else: + self.head_size = ( + model.config.hidden_size // model.config.num_attention_heads + ) # Skip it for models in the exception list if model.config.model_type not in REPLICATED_ATTENTION_MODELS: @@ -254,6 +266,7 @@ class TransformersFlashCausalLM(FlashCausalLM): prefill_cache_indices=None, # not used, but passed to match original signature adapter_data=None, # not supported, but passed to match original signature ): + # from pdb import set_trace; set_trace() # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py new file mode 100644 index 00000000..d4945ef0 --- /dev/null +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -0,0 +1,467 @@ +import math +from typing import List, Optional + +import torch +from opentelemetry import trace +from transformers import AutoTokenizer, AutoProcessor +import transformers.modeling_utils + +from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch +from text_generation_server.utils import initialize_torch_distributed + +from text_generation_server.layers.attention import paged_attention, attention, Seqlen +from text_generation_server.layers.attention.kv_cache import KVScales, KVCache +from text_generation_server.models.globals import ATTENTION +import torch.nn.functional as F + +tracer = trace.get_tracer(__name__) + + +def tgi_flash_attention_forward( + module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers + kv_cache: List[KVCache], + kv_head_mapping: torch.Tensor, + slots: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + seqlen: Seqlen, + block_tables: torch.Tensor, + max_s: int, + kv_scales: KVScales, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling +): + from loguru import logger + + logger.info("Using TGI Flash Attention") + # from pdb import set_trace; set_trace() + kv_cache = kv_cache[module.layer_idx] + query_states = query_states.transpose(1, 2).squeeze(dim=0) + key_states = key_states.transpose(1, 2).squeeze(dim=0) + value_states = value_states.transpose(1, 2).squeeze(dim=0) + # from pdb import set_trace; set_trace() + + # Take care of updating the cache in-place + kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) + + _, num_heads, head_dim = query_states.shape + softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + sliding_window = -1 if sliding_window is None else sliding_window + # if module.layer_idx == 0: + # from pdb import set_trace; set_trace() + + if cu_seqlen_prefill is not None: + attention_mask = None + if attention_mask is None: + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) + else: + from loguru import logger + + logger.info("uSING FLASH ATTENTION") + lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1] + max_length = max(lengths) + attention_mask = attention_mask[:, :, :, :max_length] + enable_gqa = query_states.shape[1] != key_states.shape[1] + # Split tensors using vectorized split + query_list = torch.split(query_states, lengths.tolist(), dim=0) + key_list = torch.split(key_states, lengths.tolist(), dim=0) + value_list = torch.split(value_states, lengths.tolist(), dim=0) + + padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True) + padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True) + padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True) + + padded_query = padded_query.transpose(1, 2).contiguous() + padded_key = padded_key.transpose(1, 2).contiguous() + padded_value = padded_value.transpose(1, 2).contiguous() + + # Compute attention + attn_output = F.scaled_dot_product_attention( + padded_query, + padded_key, + padded_value, + attn_mask=attention_mask, + scale=softmax_scale, + enable_gqa=enable_gqa, + ) + + attn_output = attn_output.transpose( + 1, 2 + ) # [batch_size, seq_len, num_heads, head_dim] + max_seq_len = padded_query.size(2) + seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze( + 0 + ) + lengths_tensor = torch.tensor( + lengths, device=padded_query.device + ).unsqueeze(1) + mask = seq_range < lengths_tensor # [batch, max_seq_len] + attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim] + + else: + attn_output = paged_attention( + query_states, + kv_cache, + kv_head_mapping, + softmax_scale, + block_tables, + seqlen, + max_s, + kv_scales=kv_scales, + softcap=softcap, + window_size_left=sliding_window, + ) + + attn_output = attn_output.view(-1, num_heads * head_dim) + + return attn_output, None + + +transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward +transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = ( + transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"] +) +# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"] +# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"] +transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = ( + tgi_flash_attention_forward +) +transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ + "tgi" +] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ + "eager" +] + +# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, +# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache +# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due +# to internal constraints it was not (yet?) possible to circumvent +REPLICATED_ATTENTION_MODELS = [ + "olmo2", + "phi3", +] + + +class TransformersFlashVlmCausalLM(VlmCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + processor_class=AutoProcessor, + processor_kwargs=None, + kv_cache_dtype: Optional[torch.dtype] = None, + batch_class=VlmCausalLMBatch, + ): + # # from pdb import set_trace; set_trace() + self.batch_class = VlmCausalLMBatch + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + dtype = default_dtype if dtype is None else dtype + else: + raise ValueError( + "Flash `Transformers` modeling backend is not available on cpu." + ) + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + if processor_kwargs is None: + processor_kwargs = {} + # processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + from transformers import Qwen2VLForConditionalGeneration + + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + attn_implementation="tgi", + device_map=device if world_size == 1 else None, + tp_plan="auto" if world_size > 1 else None, + ) + + torch.distributed.barrier(group=self.process_group) + self.config = model.config + config = model.config + + # VLM models define the config we care about in their text_config + text_config = getattr(model.config, "text_config", None) + if text_config is not None: + config = text_config + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None and isinstance( + model.config.eos_token_id, int + ): + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = config.hidden_size // config.num_attention_heads + + # Skip it for models in the exception list + if config.model_type not in REPLICATED_ATTENTION_MODELS: + self.num_heads = self.num_heads // self.process_group.size() + self.num_kv_heads = ( + self.num_kv_heads // self.process_group.size() + if self.num_kv_heads > 1 + else self.num_kv_heads + ) + + self.cuda_graphs = {} + self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_prefill_state, + create_decode_state, + create_prefill_with_paged_kv_state, + ) + + self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) + + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + self.num_groups = self.num_heads // self.num_kv_heads + + # Those will never change and will be used in the forwards + self.kv_head_mapping = torch.arange( + 0, self.num_kv_heads, dtype=torch.int32, device=device + ).repeat_interleave(self.num_groups) + # This means no scale + self.kv_scales = KVScales( + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device), + ) + + # Skip FlashCausalLM init. + super(FlashCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code + # We first copy the original model.forward because we still need it in the monkey patch + self.model.original_forward = self.model.forward + self.model.forward = self._model_forward + self.model.get_position_ids = self.get_position_ids + + torch.distributed.barrier(group=self.process_group) + + def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor): + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.config.vision_config.spatial_merge_size + vision_start_token_id = self.config.vision_start_token_id + vision_end_token_id = self.config.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + # import ipdb + + # ipdb.set_trace() + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + @classmethod + def fallback( + cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + batch_class: Optional[type] = VlmCausalLMBatch, + ): + return cls( + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + batch_class=batch_class, + ) + + def _model_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[KVCache], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + lm_head_indices: Optional[torch.Tensor], + prefill_cache_indices=None, # not used, but passed to match original signature + adapter_data=None, # not supported, but passed to match original signature + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + ): + # from pdb import set_trace; set_trace() + # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers + logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 + + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ + logits = ( + self.model.original_forward( + input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers + position_ids=position_ids.transpose(0, 1).unsqueeze( + 1 + ), # expand dim to fit Transformers + past_key_values=None, # we use self.kv_cache instead of transformers cache object + use_cache=False, # we use self.kv_cache instead of transformers cache object + logits_to_keep=logits_to_keep, + return_dict=True, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + kv_head_mapping=self.kv_head_mapping, + kv_scales=self.kv_scales, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + .logits.squeeze(dim=0)[lm_head_indices] + .unsqueeze(0) + ) + + # from pdb import set_trace; set_trace() + + return logits, None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 9111fdc0..adb14c6a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -372,9 +372,6 @@ class VlmCausalLM(FlashCausalLM): def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) - def forward( self, batch: VlmCausalLMBatch, @@ -442,12 +439,6 @@ class VlmCausalLM(FlashCausalLM): ) batch.position_ids = position_ids - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])