formatting

This commit is contained in:
OlivierDehaene 2023-02-01 15:30:37 +01:00
parent c25fd1e2e8
commit 651403c325
2 changed files with 6 additions and 4 deletions

View File

@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,

View File

@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(decoder_input_ids.view(1, -1), logits)
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])