diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d4edb8f7..8caf7967 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -14,6 +14,7 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + FLASH_NEOX = torch.cuda.is_available() except ImportError: FLASH_NEOX = False @@ -48,7 +49,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: if "facebook/galactica" in model_id: if sharded: diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 2d3c6d8e..7be4708b 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -84,10 +84,9 @@ class FlashNeoXBatch(Batch): # Parse batch for r in pb.requests: - tokenized_input = ( - tokenizer(r.inputs, return_tensors="pt")["input_ids"] - .squeeze(0) - ) + tokenized_input = tokenizer(r.inputs, return_tensors="pt")[ + "input_ids" + ].squeeze(0) input_ids.append(tokenized_input) all_input_ids.append(tokenized_input.tolist()) @@ -96,9 +95,7 @@ class FlashNeoXBatch(Batch): input_lengths.append(input_length) # Position ids - position_ids.append( - torch.arange(0, input_length, dtype=torch.int32) - ) + position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -188,7 +185,6 @@ class FlashNeoXBatch(Batch): stopping_criterias=stopping_criterias, ) - def __len__(self): return len(self.requests) @@ -319,9 +315,7 @@ class FlashNeoX(Model): logits = out[i].unsqueeze(0) # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids, logits - ) + next_token_id, logprobs = next_token_chooser(all_input_ids, logits) # Copy to cpu to avoid other copies when indexing and calling .item() next_token_id = next_token_id.to("cpu", non_blocking=True) logprobs = logprobs.to("cpu") @@ -435,9 +429,7 @@ class FlashNeoX(Model): next_batch_position_ids = torch.tensor( next_batch_position_ids, dtype=torch.int32 ) - next_batch_cu_seqlens = torch.tensor( - next_batch_cu_seqlens, dtype=torch.int32 - ) + next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) if len(next_batch_keep_indices) > 1: next_batch_input_ids = torch.concat(next_batch_input_ids) next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 86c777a5..1850561d 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -39,7 +39,8 @@ class WatermarkLogitsProcessor(LogitsProcessor): def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): if isinstance(input_ids, list): - assert (len(input_ids) >= 1 + assert ( + len(input_ids) >= 1 ), "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1] else: @@ -52,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor): self.rng.manual_seed(self.hash_key * prev_token) def _get_greenlist_ids( - self, input_ids: Union[List[int], torch.LongTensor], max_value: int, device: torch.device + self, + input_ids: Union[List[int], torch.LongTensor], + max_value: int, + device: torch.device, ) -> List[int]: # seed the rng using the previous tokens/prefix self._seed_rng(input_ids) greenlist_size = int(max_value * self.gamma) - vocab_permutation = torch.randperm( - max_value, device=device, generator=self.rng - ) + vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng) greenlist_ids = vocab_permutation[:greenlist_size] return greenlist_ids @@ -83,7 +85,9 @@ class WatermarkLogitsProcessor(LogitsProcessor): def __call__( self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor ) -> torch.FloatTensor: - greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device) + greenlist_ids = self._get_greenlist_ids( + input_ids, scores.shape[-1], scores.device + ) green_tokens_mask = self._calc_greenlist_mask( scores=scores, greenlist_token_ids=greenlist_ids )