diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ab830b58b..54f57bcc4 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -206,6 +206,9 @@ try: from text_generation_server.models.transformers_flash_causal_lm import ( TransformersFlashCausalLM, ) + from text_generation_server.models.transformers_flash_causal_vlm import ( + TransformersFlashCausalVLM, + ) except ImportError: FLASH_TRANSFORMERS_BACKEND = False @@ -1483,17 +1486,32 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == QWEN2_VL: - return VlmCausalLM( - model_id=model_id, - model_class=Qwen2VLForConditionalGeneration, - revision=revision, + # return VlmCausalLM( + # model_id=model_id, + # model_class=Qwen2VLForConditionalGeneration, + # revision=revision, + # quantize=quantize, + # speculator=speculator, + # dtype=dtype, + # default_dtype=torch.bfloat16, + # kv_cache_dtype=kv_cache_dtype, + # trust_remote_code=trust_remote_code, + # lora_adapter_ids=lora_adapter_ids, + # ) + from transformers import Qwen2VLForConditionalGeneration, AutoProcessor + + # CUDA_LAUNCH_BLOCKING=1 ATTENTION=flashinfer PREFIX_CACHING=False MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve Qwen/Qwen2-VL-2B-Instruct + # cargo run --bin text-generation-router --release -- --tokenizer-name Qwen/Qwen2-VL-2B-Instruct --max-batch-prefill-tokens 256 + return TransformersFlashCausalVLM.fallback( + model_id, + revision, quantize=quantize, speculator=speculator, dtype=dtype, - default_dtype=torch.bfloat16, - kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + processor_class=AutoProcessor, + model_class=Qwen2VLForConditionalGeneration, ) if model_type == QWEN2_5_VL: return VlmCausalLM( diff --git a/server/text_generation_server/models/transformers_flash_causal_vlm.py b/server/text_generation_server/models/transformers_flash_causal_vlm.py new file mode 100644 index 000000000..4468e2762 --- /dev/null +++ b/server/text_generation_server/models/transformers_flash_causal_vlm.py @@ -0,0 +1,368 @@ +import math +from typing import List, Optional, Dict, Tuple, Any, ContextManager +from collections import namedtuple + +import torch +from opentelemetry import trace +from transformers import AutoTokenizer, AutoModelForCausalLM +import transformers.modeling_utils +from contextlib import nullcontext + +from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch +from text_generation_server.utils import initialize_torch_distributed + +from text_generation_server.layers.attention import paged_attention, attention, Seqlen +from text_generation_server.layers.attention.kv_cache import KVScales, KVCache +from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.metadata_kernels import block_tables_to_ragged + + +tracer = trace.get_tracer(__name__) + + +def tgi_flash_attention_forward( + module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers + kv_cache: List[KVCache], + kv_head_mapping: torch.Tensor, + slots: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + seqlen: Seqlen, + block_tables: torch.Tensor, + max_s: int, + kv_scales: KVScales, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling +): + kv_cache = kv_cache[module.layer_idx] + query_states = query_states.transpose(1, 2).squeeze(dim=0) + key_states = key_states.transpose(1, 2).squeeze(dim=0) + value_states = value_states.transpose(1, 2).squeeze(dim=0) + + # Take care of updating the cache in-place + kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) + + _, num_heads, head_dim = query_states.shape + softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + sliding_window = -1 if sliding_window is None else sliding_window + + if cu_seqlen_prefill is not None: + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) + else: + attn_output = paged_attention( + query_states, + kv_cache, + kv_head_mapping, + softmax_scale, + block_tables, + seqlen, + max_s, + kv_scales=kv_scales, + softcap=softcap, + ) + + attn_output = attn_output.view(-1, num_heads * head_dim) + + return attn_output, None + + +transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward + +# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, +# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache +# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due +# to internal constraints it was not (yet?) possible to circumvent +REPLICATED_ATTENTION_MODELS = [ + "olmo2", + "phi3", +] + + +class TransformersFlashCausalVLM(VlmCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + kv_cache_dtype: Optional[torch.dtype] = None, + processor_class=None, + processor_kwargs=None, + model_class=AutoModelForCausalLM, + ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + dtype = default_dtype if dtype is None else dtype + else: + raise ValueError( + "Flash `Transformers` modeling backend is not available on cpu." + ) + + # Initialize processor + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = VlmCausalLMBatch + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + model = model_class.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + # attn_implementation="tgi", + # TODO: prefer custom implementation + attn_implementation="sdpa", + device_map=device if world_size == 1 else None, + tp_plan="auto" if world_size > 1 else None, + ) + + torch.distributed.barrier(group=self.process_group) + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None and isinstance( + model.config.eos_token_id, int + ): + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + # VLM models define the config we care about in their text_config + text_config = getattr(model.config, "text_config", None) + if text_config is not None: + config = text_config + else: + config = model.config + + self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads + self.num_kv_heads = getattr(config, "num_key_value_heads", self.num_heads) + self.head_size = config.hidden_size // config.num_attention_heads + + # Skip it for models in the exception list + if model.config.model_type not in REPLICATED_ATTENTION_MODELS: + self.num_heads = self.num_heads // self.process_group.size() + self.num_kv_heads = ( + self.num_kv_heads // self.process_group.size() + if self.num_kv_heads > 1 + else self.num_kv_heads + ) + + self.cuda_graphs = {} + self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_prefill_state, + create_decode_state, + create_prefill_with_paged_kv_state, + ) + + self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) + + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + self.num_groups = self.num_heads // self.num_kv_heads + + # Those will never change and will be used in the forwards + self.kv_head_mapping = torch.arange( + 0, self.num_kv_heads, dtype=torch.int32, device=device + ).repeat_interleave(self.num_groups) + # This means no scale + self.kv_scales = KVScales( + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device), + ) + + # Skip FlashCausalLM init. + super(FlashCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + sliding_window=getattr(config, "sliding_window", None), + support_chunking=False, + ) + + # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code + # We first copy the original model.forward because we still need it in the monkey patch + self.model.original_forward = self.model.forward + self.model.forward = self._model_forward + + self.model.get_position_ids = self._get_position_ids + + text_model = namedtuple("TextModel", ["max_past"]) + + # model has text_model.max_past + self.model.text_model = text_model(max_past=1024) + + torch.distributed.barrier(group=self.process_group) + + @classmethod + def fallback( + cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + processor_class=None, + processor_kwargs=None, + model_class=AutoModelForCausalLM, + ): + return cls( + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + processor_class=processor_class, + processor_kwargs=processor_kwargs, + model_class=model_class, + ) + + def warmup(self, batch, max_input_tokens, max_total_tokens): + return super().warmup(batch, max_input_tokens, max_total_tokens) + # max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + # 5000, 5000, 5001 + # ) + # return max_supported_total_tokens, max_input_tokens, max_total_tokens + + # TODO: may need to be updated + def _get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor): + # return a 4D tensor of shape (batch_size, num_heads, num_kv_heads, head_size) + # return torch.arange(input_ids.shape[-1], device=input_ids.device).view( + # 1, 1, 1, -1 + # ) + return torch.arange(input_ids.shape[-1], device=input_ids.device) + + def _model_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[KVCache], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + lm_head_indices: Optional[torch.Tensor], + pixel_values=None, + pixel_attention_mask=None, + image_sizes=None, + image_grid_thw=None, + prefill_cache_indices=None, # not used, but passed to match original signature + adapter_data=None, # not supported, but passed to match original signature + ): + print("\n----") + # TODO: adjust the values to the correct ones + # initalized to None as they are expected by the original forward + attention_mask = None + past_key_values = None + inputs_embeds = None + labels = None + use_cache = None + output_attentions = None + output_hidden_states = None + return_dict = None + pixel_values_videos = None + video_grid_thw = None + rope_deltas = None + cache_position = None + + if pixel_values is not None: + print("pixel_values shape", pixel_values.shape) + + logits = self.model.original_forward( + input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers + attention_mask=attention_mask, + position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ).logits.squeeze(dim=0) + + print("logits shape", logits.shape) + + # dummy logic to avoid errors (must be correct size (batch_size, seq_len, vocab_size)) + # vocab_size = (logits.shape[0], 151936) + vocab_size = (1, 151936) + dummy_logits = torch.zeros(vocab_size, device=input_ids.device) + + # add the logits to the dummy_logits to the first index only + dummy_logits[0, :logits.shape[1]] = logits[0] + + print("dummy_logits shape", dummy_logits.shape) + + return dummy_logits, None