diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f2a472dc..2fdebf58 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -114,7 +114,9 @@ def get_model( santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder return santacoder_cls(model_id, revision, quantize=quantize) - config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=True) + config = AutoConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=True + ) model_type = config.model_type if model_type == "bloom": diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 71b19b96..ec0ed18e 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -18,7 +18,9 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria -from text_generation_server.utils.tokens_heterogeneous import HeterogeneousNextTokenChooser +from text_generation_server.utils.tokens_heterogeneous import ( + HeterogeneousNextTokenChooser, +) tracer = trace.get_tracer(__name__) @@ -32,7 +34,9 @@ class VectorizedCausalLMBatch(Batch): # Decoder values attention_mask: torch.Tensor position_ids: torch.Tensor - past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]] + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]] + ] # All tokens input_ids: torch.Tensor @@ -52,11 +56,11 @@ class VectorizedCausalLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int - kv_cache_seq_dim:int=2 + kv_cache_seq_dim: int = 2 # TODO: Get from requests (should these be lists?) - details:bool=os.environ.get("RETURN_DETAILS") is not None - generate_stream:bool=os.environ.get("GENERATE_STREAM") is not None + details: bool = os.environ.get("RETURN_DETAILS") is not None + generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( @@ -74,15 +78,22 @@ class VectorizedCausalLMBatch(Batch): device: torch.device, ) -> "VectorizedCausalLMBatch": inputs = [r.inputs for r in pb.requests] - offsets = [None]*len(inputs) - token_offsets = [None]*len(inputs) - requests_idx_mapping = {r.id:i for i, r in enumerate(pb.requests)} + offsets = [None] * len(inputs) + token_offsets = [None] * len(inputs) + requests_idx_mapping = {r.id: i for i, r in enumerate(pb.requests)} # Parse batch - stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests] - max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias) + stopping_criterias = [ + StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) + for r in pb.requests + ] + max_new_tokens = ( + stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias + ) - next_token_chooser= HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], device) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + [r.parameters for r in pb.requests], device + ) tokenized_inputs = tokenizer( inputs, @@ -96,7 +107,7 @@ class VectorizedCausalLMBatch(Batch): input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max().item() - input_shape=(pb.size, max_input_length + max(max_new_tokens)) + input_shape = (pb.size, max_input_length + max(max_new_tokens)) # Allocate maximum attention_mask attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) @@ -112,7 +123,10 @@ class VectorizedCausalLMBatch(Batch): max_tokens = len(inputs) * max_input_length + sum(max_new_tokens) - generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias) + generate_stream = cls.generate_stream or any( + stopping_criteria.stop_sequence_criterias + for stopping_criteria in stopping_criterias + ) return cls( batch_id=pb.id, @@ -133,7 +147,9 @@ class VectorizedCausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]: + def filter( + self, requests: List[generate_pb2.Request] + ) -> Optional["VectorizedCausalLMBatch"]: if len(requests) == 0: raise ValueError("Batch must have at least one request") if len(requests) == len(self): @@ -143,70 +159,108 @@ class VectorizedCausalLMBatch(Batch): keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests] # New values after filtering - self.requests_idx_mapping={r.id:i for i, r in enumerate(self.requests)} - self.input_lengths=[self.input_lengths[i] for i in keep_indices] + self.requests_idx_mapping = {r.id: i for i, r in enumerate(self.requests)} + self.input_lengths = [self.input_lengths[i] for i in keep_indices] self.offsets = [self.offsets[i] for i in keep_indices] self.token_offsets = [self.token_offsets[i] for i in keep_indices] - self.next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in self.requests], self.input_ids.device) + self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + [r.parameters for r in self.requests], self.input_ids.device + ) self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices] - remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias] + remaining_decode_tokens = [ + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + for stopping_criteria in self.stopping_criterias + ] # Select the remaining indices and remove unnecessary padding - max_input_length=max(self.input_lengths) - sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens)) - self.max_input_length=max_input_length - self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens) + max_input_length = max(self.input_lengths) + sequence_slice = slice( + self.max_input_length - max_input_length, + self.max_input_length + max(remaining_decode_tokens), + ) + self.max_input_length = max_input_length + self.max_tokens = len(self.requests) * self.max_input_length + sum( + remaining_decode_tokens + ) - self.input_ids = self.input_ids[keep_indices,sequence_slice] - self.position_ids = self.position_ids[keep_indices,sequence_slice] - self.attention_mask = self.attention_mask[keep_indices,sequence_slice] + self.input_ids = self.input_ids[keep_indices, sequence_slice] + self.position_ids = self.position_ids[keep_indices, sequence_slice] + self.attention_mask = self.attention_mask[keep_indices, sequence_slice] tensors_to_update = [] if self.past_key_values is not None: - if not isinstance(self.past_key_values,(list, tuple)): - raise NotImplementedError(f"Unsupported kv cache type: {type(self.past_key_values)}") + if not isinstance(self.past_key_values, (list, tuple)): + raise NotImplementedError( + f"Unsupported kv cache type: {type(self.past_key_values)}" + ) for layer_kv in self.past_key_values: if isinstance(layer_kv, torch.Tensor): tensors_to_update.append(layer_kv) - elif isinstance(layer_kv,(list, tuple)): + elif isinstance(layer_kv, (list, tuple)): tensors_to_update.extend(layer_kv) else: - raise NotImplementedError(f"Unsupported layer kv cache type: {type(layer_kv)}") + raise NotImplementedError( + f"Unsupported layer kv cache type: {type(layer_kv)}" + ) - kv_cache_slice=[keep_indices, *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), sequence_slice] + kv_cache_slice = [ + keep_indices, + *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), + sequence_slice, + ] for tensor in tensors_to_update: # Update tensors in-place to allow incremental garbage collection - tensors_to_update.data=tensor[kv_cache_slice] + tensors_to_update.data = tensor[kv_cache_slice] return self @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch": - if len(batches)==0: + def concatenate( + cls, batches: List["VectorizedCausalLMBatch"] + ) -> "VectorizedCausalLMBatch": + if len(batches) == 0: raise ValueError("Cannot concatenate empty list.") - requests=[request for batch in batches for request in batch.requests] - batch_sizes=[len(batch.requests) for batch in batches] - batch_size=sum(batch_sizes) + requests = [request for batch in batches for request in batch.requests] + batch_sizes = [len(batch.requests) for batch in batches] + batch_size = sum(batch_sizes) - end_indices=torch.tensor(batch_sizes).cumsum(0).tolist() - start_indices=[0]+end_indices[:-1] + end_indices = torch.tensor(batch_sizes).cumsum(0).tolist() + start_indices = [0] + end_indices[:-1] input_lengths = [length for batch in batches for length in batch.input_lengths] offsets = [offset for batch in batches for offset in batch.offsets] - token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets] - next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in requests], batches[0].input_ids.device) - stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias] + token_offsets = [ + token_offset for batch in batches for token_offset in batch.token_offsets + ] + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + [r.parameters for r in requests], batches[0].input_ids.device + ) + stopping_criterias = [ + stopping_criteria + for batch in batches + for stopping_criteria in batch.stopping_criterias + ] - requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()} + requests_idx_mapping = { + k: v + start_index + for batch, start_index in zip(batches, start_indices) + for k, v in batch.requests_idx_mapping.items() + } - max_input_length=max(input_lengths) - left_indices=[max_input_length-batch.max_input_length for batch in batches] + max_input_length = max(input_lengths) + left_indices = [max_input_length - batch.max_input_length for batch in batches] - input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches)) - device=batches[0].input_ids.device + input_shape = ( + batch_size, + max_input_length + + max( + batch.input_ids.size(1) - batch.max_input_length for batch in batches + ), + ) + device = batches[0].input_ids.device # Allocate maximum attention_mask attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) @@ -217,56 +271,84 @@ class VectorizedCausalLMBatch(Batch): # TODO : only needed for prefill input_ids[:, :max_input_length].fill_(0) - for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices): - attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length]) - input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length]) + for batch, start_index, end_index, left_index in zip( + batches, start_indices, end_indices, left_indices + ): + attention_mask[start_index:end_index, left_index:max_input_length].copy_( + batch.attention_mask[:, : batch.max_input_length] + ) + input_ids[start_index:end_index, left_index:max_input_length].copy_( + batch.input_ids[:, : batch.max_input_length] + ) position_ids = attention_mask.cumsum(-1).sub_(1) position_ids[:, :max_input_length].relu_() - max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches) + max_tokens = sum( + batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) + for batch in batches + ) - kv_formats=None + kv_formats = None for batch in batches: if batch.past_key_values is None: raise ValueError("Only concatenate prefilled batches") if not isinstance(batch.past_key_values, (list, tuple)): - raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}") + raise NotImplementedError( + f"Unsupported kv cache type: {type(batch.past_key_values)}" + ) if kv_formats is None: - num_layers=len(batch.past_key_values) - if num_layers==0: + num_layers = len(batch.past_key_values) + if num_layers == 0: raise ValueError("Empty KV cache") - kv_formats = [0]*num_layers - elif len(batch.past_key_values)!=len(kv_formats): + kv_formats = [0] * num_layers + elif len(batch.past_key_values) != len(kv_formats): raise ValueError("Num layers is not constant") for i, layer_kv in enumerate(batch.past_key_values): if isinstance(layer_kv, (list, tuple)): kv_format = len(layer_kv) else: - kv_format=None - if kv_formats[i]==0: - if kv_format==0: + kv_format = None + if kv_formats[i] == 0: + if kv_format == 0: raise ValueError("Empty KV cache") - kv_formats[i]=kv_format - elif kv_formats[i]!=kv_format: + kv_formats[i] = kv_format + elif kv_formats[i] != kv_format: raise ValueError("Incompatible KV cache format.") - kv_cache_seq_dim=batches[0].kv_cache_seq_dim - past_key_values=[] + kv_cache_seq_dim = batches[0].kv_cache_seq_dim + past_key_values = [] for i, kv_format in enumerate(kv_formats): for j in range(1 if kv_format is None else kv_format): - tensors_to_merge=[batch.past_key_values[i] for batch in batches] + tensors_to_merge = [batch.past_key_values[i] for batch in batches] # Generally `max_input_length`, unless the model allocates more than needed. - right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)] - combined_shape=[batch_size]+list(tensors_to_merge[0].shape[1:]) - combined_shape[kv_cache_seq_dim]=max(right_indices) + right_indices = [ + left_index + tensor.size(kv_cache_seq_dim) + for tensor, left_index in zip(tensors_to_merge, left_indices) + ] + combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:]) + combined_shape[kv_cache_seq_dim] = max(right_indices) # Set to zero to avoid propagating nans in padded values. - kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device) - for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices): - kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor) + kv_cache = torch.zeros( + combined_shape, dtype=tensors_to_merge[0].dtype, device=device + ) + for tensor, start_index, end_index, left_index, right_index in zip( + tensors_to_merge, + start_indices, + end_indices, + left_indices, + right_indices, + ): + kv_cache[ + [ + slice(start_index, end_index), + *(slice(None) for _ in range(1, kv_cache_seq_dim)), + slice(left_index, right_index), + ] + ].copy_(tensor) if kv_format is None: past_key_values.append(kv_cache) - elif j==0: + elif j == 0: past_key_values.append([kv_cache]) else: past_key_values[-1].append(kv_cache) @@ -350,58 +432,75 @@ class VectorizedCausalLM(Model): def generate_token( self, batch: VectorizedCausalLMBatch ) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]: - key_length=batch.max_input_length - if key_length>batch.input_ids.size(1): + key_length = batch.max_input_length + if key_length > batch.input_ids.size(1): raise RuntimeError("Cannot generate more than `max_tokens`.") - query_length=key_length if batch.past_key_values is None else 1 - input_ids=batch.input_ids[:, key_length-query_length: key_length] + query_length = key_length if batch.past_key_values is None else 1 + input_ids = batch.input_ids[:, key_length - query_length : key_length] outputs = self.model.forward( input_ids=input_ids, - attention_mask=batch.attention_mask[:, : key_length], - position_ids=batch.position_ids[:, key_length-query_length: key_length], + attention_mask=batch.attention_mask[:, :key_length], + position_ids=batch.position_ids[:, key_length - query_length : key_length], past_key_values=batch.past_key_values, ) # TODO: Post-processing - next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details) + next_token_ids, logprobs = batch.next_token_chooser( + input_ids, outputs.logits, batch.details + ) if batch.generate_stream: # TODO: self.decode_token, offsets? - next_token_texts=self.tokenizer.batch_decode(next_token_ids.tolist()) + next_token_texts = self.tokenizer.batch_decode(next_token_ids.tolist()) if batch.details: - token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist() - if query_length>1: - prefill_token_ids=batch.input_ids[:, :key_length].tolist() - prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist() - prefill_tokens=[] - for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths): - prefill_token_ids_=prefill_token_ids_[-input_length:] + token_logprobs = ( + logprobs[:, -1, :] + .gather(1, next_token_ids.unsqueeze(1)) + .squeeze(1) + .tolist() + ) + if query_length > 1: + prefill_token_ids = batch.input_ids[:, :key_length].tolist() + prefill_logprobs = ( + logprobs.gather(2, batch.input_ids[:, 1:key_length, None]) + .squeeze(2) + .tolist() + ) + prefill_tokens = [] + for prefill_token_ids_, prefill_logprobs_, input_length in zip( + prefill_token_ids, prefill_logprobs, batch.input_lengths + ): + prefill_token_ids_ = prefill_token_ids_[-input_length:] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids_, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens.append(PrefillTokens( - prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts - )) + prefill_tokens.append( + PrefillTokens( + prefill_token_ids_, + [math.nan, *prefill_logprobs_], + prefill_texts, + ) + ) # Update batch # TODO: Why do we need all input ids? batch.input_ids[:, key_length].copy_(next_token_ids) - batch.past_key_values=outputs.past_key_values - batch.input_lengths=[length+1 for length in batch.input_lengths] - batch.max_input_length+=1 + batch.past_key_values = outputs.past_key_values + batch.input_lengths = [length + 1 for length in batch.input_lengths] + batch.max_input_length += 1 # TODO: Vectorize some of this? generations: List[Generation] = [] - next_batch=None + next_batch = None for i, next_token_id in enumerate(next_token_ids): - next_token_text=next_token_texts[i] if batch.generate_stream else "" - stopping_criterias=batch.stopping_criterias[i] + next_token_text = next_token_texts[i] if batch.generate_stream else "" + stopping_criterias = batch.stopping_criterias[i] stop, reason = stopping_criterias( next_token_id, next_token_text, @@ -421,10 +520,9 @@ class VectorizedCausalLM(Model): generated_text = None next_batch = batch - generation = Generation( batch.requests[i].id, - prefill_tokens[i] if batch.details and query_length>1 else None, + prefill_tokens[i] if batch.details and query_length > 1 else None, next_token_id, token_logprobs[i] if batch.details else 0.0, next_token_text, @@ -435,4 +533,3 @@ class VectorizedCausalLM(Model): generations.append(generation) return generations, next_batch - diff --git a/server/text_generation_server/utils/tokens_heterogeneous.py b/server/text_generation_server/utils/tokens_heterogeneous.py index 96d6f480..599bf0cb 100644 --- a/server/text_generation_server/utils/tokens_heterogeneous.py +++ b/server/text_generation_server/utils/tokens_heterogeneous.py @@ -24,8 +24,10 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. """ - def __init__(self, penalty: List[float], device:torch.device): - self.penalty = torch.tensor(penalty, dtype=torch.float32, device=device).unsqueeze(1) + def __init__(self, penalty: List[float], device: torch.device): + self.penalty = torch.tensor( + penalty, dtype=torch.float32, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) @@ -36,6 +38,7 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): scores.scatter_(1, input_ids, score) return scores + class HeterogeneousTemperatureLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution). @@ -47,13 +50,16 @@ class HeterogeneousTemperatureLogitsWarper(LogitsWarper): The value used to module the logits distribution. """ - def __init__(self, temperature: List[float], device:torch.device): - self.temperature = torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1) + def __init__(self, temperature: List[float], device: torch.device): + self.temperature = torch.tensor( + temperature, dtype=torch.float32, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: scores.div_(self.temperature) return scores + class HeterogeneousTopPLogitsWarper(LogitsWarper): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. @@ -70,8 +76,16 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): Minimum number of tokens that cannot be filtered. """ - def __init__(self, top_p: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1): - self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(1) + def __init__( + self, + top_p: List[float], + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze( + 1 + ) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @@ -86,10 +100,13 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) scores.masked_fill_(indices_to_remove, self.filter_value) return scores + class HeterogeneousTopKLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. @@ -105,10 +122,20 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): Minimum number of tokens that cannot be filtered. """ - def __init__(self, top_k: List[int], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1): + def __init__( + self, + top_k: List[int], + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): self.max_top_k = max(top_k) - self.top_k = torch.tensor([max(x - 1, min_tokens_to_keep-1) for x in top_k], dtype=torch.int64,device=device).unsqueeze(1) - zeros=[x == 0 for x in top_k] + self.top_k = torch.tensor( + [max(x - 1, min_tokens_to_keep - 1) for x in top_k], + dtype=torch.int64, + device=device, + ).unsqueeze(1) + zeros = [x == 0 for x in top_k] if any(zeros): self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device) else: @@ -116,13 +143,13 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): self.filter_value = filter_value def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - if scores.size(-1)>self.max_top_k: # Safety check - max_top_k=scores.size(-1) - top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed. + if scores.size(-1) > self.max_top_k: # Safety check + max_top_k = scores.size(-1) + top_k = torch.clamp_max(self.top_k, max_top_k) # Run only if needed. else: - max_top_k=self.max_top_k - top_k=self.top_k - kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) + max_top_k = self.max_top_k + top_k = self.top_k + kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) if self.top_k_mask is not None: kth_scores.masked_fill_(self.top_k_mask, self.filter_value) # Remove all tokens with a probability less than the last token of the top-k @@ -147,7 +174,13 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): Minimum number of tokens that cannot be filtered. """ - def __init__(self, mass: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1): + def __init__( + self, + mass: List[float], + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): self.filter_value = filter_value self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1) self.min_tokens_to_keep = min_tokens_to_keep @@ -167,11 +200,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): # Remove tokens with cumulative mass above the threshold last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind[last_ind < 0] = 0 - sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) scores = scores.masked_fill_(indices_to_remove, self.filter_value) return scores @@ -181,103 +218,113 @@ class HeterogeneousSampling: r""" Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. """ - def __init__(self, do_sample:List[bool], seeds: List[int], device: torch.device): + + def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): self.seeds = seeds - self.greedy=Greedy() + self.greedy = Greedy() # TODO: Most seeds are ignored - self.sampling=Sampling(seeds[0], device) - self.do_sample=torch.tensor(do_sample, dtype=torch.bool, device=device) + self.sampling = Sampling(seeds[0], device) + self.do_sample = torch.tensor(do_sample, dtype=torch.bool, device=device) def __call__(self, logits): return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits)) + class HeterogeneousNextTokenChooser: def __init__( self, *, - batch_size:int, - device:torch.device, - watermark:Optional[Union[bool,List[Optional[bool]]]]=None, - temperature:Optional[Union[float,List[Optional[float]]]]=None, - repetition_penalty:Optional[Union[float,List[Optional[float]]]]=None, - top_k:Optional[Union[int,List[Optional[int]]]]=None, - top_p:Optional[Union[float,List[Optional[float]]]]=None, - typical_p:Optional[Union[float,List[Optional[float]]]]=None, - do_sample:Optional[Union[bool,List[Optional[bool]]]]=None, - seeds:Optional[Union[int,List[Optional[int]]]]=None, + batch_size: int, + device: torch.device, + watermark: Optional[Union[bool, List[Optional[bool]]]] = None, + temperature: Optional[Union[float, List[Optional[float]]]] = None, + repetition_penalty: Optional[Union[float, List[Optional[float]]]] = None, + top_k: Optional[Union[int, List[Optional[int]]]] = None, + top_p: Optional[Union[float, List[Optional[float]]]] = None, + typical_p: Optional[Union[float, List[Optional[float]]]] = None, + do_sample: Optional[Union[bool, List[Optional[bool]]]] = None, + seeds: Optional[Union[int, List[Optional[int]]]] = None, ): # TODO: Most seeds are ignored - seeds=self._standardize(seeds, batch_size, 0) - do_sample=self._standardize(do_sample, batch_size, False) + seeds = self._standardize(seeds, batch_size, 0) + do_sample = self._standardize(do_sample, batch_size, False) warpers = LogitsProcessorList() - watermark=self._standardize(watermark, batch_size, False) + watermark = self._standardize(watermark, batch_size, False) if any(watermark): raise NotImplementedError("Watermarking not implemented") - repetition_penalty=self._standardize(repetition_penalty, batch_size, 1.0) - if any([x!=1.0 for x in repetition_penalty]): - warpers.append(HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, device)) + repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0) + if any([x != 1.0 for x in repetition_penalty]): + warpers.append( + HeterogeneousRepetitionPenaltyLogitsProcessor( + repetition_penalty, device + ) + ) - temperature=self._standardize(temperature, batch_size, 1.0) - if any([x!=1.0 for x in temperature]): - do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)] + temperature = self._standardize(temperature, batch_size, 1.0) + if any([x != 1.0 for x in temperature]): + do_sample = [ + sample or x != 1.0 for x, sample in zip(temperature, do_sample) + ] warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device)) - top_k=self._standardize(top_k, batch_size, 0) - n_top_k=sum([x!=0 for x in top_k]) - if n_top_k>0: - do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] + top_k = self._standardize(top_k, batch_size, 0) + n_top_k = sum([x != 0 for x in top_k]) + if n_top_k > 0: + do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) - top_p=self._standardize(top_p, batch_size, 1.0) - if any([x<1.0 for x in top_p]): - do_sample=[sample or x<1.0 for x, sample in zip(top_p, do_sample)] + top_p = self._standardize(top_p, batch_size, 1.0) + if any([x < 1.0 for x in top_p]): + do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] warpers.append(HeterogeneousTopPLogitsWarper(top_p, device)) - typical_p=self._standardize(typical_p, batch_size, 1.0) - if any([x<1.0 for x in typical_p]): - do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)] + typical_p = self._standardize(typical_p, batch_size, 1.0) + if any([x < 1.0 for x in typical_p]): + do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device)) - self.warpers=warpers + self.warpers = warpers - num_do_sample=sum(do_sample) - if num_do_sample==0: - self.choice=Greedy() - elif num_do_sample