mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: adjust FlashLlamaModel prefix logic
This commit is contained in:
parent
78004db1e6
commit
daa397c515
@ -507,7 +507,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
base_model = "" if prefix.endswith("text_model") else ".model"
|
|
||||||
|
|
||||||
# Skip fp8 quant for first and last layers
|
# Skip fp8 quant for first and last layers
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
@ -516,11 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=0,
|
index=0,
|
||||||
prefix=(
|
prefix=f"{prefix}.layers.0",
|
||||||
"model.layers.0"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}{base_model}.layers.0"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -536,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaCrossLayer(
|
FlashLlamaCrossLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
f"model.layers.{layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}{base_model}.layers.{layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -549,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
f"model.layers.{layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}{base_model}.layers.{layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -564,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=last_layer_id,
|
index=last_layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{last_layer_id}"),
|
||||||
f"model.layers.{last_layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}{base_model}.layers.{last_layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm",
|
prefix=f"{prefix}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
@ -1288,7 +1288,7 @@ class FlashCausalLM(Model):
|
|||||||
weights_loader=weights_loader,
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = ""
|
prefix = None
|
||||||
model = model_class(prefix, config, weights)
|
model = model_class(prefix, config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user