diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6347b1a5..61934e84 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,3 +1,5 @@ +import inspect + import torch from dataclasses import dataclass @@ -26,7 +28,7 @@ class CausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor attention_mask: torch.Tensor - position_ids: torch.Tensor + position_ids: Optional[torch.Tensor] past_key_values: Optional[List[Tuple]] # All tokens @@ -62,6 +64,7 @@ class CausalLMBatch(Batch): pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device, + use_position_ids: bool = False, ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -105,11 +108,14 @@ class CausalLMBatch(Batch): ) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) + if use_position_ids: + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + else: + position_ids = None + return cls( batch_id=pb.id, requests=pb.requests, @@ -204,11 +210,12 @@ class CausalLMBatch(Batch): batch_left_offset : -batch.padding_right_offset, ] - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids + if batch.position_ids is not None: + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids for j, past in enumerate(batch.past_key_values): past_keys, past_values = past @@ -328,6 +335,10 @@ class CausalLM(Model): tokenizer=tokenizer, device=device, decode_buffer=decode_buffer ) + self.use_position_ids = "position_ids" in set( + inspect.signature(self.model.forward).parameters.keys() + ) + @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch @@ -494,7 +505,8 @@ class CausalLM(Model): if len(next_batch_keep_indices) != len(batch): # Apply indices to attention mask, past key values and other items that need to be cached next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] - next_batch_position_ids = batch.position_ids[next_batch_keep_indices] + next_batch_position_ids = batch.position_ids[next_batch_keep_indices] \ + if batch.position_ids is not None else None # Force past to be of dim [batch_size, num_heads, ...] for easy indexing next_batch_past_key_values = [ [ @@ -522,7 +534,8 @@ class CausalLM(Model): next_batch_attention_mask[:, -batch.padding_right_offset] = 1 # Update position_ids - next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 + if next_batch_position_ids is not None: + next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 next_batch = CausalLMBatch( batch_id=batch.batch_id, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 08a48553..5b747d2c 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -23,6 +23,7 @@ class Model(ABC): self.all_special_ids = set(tokenizer.all_special_ids) self.device = device self.decode_buffer = decode_buffer + self.use_position_ids = False @property @abstractmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 3e3789bf..9703e127 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -39,8 +39,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ClearCacheResponse() async def Prefill(self, request, context): + kwargs = {"use_position_ids": True} if self.model.use_position_ids else {} batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.device + request.batch, self.model.tokenizer, self.model.device, **kwargs, ) generations, next_batch = self.model.generate_token(batch)