mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix imports
This commit is contained in:
parent
f9e3a3bb91
commit
e7826855a3
@ -18,10 +18,7 @@ from text_generation_server.models.types import (
|
|||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
StoppingCriteria,
|
|
||||||
HeterogeneousNextTokenChooser
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -71,11 +68,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlens = [0]
|
cu_seqlens = [0]
|
||||||
@ -228,7 +225,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Slice from past
|
# Slice from past
|
||||||
past_key_values.append(
|
past_key_values.append(
|
||||||
self.past_key_values[:, self.cu_seqlens[idx]: self.cu_seqlens[idx + 1]]
|
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
|
||||||
)
|
)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
@ -242,7 +239,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
cumulative_length += request_input_length
|
cumulative_length += request_input_length
|
||||||
max_tokens += request_input_length + (
|
max_tokens += request_input_length + (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
if single_request:
|
if single_request:
|
||||||
@ -395,7 +392,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
] = batch.all_input_ids_tensor
|
] = batch.all_input_ids_tensor
|
||||||
|
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
@ -481,14 +478,14 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
cu_seqlens_q: Optional[torch.Tensor],
|
cu_seqlens_q: Optional[torch.Tensor],
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@ -503,7 +500,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
single_request = len(batch) == 1
|
single_request = len(batch) == 1
|
||||||
@ -512,7 +509,7 @@ class FlashCausalLM(Model):
|
|||||||
# Ask to pre-allocate kv to its max size
|
# Ask to pre-allocate kv to its max size
|
||||||
# == number of tokens + max_new_tokens
|
# == number of tokens + max_new_tokens
|
||||||
pre_allocate_past_size = (
|
pre_allocate_past_size = (
|
||||||
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pre_allocate_past_size = None
|
pre_allocate_past_size = None
|
||||||
@ -613,9 +610,9 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
input_length,
|
input_length,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
@ -630,8 +627,8 @@ class FlashCausalLM(Model):
|
|||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[
|
prefill_tokens_indices[
|
||||||
start_index: end_index - 1
|
start_index : end_index - 1
|
||||||
] = batch.input_ids[start_index + 1: end_index]
|
] = batch.input_ids[start_index + 1 : end_index]
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids
|
prefill_tokens_indices = batch.input_ids
|
||||||
@ -717,7 +714,7 @@ class FlashCausalLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens:]
|
all_input_ids[-stopping_criteria.current_tokens :]
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
@ -732,8 +729,8 @@ class FlashCausalLM(Model):
|
|||||||
if prefill:
|
if prefill:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||||
start_index: end_index - 1
|
start_index : end_index - 1
|
||||||
]
|
]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
|
@ -14,8 +14,9 @@ from text_generation_server.utils.tokens import (
|
|||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
Sampling,
|
||||||
|
Greedy,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.logits_process import Sampling, Greedy
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_file",
|
"convert_file",
|
||||||
|
@ -14,25 +14,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
|
||||||
self.generator = torch.Generator(device)
|
|
||||||
self.generator.manual_seed(seed)
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def __call__(self, logits):
|
|
||||||
probs = torch.nn.functional.softmax(logits, -1)
|
|
||||||
# Avoid GPU<->CPU sync done by torch multinomial
|
|
||||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
|
||||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
|
||||||
return probs.div_(q).argmax()
|
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
|
||||||
def __call__(self, logits):
|
|
||||||
return logits.argmax(dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class StaticWarper:
|
class StaticWarper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -329,46 +310,3 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.mass = self.mass[indices]
|
self.mass = self.mass[indices]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
|
||||||
r"""
|
|
||||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
|
||||||
self.seeds = seeds
|
|
||||||
|
|
||||||
self.greedy_indices = []
|
|
||||||
self.sampling_mapping = {}
|
|
||||||
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
|
||||||
if sample:
|
|
||||||
self.sampling_mapping[i] = Sampling(seed, device)
|
|
||||||
else:
|
|
||||||
self.greedy_indices.append(i)
|
|
||||||
|
|
||||||
self.greedy = Greedy()
|
|
||||||
|
|
||||||
def __call__(self, logits):
|
|
||||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
|
||||||
if self.greedy_indices:
|
|
||||||
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
|
|
||||||
|
|
||||||
for i, sampling in self.sampling_mapping.items():
|
|
||||||
out[i] = sampling(logits[i])
|
|
||||||
return out
|
|
||||||
|
|
||||||
def filter(self, indices):
|
|
||||||
new_greedy_indices = []
|
|
||||||
new_sampling_mapping = {}
|
|
||||||
for i, idx in enumerate(indices):
|
|
||||||
if idx in self.sampling_mapping:
|
|
||||||
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
|
||||||
else:
|
|
||||||
new_greedy_indices.append(i)
|
|
||||||
|
|
||||||
self.greedy_indices = new_greedy_indices
|
|
||||||
self.sampling_mapping = new_sampling_mapping
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,31 +3,36 @@ import torch
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
PreTrainedTokenizerBase, LogitsProcessorList,
|
PreTrainedTokenizerBase,
|
||||||
|
LogitsProcessorList,
|
||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from text_generation_server.utils import Sampling, Greedy
|
from text_generation_server.utils.logits_process import (
|
||||||
from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \
|
static_warper,
|
||||||
HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
HeterogeneousTypicalLogitsWarper, HeterogeneousSampling
|
HeterogeneousTemperatureLogitsWarper,
|
||||||
|
HeterogeneousTopKLogitsWarper,
|
||||||
|
HeterogeneousTopPLogitsWarper,
|
||||||
|
HeterogeneousTypicalLogitsWarper,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
watermark=False,
|
watermark=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
typical_p=None,
|
typical_p=None,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -39,10 +44,10 @@ class NextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
or (top_p is not None and top_p < 1.0)
|
or (top_p is not None and top_p < 1.0)
|
||||||
or (typical_p is not None and typical_p < 1.0)
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
)
|
)
|
||||||
if has_warpers:
|
if has_warpers:
|
||||||
self.static_warper = static_warper(
|
self.static_warper = static_warper(
|
||||||
@ -71,9 +76,9 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
@ -101,11 +106,11 @@ class StopSequenceCriteria:
|
|||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
@ -131,9 +136,9 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
@ -148,17 +153,17 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
watermark: List[bool],
|
watermark: List[bool],
|
||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
top_p: List[float],
|
top_p: List[float],
|
||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
):
|
):
|
||||||
warpers = LogitsProcessorList()
|
warpers = LogitsProcessorList()
|
||||||
|
|
||||||
@ -223,10 +228,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
@ -240,3 +245,63 @@ class HeterogeneousNextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Sampling:
|
||||||
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
|
self.generator = torch.Generator(device)
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
probs = torch.nn.functional.softmax(logits, -1)
|
||||||
|
# Avoid GPU<->CPU sync done by torch multinomial
|
||||||
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||||
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||||
|
return probs.div_(q).argmax()
|
||||||
|
|
||||||
|
|
||||||
|
class Greedy:
|
||||||
|
def __call__(self, logits):
|
||||||
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousSampling:
|
||||||
|
r"""
|
||||||
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||||
|
self.seeds = seeds
|
||||||
|
|
||||||
|
self.greedy_indices = []
|
||||||
|
self.sampling_mapping = {}
|
||||||
|
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
||||||
|
if sample:
|
||||||
|
self.sampling_mapping[i] = Sampling(seed, device)
|
||||||
|
else:
|
||||||
|
self.greedy_indices.append(i)
|
||||||
|
|
||||||
|
self.greedy = Greedy()
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||||
|
if self.greedy_indices:
|
||||||
|
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
|
||||||
|
|
||||||
|
for i, sampling in self.sampling_mapping.items():
|
||||||
|
out[i] = sampling(logits[i])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
new_greedy_indices = []
|
||||||
|
new_sampling_mapping = {}
|
||||||
|
for i, idx in enumerate(indices):
|
||||||
|
if idx in self.sampling_mapping:
|
||||||
|
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
||||||
|
else:
|
||||||
|
new_greedy_indices.append(i)
|
||||||
|
|
||||||
|
self.greedy_indices = new_greedy_indices
|
||||||
|
self.sampling_mapping = new_sampling_mapping
|
||||||
|
return self
|
||||||
|
Loading…
Reference in New Issue
Block a user