From 9285f67be543bc35c3caebbbed535b8e51f0645d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 30 Jan 2023 15:15:34 +0100 Subject: [PATCH] black --- server/text_generation/models/bloom.py | 4 +++- server/text_generation/models/causal_lm.py | 7 ++++-- server/text_generation/models/galactica.py | 4 +++- server/text_generation/models/santacoder.py | 16 ++++++++------ server/text_generation/models/seq2seq_lm.py | 2 +- server/text_generation/models/types.py | 2 +- server/text_generation/utils.py | 24 ++++++++++++--------- 7 files changed, 37 insertions(+), 22 deletions(-) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 463c0406..35f46bc2 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM): if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor - def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index ece01271..ccd4c3ba 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -296,7 +296,10 @@ class CausalLM(Model): ) with context_manager(): logits, past = self.forward( - batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values + batch.input_ids, + batch.attention_mask, + batch.position_ids, + batch.past_key_values, ) # List of indices to cache @@ -389,7 +392,7 @@ class CausalLM(Model): token_ids=token_ids.squeeze(1).tolist(), logprobs=logprobs, reason=reason, - seed=seed + seed=seed, ) ) # add to the next batch diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index b56fc748..d047ccb6 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -333,7 +333,9 @@ class GalacticaSharded(Galactica): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py index cf9f450c..4b898ab9 100644 --- a/server/text_generation/models/santacoder.py +++ b/server/text_generation/models/santacoder.py @@ -39,12 +39,16 @@ class SantaCoder(CausalLM): } ) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=dtype, - load_in_8bit=quantize, - trust_remote_code=True, # required - ).to(device).eval() + self.model = ( + AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + load_in_8bit=quantize, + trust_remote_code=True, # required + ) + .to(device) + .eval() + ) super(CausalLM, self).__init__( tokenizer=tokenizer, diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index d7f5d0bb..f965ea88 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -468,7 +468,7 @@ class Seq2SeqLM(Model): token_ids=token_ids.tolist(), logprobs=logprobs, reason=reason, - seed=seed + seed=seed, ) ) # add to the next batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index c97c550b..4ee3cb32 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -50,5 +50,5 @@ class GeneratedText: token_ids=self.token_ids, logprobs=self.logprobs, finish_reason=self.reason, - seed=self.seed + seed=self.seed, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 3001185d..1d087a42 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -33,7 +33,9 @@ class Sampling: def __call__(self, logits): probs = torch.nn.functional.softmax(logits, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator).squeeze(1) + next_tokens = torch.multinomial( + probs, num_samples=1, generator=self.generator + ).squeeze(1) return next_tokens @property @@ -47,7 +49,9 @@ class Greedy: class NextTokenChooser: - def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None): + def __init__( + self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None + ): warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` @@ -84,7 +88,7 @@ class NextTokenChooser: top_k=pb.top_k, top_p=pb.top_p, do_sample=pb.do_sample, - seed=seed + seed=seed, ) @@ -100,10 +104,10 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens=20, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens=20, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -128,9 +132,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences