diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index d901cae3..3fc5658f 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor, diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 245dca12..1ae266d8 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -418,7 +418,9 @@ class Seq2SeqLM(Model): decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token_id, logprobs = next_token_chooser(decoder_input_ids.view(1, -1), logits) + next_token_id, logprobs = next_token_chooser( + decoder_input_ids.view(1, -1), logits + ) # Append next token to decoder tokens decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])