mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
filter
This commit is contained in:
parent
476d8fc379
commit
7a70928b06
@ -14,7 +14,7 @@ from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils import StoppingCriteria
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -48,6 +48,8 @@ class VectorizedCausalLMBatch(Batch):
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
kv_cache_seq_dim:int=2
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
@ -64,7 +66,6 @@ class VectorizedCausalLMBatch(Batch):
|
||||
device: torch.device,
|
||||
) -> "VectorizedCausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
@ -75,14 +76,10 @@ class VectorizedCausalLMBatch(Batch):
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
next_token_chooser=NextTokenChooser.from_pb(r.parameters, device)
|
||||
# TODO: Implement
|
||||
assert len(next_token_chooser.warpers)==0
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
next_token_choosers.append(next_token_chooser)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
@ -134,7 +131,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
input_lengths=input_lengths.tolist(),
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_chooser=next_token_choosers,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length.item(),
|
||||
max_tokens=max_tokens,
|
||||
@ -142,7 +139,52 @@ class VectorizedCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]:
|
||||
raise NotImplementedError()
|
||||
if len(requests) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(requests) == len(self):
|
||||
return self
|
||||
|
||||
self.requests = requests
|
||||
keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests]
|
||||
|
||||
# New values after filtering
|
||||
self.requests_idx_mapping={r.id:i for i, r in enumerate(self.requests)}
|
||||
self.input_lengths=[self.input_lengths[i] for i in keep_indices]
|
||||
self.offsets = [self.offsets[i] for i in keep_indices]
|
||||
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
|
||||
self.next_token_chooser=self.next_token_chooser.filter(keep_indices)
|
||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
||||
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
||||
self.padding_right_offset=max(remaining_decode_tokens)
|
||||
|
||||
# Select the remaining indices and remove unnecessary padding
|
||||
max_input_length=max(self.input_lengths)
|
||||
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+self.padding_right_offset)
|
||||
self.max_input_length=max_input_length
|
||||
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens)
|
||||
|
||||
self.input_ids = self.input_ids[keep_indices,sequence_slice]
|
||||
self.position_ids = self.position_ids[keep_indices,sequence_slice]
|
||||
self.attention_mask = self.attention_mask[keep_indices,sequence_slice]
|
||||
|
||||
tensors_to_update = []
|
||||
if self.past_key_values is not None:
|
||||
if not isinstance(self.past_key_values,(list, tuple)):
|
||||
raise NotImplementedError(f"Unsupported kv cache type: {type(self.past_key_values)}")
|
||||
for layer_kv in self.past_key_values:
|
||||
if isinstance(layer_kv, torch.Tensor):
|
||||
tensors_to_update.append(layer_kv)
|
||||
elif isinstance(layer_kv,(list, tuple)):
|
||||
tensors_to_update.extend(layer_kv)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported layer kv cache type: {type(layer_kv)}")
|
||||
|
||||
kv_cache_slice=[keep_indices, *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), sequence_slice]
|
||||
for tensor in tensors_to_update:
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
tensors_to_update.data=tensor[kv_cache_slice]
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
@ -157,74 +199,76 @@ class VectorizedNextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size:int,
|
||||
watermark=None,
|
||||
temperature=None,
|
||||
repetition_penalty=None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
do_sample=None,
|
||||
seed:int=0,
|
||||
device="cpu",
|
||||
watermark:Optional[List[Optional[bool]]]=None,
|
||||
temperature:Optional[List[Optional[float]]]=None,
|
||||
repetition_penalty:Optional[List[Optional[float]]]=None,
|
||||
top_k:Optional[List[Optional[int]]]=None,
|
||||
top_p:Optional[List[Optional[float]]]=None,
|
||||
typical_p:Optional[List[Optional[float]]]=None,
|
||||
do_sample:Optional[List[Optional[bool]]]=None,
|
||||
seeds:Optional[List[Optional[int]]]=None,
|
||||
device:torch.device="cpu",
|
||||
):
|
||||
self.batch_size=batch_size
|
||||
self.filter_value = -float("Inf")
|
||||
self.device=device
|
||||
|
||||
do_sample=self._standardize(do_sample, False)
|
||||
# TODO: Seeds are ignored
|
||||
self.seeds=self._standardize(seeds, 0)
|
||||
self.do_sample=self._standardize(do_sample, False)
|
||||
|
||||
watermark=self._standardize(watermark, False)
|
||||
if any(watermark):
|
||||
self.watermark=self._standardize(watermark, False)
|
||||
if any(self.watermark):
|
||||
raise NotImplementedError("Watermarking not implemented")
|
||||
|
||||
repetition_penalty=self._standardize(repetition_penalty, 1.0)
|
||||
if any([x!=1.0 for x in repetition_penalty]):
|
||||
self.repetition_penalty=torch.tensor(repetition_penalty, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.repetition_penalty=self._standardize(repetition_penalty, 1.0)
|
||||
if any([x!=1.0 for x in self.repetition_penalty]):
|
||||
self.repetition_penalty_t=torch.tensor(self.repetition_penalty, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.repetition_penalty=None
|
||||
self.repetition_penalty_t=None
|
||||
|
||||
temperature=self._standardize(temperature, 1.0)
|
||||
if any([x!=1.0 for x in temperature]):
|
||||
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
|
||||
self.temperature=torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.temperature=self._standardize(temperature, 1.0)
|
||||
if any([x!=1.0 for x in self.temperature]):
|
||||
self.do_sample=[sample or x!=1.0 for x, sample in zip(self.temperature, self.do_sample)]
|
||||
self.temperature_t=torch.tensor(self.temperature, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.temperature=None
|
||||
self.temperature_t=None
|
||||
|
||||
top_k=self._standardize(top_k, 0)
|
||||
self.top_k=self._standardize(top_k, 0)
|
||||
n_top_k=sum([x!=0 for x in top_k])
|
||||
if n_top_k>0:
|
||||
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||
self.max_top_k=max(top_k)
|
||||
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1)
|
||||
self.do_sample=[sample or x!=0 for x, sample in zip(self.top_k, self.do_sample)]
|
||||
self.max_top_k=max(self.top_k)
|
||||
self.top_k_t=torch.tensor([max(x-1,0) for x in self.top_k], dtype=torch.int64, device=self.device).unsqueeze(1)
|
||||
if n_top_k<self.batch_size:
|
||||
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
|
||||
self.top_k_mask=torch.tensor([x==0 for x in self.top_k], dtype=torch.bool, device=self.device)
|
||||
else:
|
||||
self.top_k_mask=None
|
||||
else:
|
||||
self.max_top_k=None
|
||||
self.top_k=None
|
||||
self.top_k_t=None
|
||||
self.top_k_mask=None
|
||||
|
||||
top_p=self._standardize(top_p, 1.0)
|
||||
if any([x<1.0 for x in top_p]):
|
||||
do_sample=[sample or x<1.0 for x, sample in zip(temperature, top_p)]
|
||||
self.top_p_inv=torch.tensor([1.0-x for x in top_p], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.top_p=self._standardize(top_p, 1.0)
|
||||
if any([x<1.0 for x in self.top_p]):
|
||||
self.do_sample=[sample or x<1.0 for x, sample in zip(temperature, self.top_p)]
|
||||
self.top_p_t=torch.tensor([1.0-x for x in self.top_p], dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.top_p_inv=None
|
||||
self.top_p_t=None
|
||||
|
||||
typical_p=self._standardize(typical_p, 1.0)
|
||||
if any([x<1.0 for x in typical_p]):
|
||||
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
|
||||
self.typical_p=torch.tensor(typical_p, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.typical_p=self._standardize(typical_p, 1.0)
|
||||
if any([x<1.0 for x in self.typical_p]):
|
||||
self.do_sample=[sample or x<1.0 for x, sample in zip(self.typical_p, self.do_sample)]
|
||||
self.typical_p_t=torch.tensor(self.typical_p, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.typical_p=None
|
||||
self.typical_p_t=None
|
||||
|
||||
num_do_sample=sum(do_sample)
|
||||
self.do_sample = num_do_sample>0
|
||||
if self.do_sample and num_do_sample<self.batch_size:
|
||||
self.num_do_sample=sum(self.do_sample)
|
||||
if 0<self.num_do_sample<self.batch_size:
|
||||
# Mixed greedy and probabilistic sampling. Compute both and pick the right one.
|
||||
self.do_sample_v=torch.tensor(do_sample, dtype=torch.bool, device=device)
|
||||
self.do_sample_t=torch.tensor(self.do_sample, dtype=torch.bool, device=self.device)
|
||||
else:
|
||||
self.do_sample_v=None
|
||||
self.do_sample_t=None
|
||||
|
||||
def _standardize(self, values, default):
|
||||
if isinstance(values, list):
|
||||
@ -241,22 +285,22 @@ class VectorizedNextTokenChooser:
|
||||
# Only process the last token
|
||||
scores=scores[: -1, :]
|
||||
|
||||
if self.repetition_penalty is not None:
|
||||
if self.repetition_penalty_t is not None:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)
|
||||
score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t)
|
||||
scores.scatter_(1, input_ids, score)
|
||||
|
||||
if self.temperature is not None:
|
||||
scores.div_(self.temperature)
|
||||
if self.temperature_t is not None:
|
||||
scores.div_(self.temperature_t)
|
||||
|
||||
if self.top_k is not None:
|
||||
if self.top_k_t is not None:
|
||||
if scores.size(-1)>self.max_top_k: # Safety check
|
||||
max_top_k=scores.size(-1)
|
||||
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed.
|
||||
top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed.
|
||||
else:
|
||||
max_top_k=self.max_top_k
|
||||
top_k=self.top_k
|
||||
top_k=self.top_k_t
|
||||
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||
if self.top_k_mask is not None:
|
||||
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
||||
@ -264,17 +308,17 @@ class VectorizedNextTokenChooser:
|
||||
indices_to_remove = scores < kth_scores
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
if self.top_p_inv is not None:
|
||||
if self.top_p_t is not None:
|
||||
# TODO: Merge wit top_k
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_inv
|
||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_t
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
if self.typical_p is not None:
|
||||
if self.typical_p_t is not None:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
@ -285,7 +329,7 @@ class VectorizedNextTokenChooser:
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.typical_p).sum(dim=1)
|
||||
last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
@ -294,11 +338,11 @@ class VectorizedNextTokenChooser:
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, dim=-1)
|
||||
|
||||
if self.do_sample:
|
||||
if self.num_do_sample:
|
||||
probs = torch.nn.functional.softmax(scores, -1)
|
||||
next_token_ids = torch.multinomial(probs, num_samples=1)
|
||||
if self.do_sample_v is not None:
|
||||
next_token_ids=torch.where(self.do_sample_v, next_token_ids,torch.argmax(scores, dim=-1))
|
||||
if self.do_sample_t is not None:
|
||||
next_token_ids=torch.where(self.do_sample_t, next_token_ids,torch.argmax(scores, dim=-1))
|
||||
else:
|
||||
next_token_ids = torch.argmax(scores, dim=-1)
|
||||
|
||||
@ -312,6 +356,7 @@ class VectorizedNextTokenChooser:
|
||||
) -> "VectorizedNextTokenChooser":
|
||||
# TODO: Seeds are ignored
|
||||
return VectorizedNextTokenChooser(
|
||||
batch_size=len(pb),
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
@ -319,10 +364,26 @@ class VectorizedNextTokenChooser:
|
||||
top_p=[pb_.top_p for pb_ in pb],
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
do_sample=[pb_.do_sample for pb_ in pb],
|
||||
seed=0,
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
)
|
||||
|
||||
def filter(self, keep_indices: List[int]) -> "VectorizedNextTokenChooser":
|
||||
return VectorizedNextTokenChooser(
|
||||
batch_size=len(keep_indices),
|
||||
watermark=[self.watermark[i] for i in keep_indices],
|
||||
temperature=[self.temperature[i] for i in keep_indices],
|
||||
repetition_penalty=[self.repetition_penalty[i] for i in keep_indices],
|
||||
top_k=[self.top_k[i] for i in keep_indices],
|
||||
top_p=[self.top_p[i] for i in keep_indices],
|
||||
typical_p=[self.typical_p[i] for i in keep_indices],
|
||||
do_sample=[self.do_sample[i] for i in keep_indices],
|
||||
seeds=[self.seeds[i] for i in keep_indices],
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class VectorizedCausalLM(Model):
|
||||
def __init__(
|
||||
|
Loading…
Reference in New Issue
Block a user