From 53b2bea6b99d0dfcf1c088853adcf45dfadfe171 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 18 Dec 2024 03:25:22 +0000 Subject: [PATCH] feat: improve prefix for idefics3 --- .../custom_modeling/flash_llama_modeling.py | 36 +++++++++++-------- .../models/vlm_causal_lm.py | 26 ++++++-------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 63207938..d2c4f751 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -515,7 +515,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=0, - prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"), + prefix=f"{prefix}.layers.0" if prefix else "model.layers.0", config=config, weights=weights, ) @@ -532,9 +532,9 @@ class FlashLlamaModel(torch.nn.Module): FlashLlamaCrossLayer( index=layer_id, prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.layers.{layer_id}" + f"{prefix}.layers.{layer_id}" + if prefix + else f"model.layers.{layer_id}" ), config=config, weights=weights, @@ -545,9 +545,9 @@ class FlashLlamaModel(torch.nn.Module): FlashLlamaLayer( index=layer_id, prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.layers.{layer_id}" + f"{prefix}.layers.{layer_id}" + if prefix + else f"model.layers.{layer_id}" ), config=config, weights=weights, @@ -560,9 +560,9 @@ class FlashLlamaModel(torch.nn.Module): FlashLlamaLayer( index=last_layer_id, prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.layers.{last_layer_id}" + f"{prefix}.layers.{last_layer_id}" + if prefix + else f"model.layers.{last_layer_id}" ), config=config, weights=weights, @@ -570,7 +570,7 @@ class FlashLlamaModel(torch.nn.Module): ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.norm", + prefix=f"{prefix}.norm" if prefix else "model.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -630,11 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() + if config.model_type == "mllama_text_model": + prefix = f"{prefix}.model" + with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.embed_tokens" - ), + prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"), weights=weights, ) self.model = FlashLlamaModel(prefix, config, weights) @@ -648,6 +649,13 @@ class FlashLlamaForCausalLM(torch.nn.Module): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier + if config.model_type == "mllama_text_model": + prefix = prefix.replace(".model", "") + suffix = f"{prefix}.{suffix}" + + if config.model_type == "granite": + suffix = f"{prefix}.{suffix}" + with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 0548fbc6..306da497 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -167,10 +167,6 @@ def image_text_replacement_fixup(config, text: str) -> str: return text.replace( f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN ) - if config.model_type == "idefics3": - return text.replace( - f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN - ) return text @@ -290,8 +286,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch): else: image_inputs = None - batch_inputs = [] - max_truncation = 0 + batch_tokenized_inputs = [] + max_length = 0 image_id = 0 for r in requests: full_text = "" @@ -306,16 +302,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_id += 1 full_text = image_text_replacement_fixup(config, full_text) - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=not config.model_type == "paligemma", - )["input_ids"] + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) return batch_tokenized_inputs, image_inputs