fix: run lints

This commit is contained in:
drbh 2024-08-08 01:35:42 +00:00
parent 22d9249c4a
commit e01e1b7ca6
2 changed files with 8 additions and 4 deletions

View File

@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix and prefix + '.'}decoder.embed_positions.weight")
weights.get_tensor(
f"{prefix and prefix + '.'}decoder.embed_positions.weight"
)
)
def forward(
@ -437,7 +439,7 @@ class OPTDecoder(OPTPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
prefix = prefix and prefix + '.'
prefix = prefix and prefix + "."
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}decoder.embed_tokens", weights=weights
@ -757,7 +759,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix=f"{prefix and prefix + '.'}decoder.embed_tokens", weights=weights
config,
prefix=f"{prefix and prefix + '.'}decoder.embed_tokens",
weights=weights,
)
def forward(