fix: adjust FlashLlamaModel prefix logic

This commit is contained in:
drbh 2025-01-08 13:49:11 +00:00
parent 78004db1e6
commit daa397c515
2 changed files with 6 additions and 23 deletions

View File

@ -507,7 +507,6 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
base_model = "" if prefix.endswith("text_model") else ".model"
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
@ -516,11 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=(
"model.layers.0"
if not prefix
else f"{prefix}{base_model}.layers.0"
),
prefix=f"{prefix}.layers.0",
config=config,
weights=weights,
)
@ -536,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{layer_id}"
),
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
)
@ -549,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{layer_id}"
),
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
)
@ -564,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{last_layer_id}"
),
prefix=(f"{prefix}.layers.{last_layer_id}"),
config=config,
weights=weights,
)
)
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm",
prefix=f"{prefix}.norm",
weights=weights,
eps=config.rms_norm_eps,
)

View File

@ -1288,7 +1288,7 @@ class FlashCausalLM(Model):
weights_loader=weights_loader,
)
prefix = ""
prefix = None
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)