fix: adjust FlashLlamaModel prefix logic

This commit is contained in:
drbh 2025-01-08 13:49:11 +00:00
parent 78004db1e6
commit daa397c515
2 changed files with 6 additions and 23 deletions

View File

@ -507,7 +507,6 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
base_model = "" if prefix.endswith("text_model") else ".model"
# Skip fp8 quant for first and last layers # Skip fp8 quant for first and last layers
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
@ -516,11 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=0, index=0,
prefix=( prefix=f"{prefix}.layers.0",
"model.layers.0"
if not prefix
else f"{prefix}{base_model}.layers.0"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -536,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaCrossLayer( FlashLlamaCrossLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -549,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -564,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=last_layer_id, index=last_layer_id,
prefix=( prefix=(f"{prefix}.layers.{last_layer_id}"),
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{last_layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm", prefix=f"{prefix}.norm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )

View File

@ -1288,7 +1288,7 @@ class FlashCausalLM(Model):
weights_loader=weights_loader, weights_loader=weights_loader,
) )
prefix = "" prefix = None
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)