Return details optionally

This commit is contained in:
Joel Lamy-Poirier 2023-05-05 15:02:54 -04:00
parent a7c10f710f
commit a5bf08f6e2
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -1,4 +1,6 @@
import torch
import os
import math
from dataclasses import dataclass
from opentelemetry import trace
@ -6,6 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type, Dict, Union
from loguru import logger
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
@ -50,6 +53,10 @@ class VectorizedCausalLMBatch(Batch):
kv_cache_seq_dim:int=2
# TODO: Get from requests (should these be lists?)
details:bool=os.environ.get("RETURN_DETAILS") is not None
generate_stream:bool=os.environ.get("GENERATE_STREAM") is not None
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id,
@ -104,6 +111,8 @@ class VectorizedCausalLMBatch(Batch):
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias)
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -119,6 +128,7 @@ class VectorizedCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
max_tokens=max_tokens,
generate_stream=generate_stream,
)
@tracer.start_as_current_span("filter")
@ -295,7 +305,7 @@ class VectorizedNextTokenChooser:
device:torch.device="cpu",
):
self.batch_size=batch_size
self.filter_value = -float("Inf")
self.filter_value = -math.inf
self.device=device
# TODO: Seeds are ignored
@ -366,68 +376,79 @@ class VectorizedNextTokenChooser:
values[i]=default
return values
def __call__(self, input_ids, scores):
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
last_token_scores=scores[:, -1, :]
if self.repetition_penalty_t is not None:
score = torch.gather(scores, 1, input_ids)
score = torch.gather(last_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t)
scores.scatter_(1, input_ids, score)
last_token_scores.scatter_(1, input_ids, score)
if self.temperature_t is not None:
scores.div_(self.temperature_t)
last_token_scores.div_(self.temperature_t)
if self.top_k_t is not None:
if scores.size(-1)>self.max_top_k: # Safety check
max_top_k=scores.size(-1)
if last_token_scores.size(-1)>self.max_top_k: # Safety check
max_top_k=last_token_scores.size(-1)
top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed.
else:
max_top_k=self.max_top_k
top_k=self.top_k_t
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
kth_scores=torch.gather(torch.topk(last_token_scores, max_top_k)[0], 1, top_k)
if self.top_k_mask is not None:
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < kth_scores
scores = scores.masked_fill(indices_to_remove, self.filter_value)
indices_to_remove = last_token_scores < kth_scores
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
if self.top_p_t is not None:
# TODO: Merge wit top_k
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
sorted_logits, sorted_indices = torch.sort(last_token_scores, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= self.top_p_t
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
if self.typical_p_t is not None:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
normalized = torch.nn.functional.log_softmax(last_token_scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
sorted_logits = last_token_scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
if self.num_do_sample:
probs = torch.nn.functional.softmax(scores, -1)
probs = torch.nn.functional.softmax(last_token_scores, -1)
next_token_ids = torch.multinomial(probs, num_samples=1)
if self.do_sample_t is not None:
next_token_ids=torch.where(self.do_sample_t, next_token_ids,torch.argmax(scores, dim=-1))
next_token_ids=torch.where(self.do_sample_t, next_token_ids, torch.argmax(last_token_scores, dim=-1))
else:
next_token_ids = torch.argmax(scores, dim=-1)
next_token_ids = torch.argmax(last_token_scores, dim=-1)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1)
if return_logprobs:
# Compute logprobs
if scores.size(1)==1:
scores=scores.unsqueeze(1)
else:
# TODO: Post-process all the tokens?
scores[:, -1, :]=last_token_scores
logprobs = torch.log_softmax(scores, dim=-1)
else:
logprobs=None
return next_token_ids, logprobs
@ -549,7 +570,29 @@ class VectorizedCausalLM(Model):
past_key_values=batch.past_key_values,
)
# TODO: Post-processing
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits[:, -1, :])
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details)
next_token_ids=next_token_ids.cpu().tolist()
if batch.generate_stream:
# TODO: self.decode_token, offsets?
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
if batch.details:
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist()
if query_length>1:
prefill_token_ids=batch.input_ids[:, :key_length].tolist()
prefill_logprobs=logprobs.gather(1, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
prefill_tokens=[]
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths):
prefill_token_ids_=prefill_token_ids_[-input_length:]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids_,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens.append(PrefillTokens(
prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts
))
# Update batch
# TODO: Why do we need all input ids?
@ -558,19 +601,13 @@ class VectorizedCausalLM(Model):
batch.input_lengths=[length+1 for length in batch.input_lengths]
batch.max_input_length+=1
# TODO: self.decode_token, offsets?
next_token_ids=next_token_ids.cpu().tolist()
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
# TODO: Why do we need logprobs?
logprobs=logprobs.cpu().tolist()
# TODO: Vectorize some of this?
generations: List[Generation] = []
next_batch=None
for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)):
for i, next_token_id in enumerate(next_token_ids):
next_token_text=next_token_texts[i] if batch.generate_stream else ""
stopping_criterias=batch.stopping_criterias[i]
stop, reason = stopping_criterias(
next_token_id,
@ -594,9 +631,9 @@ class VectorizedCausalLM(Model):
generation = Generation(
batch.requests[i].id,
None, # TODO: Prefill tokens
prefill_tokens[i] if batch.details and query_length>1 else None, # TODO: Prefill tokens
next_token_id,
logprobs[i],
token_logprobs[i] if batch.details else 0.0,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,