diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 35d037a0..fc6d4760 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -1,4 +1,6 @@ import torch +import os +import math from dataclasses import dataclass from opentelemetry import trace @@ -6,6 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type, Dict, Union from loguru import logger + from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, @@ -50,6 +53,10 @@ class VectorizedCausalLMBatch(Batch): kv_cache_seq_dim:int=2 + # TODO: Get from requests (should these be lists?) + details:bool=os.environ.get("RETURN_DETAILS") is not None + generate_stream:bool=os.environ.get("GENERATE_STREAM") is not None + def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( id=self.batch_id, @@ -104,6 +111,8 @@ class VectorizedCausalLMBatch(Batch): max_tokens = len(inputs) * max_input_length + sum(max_new_tokens) + generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias) + return cls( batch_id=pb.id, requests=pb.requests, @@ -119,6 +128,7 @@ class VectorizedCausalLMBatch(Batch): stopping_criterias=stopping_criterias, max_input_length=max_input_length, max_tokens=max_tokens, + generate_stream=generate_stream, ) @tracer.start_as_current_span("filter") @@ -295,7 +305,7 @@ class VectorizedNextTokenChooser: device:torch.device="cpu", ): self.batch_size=batch_size - self.filter_value = -float("Inf") + self.filter_value = -math.inf self.device=device # TODO: Seeds are ignored @@ -366,68 +376,79 @@ class VectorizedNextTokenChooser: values[i]=default return values - def __call__(self, input_ids, scores): + def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool): + last_token_scores=scores[:, -1, :] + if self.repetition_penalty_t is not None: - score = torch.gather(scores, 1, input_ids) + score = torch.gather(last_token_scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t) - scores.scatter_(1, input_ids, score) + last_token_scores.scatter_(1, input_ids, score) if self.temperature_t is not None: - scores.div_(self.temperature_t) + last_token_scores.div_(self.temperature_t) if self.top_k_t is not None: - if scores.size(-1)>self.max_top_k: # Safety check - max_top_k=scores.size(-1) + if last_token_scores.size(-1)>self.max_top_k: # Safety check + max_top_k=last_token_scores.size(-1) top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed. else: max_top_k=self.max_top_k top_k=self.top_k_t - kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) + kth_scores=torch.gather(torch.topk(last_token_scores, max_top_k)[0], 1, top_k) if self.top_k_mask is not None: kth_scores.masked_fill_(self.top_k_mask, self.filter_value) # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = scores < kth_scores - scores = scores.masked_fill(indices_to_remove, self.filter_value) + indices_to_remove = last_token_scores < kth_scores + last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value) if self.top_p_t is not None: # TODO: Merge wit top_k - sorted_logits, sorted_indices = torch.sort(scores, descending=True) + sorted_logits, sorted_indices = torch.sort(last_token_scores, descending=True) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= self.top_p_t # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores = scores.masked_fill(indices_to_remove, self.filter_value) + last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value) if self.typical_p_t is not None: # calculate entropy - normalized = torch.nn.functional.log_softmax(scores, dim=-1) + normalized = torch.nn.functional.log_softmax(last_token_scores, dim=-1) p = torch.exp(normalized) ent = -(normalized * p).nansum(-1, keepdim=True) # shift and sort shifted_scores = torch.abs((-normalized) - ent) sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) - sorted_logits = scores.gather(-1, sorted_indices) + sorted_logits = last_token_scores.gather(-1, sorted_indices) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative mass above the threshold last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1) last_ind[last_ind < 0] = 0 sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores = scores.masked_fill(indices_to_remove, self.filter_value) + last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value) if self.num_do_sample: - probs = torch.nn.functional.softmax(scores, -1) + probs = torch.nn.functional.softmax(last_token_scores, -1) next_token_ids = torch.multinomial(probs, num_samples=1) if self.do_sample_t is not None: - next_token_ids=torch.where(self.do_sample_t, next_token_ids,torch.argmax(scores, dim=-1)) + next_token_ids=torch.where(self.do_sample_t, next_token_ids, torch.argmax(last_token_scores, dim=-1)) else: - next_token_ids = torch.argmax(scores, dim=-1) + next_token_ids = torch.argmax(last_token_scores, dim=-1) - # Compute logprobs - logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1) + if return_logprobs: + # Compute logprobs + if scores.size(1)==1: + scores=scores.unsqueeze(1) + else: + # TODO: Post-process all the tokens? + scores[:, -1, :]=last_token_scores + + logprobs = torch.log_softmax(scores, dim=-1) + else: + logprobs=None return next_token_ids, logprobs @@ -549,7 +570,29 @@ class VectorizedCausalLM(Model): past_key_values=batch.past_key_values, ) # TODO: Post-processing - next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits[:, -1, :]) + next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details) + + next_token_ids=next_token_ids.cpu().tolist() + if batch.generate_stream: + # TODO: self.decode_token, offsets? + next_token_texts=self.tokenizer.batch_decode(next_token_ids) + + if batch.details: + token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist() + if query_length>1: + prefill_token_ids=batch.input_ids[:, :key_length].tolist() + prefill_logprobs=logprobs.gather(1, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist() + prefill_tokens=[] + for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths): + prefill_token_ids_=prefill_token_ids_[-input_length:] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids_, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens.append(PrefillTokens( + prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts + )) # Update batch # TODO: Why do we need all input ids? @@ -558,19 +601,13 @@ class VectorizedCausalLM(Model): batch.input_lengths=[length+1 for length in batch.input_lengths] batch.max_input_length+=1 - # TODO: self.decode_token, offsets? - next_token_ids=next_token_ids.cpu().tolist() - next_token_texts=self.tokenizer.batch_decode(next_token_ids) - - # TODO: Why do we need logprobs? - logprobs=logprobs.cpu().tolist() - # TODO: Vectorize some of this? generations: List[Generation] = [] next_batch=None - for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)): + for i, next_token_id in enumerate(next_token_ids): + next_token_text=next_token_texts[i] if batch.generate_stream else "" stopping_criterias=batch.stopping_criterias[i] stop, reason = stopping_criterias( next_token_id, @@ -594,9 +631,9 @@ class VectorizedCausalLM(Model): generation = Generation( batch.requests[i].id, - None, # TODO: Prefill tokens + prefill_tokens[i] if batch.details and query_length>1 else None, # TODO: Prefill tokens next_token_id, - logprobs[i], + token_logprobs[i] if batch.details else 0.0, next_token_text, next_token_id in self.all_special_ids, generated_text,