mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve prefix for idefics3
This commit is contained in:
parent
c9573ddf28
commit
53b2bea6b9
@ -515,7 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
index=0,
|
||||
prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"),
|
||||
prefix=f"{prefix}.layers.0" if prefix else "model.layers.0",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
@ -532,9 +532,9 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
FlashLlamaCrossLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.layers.{layer_id}"
|
||||
f"{prefix}.layers.{layer_id}"
|
||||
if prefix
|
||||
else f"model.layers.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
@ -545,9 +545,9 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
FlashLlamaLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.layers.{layer_id}"
|
||||
f"{prefix}.layers.{layer_id}"
|
||||
if prefix
|
||||
else f"model.layers.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
@ -560,9 +560,9 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
FlashLlamaLayer(
|
||||
index=last_layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{last_layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.layers.{last_layer_id}"
|
||||
f"{prefix}.layers.{last_layer_id}"
|
||||
if prefix
|
||||
else f"model.layers.{last_layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
@ -570,7 +570,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm" if not prefix else f"{prefix}.norm",
|
||||
prefix=f"{prefix}.norm" if prefix else "model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
@ -630,11 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
if config.model_type == "mllama_text_model":
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
with no_fp8(weights):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.embed_tokens"
|
||||
),
|
||||
prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
@ -648,6 +649,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
if embedding_multiplier is not None:
|
||||
self.embed_tokens.weight.data *= embedding_multiplier
|
||||
|
||||
if config.model_type == "mllama_text_model":
|
||||
prefix = prefix.replace(".model", "")
|
||||
suffix = f"{prefix}.{suffix}"
|
||||
|
||||
if config.model_type == "granite":
|
||||
suffix = f"{prefix}.{suffix}"
|
||||
|
||||
with no_fp8(weights):
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
|
@ -167,10 +167,6 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
if config.model_type == "idefics3":
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
@ -290,8 +286,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
else:
|
||||
image_inputs = None
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
image_id = 0
|
||||
for r in requests:
|
||||
full_text = ""
|
||||
@ -306,16 +302,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_special_tokens=not config.model_type == "paligemma",
|
||||
)["input_ids"]
|
||||
input_ids = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
)["input_ids"]
|
||||
max_length = max(max_length, len(input_ids))
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user