mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve prefix for idefics3
This commit is contained in:
parent
c9573ddf28
commit
53b2bea6b9
@ -515,7 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=0,
|
index=0,
|
||||||
prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"),
|
prefix=f"{prefix}.layers.0" if prefix else "model.layers.0",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -532,9 +532,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
FlashLlamaCrossLayer(
|
FlashLlamaCrossLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{layer_id}"
|
f"{prefix}.layers.{layer_id}"
|
||||||
if not prefix
|
if prefix
|
||||||
else f"{prefix}.layers.{layer_id}"
|
else f"model.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -545,9 +545,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{layer_id}"
|
f"{prefix}.layers.{layer_id}"
|
||||||
if not prefix
|
if prefix
|
||||||
else f"{prefix}.layers.{layer_id}"
|
else f"model.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -560,9 +560,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=last_layer_id,
|
index=last_layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{last_layer_id}"
|
f"{prefix}.layers.{last_layer_id}"
|
||||||
if not prefix
|
if prefix
|
||||||
else f"{prefix}.layers.{last_layer_id}"
|
else f"model.layers.{last_layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -570,7 +570,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.norm",
|
prefix=f"{prefix}.norm" if prefix else "model.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
@ -630,11 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if config.model_type == "mllama_text_model":
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"),
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.embed_tokens"
|
|
||||||
),
|
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(prefix, config, weights)
|
||||||
@ -648,6 +649,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
if embedding_multiplier is not None:
|
if embedding_multiplier is not None:
|
||||||
self.embed_tokens.weight.data *= embedding_multiplier
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
|
||||||
|
if config.model_type == "mllama_text_model":
|
||||||
|
prefix = prefix.replace(".model", "")
|
||||||
|
suffix = f"{prefix}.{suffix}"
|
||||||
|
|
||||||
|
if config.model_type == "granite":
|
||||||
|
suffix = f"{prefix}.{suffix}"
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -167,10 +167,6 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
|||||||
return text.replace(
|
return text.replace(
|
||||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||||
)
|
)
|
||||||
if config.model_type == "idefics3":
|
|
||||||
return text.replace(
|
|
||||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
|
||||||
)
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@ -290,8 +286,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
batch_inputs = []
|
batch_tokenized_inputs = []
|
||||||
max_truncation = 0
|
max_length = 0
|
||||||
image_id = 0
|
image_id = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
full_text = ""
|
full_text = ""
|
||||||
@ -306,16 +302,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_id += 1
|
image_id += 1
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
|
input_ids = tokenizer(
|
||||||
batch_inputs.append(full_text)
|
full_text,
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs,
|
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=r.truncate,
|
||||||
add_special_tokens=not config.model_type == "paligemma",
|
add_special_tokens=r.add_special_tokens,
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
max_length = max(max_length, len(input_ids))
|
||||||
|
batch_tokenized_inputs.append(input_ids)
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user