diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 72232d64..7547623c 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -14,7 +14,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import StoppingCriteria tracer = trace.get_tracer(__name__) @@ -48,6 +48,8 @@ class VectorizedCausalLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int + kv_cache_seq_dim:int=2 + def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( id=self.batch_id, @@ -64,7 +66,6 @@ class VectorizedCausalLMBatch(Batch): device: torch.device, ) -> "VectorizedCausalLMBatch": inputs = [] - next_token_choosers = [] stopping_criterias = [] offsets = [] token_offsets = [] @@ -75,14 +76,10 @@ class VectorizedCausalLMBatch(Batch): padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - next_token_chooser=NextTokenChooser.from_pb(r.parameters, device) - # TODO: Implement - assert len(next_token_chooser.warpers)==0 requests_idx_mapping[r.id] = i inputs.append(r.inputs) offsets.append(None) token_offsets.append(None) - next_token_choosers.append(next_token_chooser) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -134,7 +131,7 @@ class VectorizedCausalLMBatch(Batch): input_lengths=input_lengths.tolist(), offsets=offsets, token_offsets=token_offsets, - next_token_chooser=next_token_choosers, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), max_tokens=max_tokens, @@ -142,7 +139,52 @@ class VectorizedCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]: - raise NotImplementedError() + if len(requests) == 0: + raise ValueError("Batch must have at least one request") + if len(requests) == len(self): + return self + + self.requests = requests + keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests] + + # New values after filtering + self.requests_idx_mapping={r.id:i for i, r in enumerate(self.requests)} + self.input_lengths=[self.input_lengths[i] for i in keep_indices] + self.offsets = [self.offsets[i] for i in keep_indices] + self.token_offsets = [self.token_offsets[i] for i in keep_indices] + self.next_token_chooser=self.next_token_chooser.filter(keep_indices) + self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices] + remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias] + self.padding_right_offset=max(remaining_decode_tokens) + + # Select the remaining indices and remove unnecessary padding + max_input_length=max(self.input_lengths) + sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+self.padding_right_offset) + self.max_input_length=max_input_length + self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens) + + self.input_ids = self.input_ids[keep_indices,sequence_slice] + self.position_ids = self.position_ids[keep_indices,sequence_slice] + self.attention_mask = self.attention_mask[keep_indices,sequence_slice] + + tensors_to_update = [] + if self.past_key_values is not None: + if not isinstance(self.past_key_values,(list, tuple)): + raise NotImplementedError(f"Unsupported kv cache type: {type(self.past_key_values)}") + for layer_kv in self.past_key_values: + if isinstance(layer_kv, torch.Tensor): + tensors_to_update.append(layer_kv) + elif isinstance(layer_kv,(list, tuple)): + tensors_to_update.extend(layer_kv) + else: + raise NotImplementedError(f"Unsupported layer kv cache type: {type(layer_kv)}") + + kv_cache_slice=[keep_indices, *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), sequence_slice] + for tensor in tensors_to_update: + # Update tensors in-place to allow incremental garbage collection + tensors_to_update.data=tensor[kv_cache_slice] + + return self @classmethod @tracer.start_as_current_span("concatenate") @@ -157,74 +199,76 @@ class VectorizedNextTokenChooser: def __init__( self, batch_size:int, - watermark=None, - temperature=None, - repetition_penalty=None, - top_k=None, - top_p=None, - typical_p=None, - do_sample=None, - seed:int=0, - device="cpu", + watermark:Optional[List[Optional[bool]]]=None, + temperature:Optional[List[Optional[float]]]=None, + repetition_penalty:Optional[List[Optional[float]]]=None, + top_k:Optional[List[Optional[int]]]=None, + top_p:Optional[List[Optional[float]]]=None, + typical_p:Optional[List[Optional[float]]]=None, + do_sample:Optional[List[Optional[bool]]]=None, + seeds:Optional[List[Optional[int]]]=None, + device:torch.device="cpu", ): self.batch_size=batch_size self.filter_value = -float("Inf") + self.device=device - do_sample=self._standardize(do_sample, False) + # TODO: Seeds are ignored + self.seeds=self._standardize(seeds, 0) + self.do_sample=self._standardize(do_sample, False) - watermark=self._standardize(watermark, False) - if any(watermark): + self.watermark=self._standardize(watermark, False) + if any(self.watermark): raise NotImplementedError("Watermarking not implemented") - repetition_penalty=self._standardize(repetition_penalty, 1.0) - if any([x!=1.0 for x in repetition_penalty]): - self.repetition_penalty=torch.tensor(repetition_penalty, dtype=torch.float32, device=device).unsqueeze(1) + self.repetition_penalty=self._standardize(repetition_penalty, 1.0) + if any([x!=1.0 for x in self.repetition_penalty]): + self.repetition_penalty_t=torch.tensor(self.repetition_penalty, dtype=torch.float32, device=self.device).unsqueeze(1) else: - self.repetition_penalty=None + self.repetition_penalty_t=None - temperature=self._standardize(temperature, 1.0) - if any([x!=1.0 for x in temperature]): - do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)] - self.temperature=torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1) + self.temperature=self._standardize(temperature, 1.0) + if any([x!=1.0 for x in self.temperature]): + self.do_sample=[sample or x!=1.0 for x, sample in zip(self.temperature, self.do_sample)] + self.temperature_t=torch.tensor(self.temperature, dtype=torch.float32, device=self.device).unsqueeze(1) else: - self.temperature=None + self.temperature_t=None - top_k=self._standardize(top_k, 0) + self.top_k=self._standardize(top_k, 0) n_top_k=sum([x!=0 for x in top_k]) if n_top_k>0: - do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] - self.max_top_k=max(top_k) - self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1) + self.do_sample=[sample or x!=0 for x, sample in zip(self.top_k, self.do_sample)] + self.max_top_k=max(self.top_k) + self.top_k_t=torch.tensor([max(x-1,0) for x in self.top_k], dtype=torch.int64, device=self.device).unsqueeze(1) if n_top_k0 - if self.do_sample and num_do_sampleself.max_top_k: # Safety check max_top_k=scores.size(-1) - top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed. + top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed. else: max_top_k=self.max_top_k - top_k=self.top_k + top_k=self.top_k_t kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) if self.top_k_mask is not None: kth_scores.masked_fill_(self.top_k_mask, self.filter_value) @@ -264,17 +308,17 @@ class VectorizedNextTokenChooser: indices_to_remove = scores < kth_scores scores = scores.masked_fill(indices_to_remove, self.filter_value) - if self.top_p_inv is not None: + if self.top_p_t is not None: # TODO: Merge wit top_k sorted_logits, sorted_indices = torch.sort(scores, descending=True) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= self.top_p_inv + sorted_indices_to_remove = cumulative_probs <= self.top_p_t # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) - if self.typical_p is not None: + if self.typical_p_t is not None: # calculate entropy normalized = torch.nn.functional.log_softmax(scores, dim=-1) p = torch.exp(normalized) @@ -285,7 +329,7 @@ class VectorizedNextTokenChooser: sorted_logits = scores.gather(-1, sorted_indices) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative mass above the threshold - last_ind = (cumulative_probs < self.typical_p).sum(dim=1) + last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1) last_ind[last_ind < 0] = 0 sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) @@ -294,11 +338,11 @@ class VectorizedNextTokenChooser: # Compute logprobs logprobs = torch.log_softmax(scores, dim=-1) - if self.do_sample: + if self.num_do_sample: probs = torch.nn.functional.softmax(scores, -1) next_token_ids = torch.multinomial(probs, num_samples=1) - if self.do_sample_v is not None: - next_token_ids=torch.where(self.do_sample_v, next_token_ids,torch.argmax(scores, dim=-1)) + if self.do_sample_t is not None: + next_token_ids=torch.where(self.do_sample_t, next_token_ids,torch.argmax(scores, dim=-1)) else: next_token_ids = torch.argmax(scores, dim=-1) @@ -312,6 +356,7 @@ class VectorizedNextTokenChooser: ) -> "VectorizedNextTokenChooser": # TODO: Seeds are ignored return VectorizedNextTokenChooser( + batch_size=len(pb), watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], @@ -319,10 +364,26 @@ class VectorizedNextTokenChooser: top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], - seed=0, + seeds=[pb_.seed for pb_ in pb], device=device, ) + def filter(self, keep_indices: List[int]) -> "VectorizedNextTokenChooser": + return VectorizedNextTokenChooser( + batch_size=len(keep_indices), + watermark=[self.watermark[i] for i in keep_indices], + temperature=[self.temperature[i] for i in keep_indices], + repetition_penalty=[self.repetition_penalty[i] for i in keep_indices], + top_k=[self.top_k[i] for i in keep_indices], + top_p=[self.top_p[i] for i in keep_indices], + typical_p=[self.typical_p[i] for i in keep_indices], + do_sample=[self.do_sample[i] for i in keep_indices], + seeds=[self.seeds[i] for i in keep_indices], + device=self.device, + ) + + + class VectorizedCausalLM(Model): def __init__(