Return details optionally

This commit is contained in:
Joel Lamy-Poirier 2023-05-05 15:02:54 -04:00
parent a7c10f710f
commit a5bf08f6e2
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -1,4 +1,6 @@
import torch import torch
import os
import math
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
@ -6,6 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type, Dict, Union from typing import Optional, Tuple, List, Type, Dict, Union
from loguru import logger from loguru import logger
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -50,6 +53,10 @@ class VectorizedCausalLMBatch(Batch):
kv_cache_seq_dim:int=2 kv_cache_seq_dim:int=2
# TODO: Get from requests (should these be lists?)
details:bool=os.environ.get("RETURN_DETAILS") is not None
generate_stream:bool=os.environ.get("GENERATE_STREAM") is not None
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
@ -104,6 +111,8 @@ class VectorizedCausalLMBatch(Batch):
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens) max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -119,6 +128,7 @@ class VectorizedCausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length, max_input_length=max_input_length,
max_tokens=max_tokens, max_tokens=max_tokens,
generate_stream=generate_stream,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -295,7 +305,7 @@ class VectorizedNextTokenChooser:
device:torch.device="cpu", device:torch.device="cpu",
): ):
self.batch_size=batch_size self.batch_size=batch_size
self.filter_value = -float("Inf") self.filter_value = -math.inf
self.device=device self.device=device
# TODO: Seeds are ignored # TODO: Seeds are ignored
@ -366,68 +376,79 @@ class VectorizedNextTokenChooser:
values[i]=default values[i]=default
return values return values
def __call__(self, input_ids, scores): def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
last_token_scores=scores[:, -1, :]
if self.repetition_penalty_t is not None: if self.repetition_penalty_t is not None:
score = torch.gather(scores, 1, input_ids) score = torch.gather(last_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t) score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t)
scores.scatter_(1, input_ids, score) last_token_scores.scatter_(1, input_ids, score)
if self.temperature_t is not None: if self.temperature_t is not None:
scores.div_(self.temperature_t) last_token_scores.div_(self.temperature_t)
if self.top_k_t is not None: if self.top_k_t is not None:
if scores.size(-1)>self.max_top_k: # Safety check if last_token_scores.size(-1)>self.max_top_k: # Safety check
max_top_k=scores.size(-1) max_top_k=last_token_scores.size(-1)
top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed. top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed.
else: else:
max_top_k=self.max_top_k max_top_k=self.max_top_k
top_k=self.top_k_t top_k=self.top_k_t
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) kth_scores=torch.gather(torch.topk(last_token_scores, max_top_k)[0], 1, top_k)
if self.top_k_mask is not None: if self.top_k_mask is not None:
kth_scores.masked_fill_(self.top_k_mask, self.filter_value) kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < kth_scores indices_to_remove = last_token_scores < kth_scores
scores = scores.masked_fill(indices_to_remove, self.filter_value) last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
if self.top_p_t is not None: if self.top_p_t is not None:
# TODO: Merge wit top_k # TODO: Merge wit top_k
sorted_logits, sorted_indices = torch.sort(scores, descending=True) sorted_logits, sorted_indices = torch.sort(last_token_scores, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= self.top_p_t sorted_indices_to_remove = cumulative_probs <= self.top_p_t
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value) last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
if self.typical_p_t is not None: if self.typical_p_t is not None:
# calculate entropy # calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1) normalized = torch.nn.functional.log_softmax(last_token_scores, dim=-1)
p = torch.exp(normalized) p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True) ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort # shift and sort
shifted_scores = torch.abs((-normalized) - ent) shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices) sorted_logits = last_token_scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold # Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1) last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1)
last_ind[last_ind < 0] = 0 last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value) last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
if self.num_do_sample: if self.num_do_sample:
probs = torch.nn.functional.softmax(scores, -1) probs = torch.nn.functional.softmax(last_token_scores, -1)
next_token_ids = torch.multinomial(probs, num_samples=1) next_token_ids = torch.multinomial(probs, num_samples=1)
if self.do_sample_t is not None: if self.do_sample_t is not None:
next_token_ids=torch.where(self.do_sample_t, next_token_ids,torch.argmax(scores, dim=-1)) next_token_ids=torch.where(self.do_sample_t, next_token_ids, torch.argmax(last_token_scores, dim=-1))
else: else:
next_token_ids = torch.argmax(scores, dim=-1) next_token_ids = torch.argmax(last_token_scores, dim=-1)
# Compute logprobs if return_logprobs:
logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1) # Compute logprobs
if scores.size(1)==1:
scores=scores.unsqueeze(1)
else:
# TODO: Post-process all the tokens?
scores[:, -1, :]=last_token_scores
logprobs = torch.log_softmax(scores, dim=-1)
else:
logprobs=None
return next_token_ids, logprobs return next_token_ids, logprobs
@ -549,7 +570,29 @@ class VectorizedCausalLM(Model):
past_key_values=batch.past_key_values, past_key_values=batch.past_key_values,
) )
# TODO: Post-processing # TODO: Post-processing
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits[:, -1, :]) next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details)
next_token_ids=next_token_ids.cpu().tolist()
if batch.generate_stream:
# TODO: self.decode_token, offsets?
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
if batch.details:
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist()
if query_length>1:
prefill_token_ids=batch.input_ids[:, :key_length].tolist()
prefill_logprobs=logprobs.gather(1, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
prefill_tokens=[]
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths):
prefill_token_ids_=prefill_token_ids_[-input_length:]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids_,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens.append(PrefillTokens(
prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts
))
# Update batch # Update batch
# TODO: Why do we need all input ids? # TODO: Why do we need all input ids?
@ -558,19 +601,13 @@ class VectorizedCausalLM(Model):
batch.input_lengths=[length+1 for length in batch.input_lengths] batch.input_lengths=[length+1 for length in batch.input_lengths]
batch.max_input_length+=1 batch.max_input_length+=1
# TODO: self.decode_token, offsets?
next_token_ids=next_token_ids.cpu().tolist()
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
# TODO: Why do we need logprobs?
logprobs=logprobs.cpu().tolist()
# TODO: Vectorize some of this? # TODO: Vectorize some of this?
generations: List[Generation] = [] generations: List[Generation] = []
next_batch=None next_batch=None
for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)): for i, next_token_id in enumerate(next_token_ids):
next_token_text=next_token_texts[i] if batch.generate_stream else ""
stopping_criterias=batch.stopping_criterias[i] stopping_criterias=batch.stopping_criterias[i]
stop, reason = stopping_criterias( stop, reason = stopping_criterias(
next_token_id, next_token_id,
@ -594,9 +631,9 @@ class VectorizedCausalLM(Model):
generation = Generation( generation = Generation(
batch.requests[i].id, batch.requests[i].id,
None, # TODO: Prefill tokens prefill_tokens[i] if batch.details and query_length>1 else None, # TODO: Prefill tokens
next_token_id, next_token_id,
logprobs[i], token_logprobs[i] if batch.details else 0.0,
next_token_text, next_token_text,
next_token_id in self.all_special_ids, next_token_id in self.all_special_ids,
generated_text, generated_text,