diff --git a/server/Makefile b/server/Makefile index 44cf7c5c..e8b0364e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,4 @@ -transformers_commit := 517563354a3226ecfc3dca6e7a38012668d7156a +transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef gen-server: # Compile protos diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0035c1c6..3e2f5c66 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded +from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded __all__ = [ @@ -19,7 +19,6 @@ __all__ = [ "CausalLM", "Galactica", "GalacticaSharded", - "GPTNeox", "GPTNeoxSharded", "Seq2SeqLM", "SantaCoder", @@ -62,7 +61,7 @@ def get_model( if sharded: return GPTNeoxSharded(model_id, revision, quantize=quantize) else: - return GPTNeox(model_id, revision, quantize=quantize) + return CausalLM(model_id, revision, quantize=quantize) if config.model_type == "t5": if sharded: diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ef3f0260..c979b7bc 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -72,9 +72,7 @@ class CausalLMBatch(Batch): for r in pb.requests: inputs.append(r.inputs) input_lengths.append(r.input_length) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device) - ) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d04a3bce..a90a299e 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) input_lengths.append(r.input_length) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device) - ) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 5e1960f4..8fabefe3 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import List, Optional, Tuple +from typing import List, Optional from accelerate import init_empty_weights from safetensors import safe_open @@ -30,23 +30,7 @@ except Exception as e: HAS_BITS_AND_BYTES = False -class GPTNeox(CausalLM): - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - """Overwrite forward to ignore position_ids""" - - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, outputs.past_key_values - - -class GPTNeoxSharded(GPTNeox): +class GPTNeoxSharded(CausalLM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): @@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, + position_ids=position_ids, past_key_values=past_key_values, use_cache=True, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index bece913a..0f7f4df9 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch): # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device) - ) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index cf6214ce..8e90a59c 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor): prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) - def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]: + def _get_greenlist_ids( + self, input_ids: torch.LongTensor, max_value: int + ) -> list[int]: # seed the rng using the previous tokens/prefix self._seed_rng(input_ids)