diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7525940a..4a96d27a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -660,7 +660,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head" + prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 81e03943..580398cb 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -19,7 +19,6 @@ from typing import List, Optional, Tuple import torch import torch.utils.checkpoint from torch import nn -import math from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 8937c1c5..db78341d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -30,7 +30,7 @@ IDEFICS3_GLOBAL_IMG_TOKEN = "" # copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 -def get_image_prompt_string( +def _prompt_split_image( *, image_seq_len: int, image_rows: int, @@ -97,7 +97,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) ) - image_str = get_image_prompt_string( + image_str = _prompt_split_image( image_seq_len=image_seq_len, image_rows=n_rows, image_cols=n_cols,