feat: improve prefix for idefics3

This commit is contained in:
drbh 2024-12-18 03:25:22 +00:00
parent c9573ddf28
commit 53b2bea6b9
2 changed files with 32 additions and 30 deletions

View File

@ -515,7 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"),
prefix=f"{prefix}.layers.0" if prefix else "model.layers.0",
config=config,
weights=weights,
)
@ -532,9 +532,9 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaCrossLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.layers.{layer_id}"
f"{prefix}.layers.{layer_id}"
if prefix
else f"model.layers.{layer_id}"
),
config=config,
weights=weights,
@ -545,9 +545,9 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.layers.{layer_id}"
f"{prefix}.layers.{layer_id}"
if prefix
else f"model.layers.{layer_id}"
),
config=config,
weights=weights,
@ -560,9 +560,9 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaLayer(
index=last_layer_id,
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.layers.{last_layer_id}"
f"{prefix}.layers.{last_layer_id}"
if prefix
else f"model.layers.{last_layer_id}"
),
config=config,
weights=weights,
@ -570,7 +570,7 @@ class FlashLlamaModel(torch.nn.Module):
)
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.norm",
prefix=f"{prefix}.norm" if prefix else "model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
@ -630,11 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if config.model_type == "mllama_text_model":
prefix = f"{prefix}.model"
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.embed_tokens"
),
prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
@ -648,6 +649,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
if config.model_type == "mllama_text_model":
prefix = prefix.replace(".model", "")
suffix = f"{prefix}.{suffix}"
if config.model_type == "granite":
suffix = f"{prefix}.{suffix}"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,

View File

@ -167,10 +167,6 @@ def image_text_replacement_fixup(config, text: str) -> str:
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
if config.model_type == "idefics3":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
@ -290,8 +286,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
else:
image_inputs = None
batch_inputs = []
max_truncation = 0
batch_tokenized_inputs = []
max_length = 0
image_id = 0
for r in requests:
full_text = ""
@ -306,16 +302,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs