From 651403c3250bf440d96f32987fd2fd59e5c38223 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 1 Feb 2023 15:30:37 +0100 Subject: [PATCH] formatting --- server/text_generation/models/gpt_neox.py | 6 +++--- server/text_generation/models/seq2seq_lm.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index d901cae3..3fc5658f 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor, diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 245dca12..1ae266d8 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -418,7 +418,9 @@ class Seq2SeqLM(Model): decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token_id, logprobs = next_token_chooser(decoder_input_ids.view(1, -1), logits) + next_token_id, logprobs = next_token_chooser( + decoder_input_ids.view(1, -1), logits + ) # Append next token to decoder tokens decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])