fix: small syntax tweak

This commit is contained in:
drbh 2024-08-08 02:10:03 +00:00
parent e01e1b7ca6
commit f98aaeeb27

View File

@ -99,7 +99,7 @@ class OPTLearnedPositionalEmbedding(nn.Module):
self.offset = 2 self.offset = 2
self.weight = nn.Parameter( self.weight = nn.Parameter(
weights.get_tensor( weights.get_tensor(
f"{prefix and prefix + '.'}decoder.embed_positions.weight" f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
) )
) )
@ -317,7 +317,7 @@ class OPTDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"{prefix and prefix + '.'}decoder.layers.{layer_id}" prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -439,7 +439,7 @@ class OPTDecoder(OPTPreTrainedModel):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
prefix = prefix and prefix + "." prefix = prefix + "." if prefix else ""
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}decoder.embed_tokens", weights=weights prefix=f"{prefix}decoder.embed_tokens", weights=weights
@ -760,7 +760,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=f"{prefix and prefix + '.'}decoder.embed_tokens", prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
weights=weights, weights=weights,
) )