From 8188deac224ffed2d057d429d0f28d0a3a6f5744 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:08:30 +0200 Subject: [PATCH] fix vlm and seq2seq --- .../test_grammar_response_format_llama.py | 1 - .../layers/attention/flash_attn_triton.py | 1 - .../models/custom_modeling/mllama.py | 11 +++++++--- .../models/seq2seq_lm.py | 2 +- .../models/vlm_causal_lm.py | 5 ++--- .../text_generation_server/utils/adapter.py | 21 +++++++++++-------- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index 25bf9d98..eb3268ce 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle): @pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): - class Weather(BaseModel): unit: str temperature: List[int] diff --git a/server/text_generation_server/layers/attention/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py index 3a6f9a73..fd180f0f 100644 --- a/server/text_generation_server/layers/attention/flash_attn_triton.py +++ b/server/text_generation_server/layers/attention/flash_attn_triton.py @@ -699,7 +699,6 @@ def check_args( class _attention(torch.autograd.Function): - @staticmethod def forward( ctx, diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 6e091a74..be0a4b5d 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module): aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 0a1d0824..42cd572a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, - current_tokens=len(self.input_ids), + current_tokens=len(self.decoder_input_ids), ) @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1a578d7b..7484e448 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -295,7 +295,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -338,7 +338,7 @@ class VlmCausalLM(FlashCausalLM): slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -347,7 +347,6 @@ class VlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 2b61f9bb..09254b68 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -120,15 +120,18 @@ def _load_and_merge( if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( - load_module_map( - model_id, - adapter.revision, - adapter.id, - adapter.path, - weight_names, - trust_remote_code, - ) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_module_map( + model_id, + adapter.revision, + adapter.id, + adapter.path, + weight_names, + trust_remote_code, ) adapters_to_merge.append((module_map, adapter_config))