fix imports

This commit is contained in:
OlivierDehaene 2023-05-12 15:47:57 +02:00
parent f9e3a3bb91
commit e7826855a3
4 changed files with 139 additions and 138 deletions

View File

@ -18,10 +18,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import ( from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
StoppingCriteria,
HeterogeneousNextTokenChooser
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -71,11 +68,11 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
position_ids = [] position_ids = []
cu_seqlens = [0] cu_seqlens = [0]
@ -228,7 +225,7 @@ class FlashCausalLMBatch(Batch):
# Slice from past # Slice from past
past_key_values.append( past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx]: self.cu_seqlens[idx + 1]] self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
) )
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
@ -242,7 +239,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length += request_input_length cumulative_length += request_input_length
max_tokens += request_input_length + ( max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
if single_request: if single_request:
@ -395,7 +392,7 @@ class FlashCausalLMBatch(Batch):
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor ] = batch.all_input_ids_tensor
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
@ -481,14 +478,14 @@ class FlashCausalLM(Model):
) )
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_q: Optional[torch.Tensor],
max_s: int, max_s: int,
past_key_values: Optional = None, past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
@ -503,7 +500,7 @@ class FlashCausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None prefill = batch.past_key_values is None
single_request = len(batch) == 1 single_request = len(batch) == 1
@ -512,7 +509,7 @@ class FlashCausalLM(Model):
# Ask to pre-allocate kv to its max size # Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens # == number of tokens + max_new_tokens
pre_allocate_past_size = ( pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
) )
else: else:
pre_allocate_past_size = None pre_allocate_past_size = None
@ -613,9 +610,9 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
input_length, input_length,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
@ -630,8 +627,8 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[ prefill_tokens_indices[
start_index: end_index - 1 start_index : end_index - 1
] = batch.input_ids[start_index + 1: end_index] ] = batch.input_ids[start_index + 1 : end_index]
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids prefill_tokens_indices = batch.input_ids
@ -717,7 +714,7 @@ class FlashCausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens:] all_input_ids[-stopping_criteria.current_tokens :]
) )
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
@ -732,8 +729,8 @@ class FlashCausalLM(Model):
if prefill: if prefill:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index: end_index - 1 start_index : end_index - 1
] ]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,

View File

@ -14,8 +14,9 @@ from text_generation_server.utils.tokens import (
StoppingCriteria, StoppingCriteria,
StopSequenceCriteria, StopSequenceCriteria,
FinishReason, FinishReason,
Sampling,
Greedy,
) )
from text_generation_server.utils.logits_process import Sampling, Greedy
__all__ = [ __all__ = [
"convert_file", "convert_file",

View File

@ -14,25 +14,6 @@ from transformers import (
) )
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class StaticWarper: class StaticWarper:
def __init__( def __init__(
self, self,
@ -329,46 +310,3 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
def filter(self, indices): def filter(self, indices):
self.mass = self.mass[indices] self.mass = self.mass[indices]
return self return self
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy_indices = []
self.sampling_mapping = {}
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
if sample:
self.sampling_mapping[i] = Sampling(seed, device)
else:
self.greedy_indices.append(i)
self.greedy = Greedy()
def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i])
return out
def filter(self, indices):
new_greedy_indices = []
new_sampling_mapping = {}
for i, idx in enumerate(indices):
if idx in self.sampling_mapping:
new_sampling_mapping[i] = self.sampling_mapping[idx]
else:
new_greedy_indices.append(i)
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self

View File

@ -3,31 +3,36 @@ import torch
from transformers import ( from transformers import (
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase, LogitsProcessorList, PreTrainedTokenizerBase,
LogitsProcessorList,
) )
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils import Sampling, Greedy from text_generation_server.utils.logits_process import (
from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \ static_warper,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \ HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTypicalLogitsWarper, HeterogeneousSampling HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
)
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -39,10 +44,10 @@ class NextTokenChooser:
) )
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0) or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0) or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0) or (typical_p is not None and typical_p < 1.0)
) )
if has_warpers: if has_warpers:
self.static_warper = static_warper( self.static_warper = static_warper(
@ -71,9 +76,9 @@ class NextTokenChooser:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -101,11 +106,11 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
@ -131,9 +136,9 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.StoppingCriteriaParameters, pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
@ -148,17 +153,17 @@ class StoppingCriteria:
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
self, self,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
watermark: List[bool], watermark: List[bool],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
top_k: List[int], top_k: List[int],
top_p: List[float], top_p: List[float],
typical_p: List[float], typical_p: List[float],
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
): ):
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
@ -223,10 +228,10 @@ class HeterogeneousNextTokenChooser:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: List[generate_pb2.NextTokenChooserParameters], pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -240,3 +245,63 @@ class HeterogeneousNextTokenChooser:
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy_indices = []
self.sampling_mapping = {}
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
if sample:
self.sampling_mapping[i] = Sampling(seed, device)
else:
self.greedy_indices.append(i)
self.greedy = Greedy()
def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i])
return out
def filter(self, indices):
new_greedy_indices = []
new_sampling_mapping = {}
for i, idx in enumerate(indices):
if idx in self.sampling_mapping:
new_sampling_mapping[i] = self.sampling_mapping[idx]
else:
new_greedy_indices.append(i)
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self