This commit is contained in:
Joel Lamy-Poirier 2023-05-05 18:48:57 -04:00
parent 0e648a71f9
commit 87b5f03958
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
3 changed files with 315 additions and 169 deletions

View File

@ -114,7 +114,9 @@ def get_model(
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=True)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=True
)
model_type = config.model_type
if model_type == "bloom":

View File

@ -18,7 +18,9 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria
from text_generation_server.utils.tokens_heterogeneous import HeterogeneousNextTokenChooser
from text_generation_server.utils.tokens_heterogeneous import (
HeterogeneousNextTokenChooser,
)
tracer = trace.get_tracer(__name__)
@ -32,7 +34,9 @@ class VectorizedCausalLMBatch(Batch):
# Decoder values
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]]
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]
]
# All tokens
input_ids: torch.Tensor
@ -79,10 +83,17 @@ class VectorizedCausalLMBatch(Batch):
requests_idx_mapping = {r.id: i for i, r in enumerate(pb.requests)}
# Parse batch
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
stopping_criterias = [
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
for r in pb.requests
]
max_new_tokens = (
stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias
)
next_token_chooser= HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.parameters for r in pb.requests], device
)
tokenized_inputs = tokenizer(
inputs,
@ -112,7 +123,10 @@ class VectorizedCausalLMBatch(Batch):
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias)
generate_stream = cls.generate_stream or any(
stopping_criteria.stop_sequence_criterias
for stopping_criteria in stopping_criterias
)
return cls(
batch_id=pb.id,
@ -133,7 +147,9 @@ class VectorizedCausalLMBatch(Batch):
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]:
def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["VectorizedCausalLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
@ -148,16 +164,26 @@ class VectorizedCausalLMBatch(Batch):
self.offsets = [self.offsets[i] for i in keep_indices]
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
self.next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in self.requests], self.input_ids.device)
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.parameters for r in self.requests], self.input_ids.device
)
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
remaining_decode_tokens = [
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
for stopping_criteria in self.stopping_criterias
]
# Select the remaining indices and remove unnecessary padding
max_input_length = max(self.input_lengths)
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens))
sequence_slice = slice(
self.max_input_length - max_input_length,
self.max_input_length + max(remaining_decode_tokens),
)
self.max_input_length = max_input_length
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens)
self.max_tokens = len(self.requests) * self.max_input_length + sum(
remaining_decode_tokens
)
self.input_ids = self.input_ids[keep_indices, sequence_slice]
self.position_ids = self.position_ids[keep_indices, sequence_slice]
@ -166,16 +192,24 @@ class VectorizedCausalLMBatch(Batch):
tensors_to_update = []
if self.past_key_values is not None:
if not isinstance(self.past_key_values, (list, tuple)):
raise NotImplementedError(f"Unsupported kv cache type: {type(self.past_key_values)}")
raise NotImplementedError(
f"Unsupported kv cache type: {type(self.past_key_values)}"
)
for layer_kv in self.past_key_values:
if isinstance(layer_kv, torch.Tensor):
tensors_to_update.append(layer_kv)
elif isinstance(layer_kv, (list, tuple)):
tensors_to_update.extend(layer_kv)
else:
raise NotImplementedError(f"Unsupported layer kv cache type: {type(layer_kv)}")
raise NotImplementedError(
f"Unsupported layer kv cache type: {type(layer_kv)}"
)
kv_cache_slice=[keep_indices, *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), sequence_slice]
kv_cache_slice = [
keep_indices,
*(slice(None) for _ in range(1, self.kv_cache_seq_dim)),
sequence_slice,
]
for tensor in tensors_to_update:
# Update tensors in-place to allow incremental garbage collection
tensors_to_update.data = tensor[kv_cache_slice]
@ -184,7 +218,9 @@ class VectorizedCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
def concatenate(
cls, batches: List["VectorizedCausalLMBatch"]
) -> "VectorizedCausalLMBatch":
if len(batches) == 0:
raise ValueError("Cannot concatenate empty list.")
requests = [request for batch in batches for request in batch.requests]
@ -196,16 +232,34 @@ class VectorizedCausalLMBatch(Batch):
input_lengths = [length for batch in batches for length in batch.input_lengths]
offsets = [offset for batch in batches for offset in batch.offsets]
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in requests], batches[0].input_ids.device)
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
token_offsets = [
token_offset for batch in batches for token_offset in batch.token_offsets
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.parameters for r in requests], batches[0].input_ids.device
)
stopping_criterias = [
stopping_criteria
for batch in batches
for stopping_criteria in batch.stopping_criterias
]
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
requests_idx_mapping = {
k: v + start_index
for batch, start_index in zip(batches, start_indices)
for k, v in batch.requests_idx_mapping.items()
}
max_input_length = max(input_lengths)
left_indices = [max_input_length - batch.max_input_length for batch in batches]
input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches))
input_shape = (
batch_size,
max_input_length
+ max(
batch.input_ids.size(1) - batch.max_input_length for batch in batches
),
)
device = batches[0].input_ids.device
# Allocate maximum attention_mask
@ -217,21 +271,32 @@ class VectorizedCausalLMBatch(Batch):
# TODO : only needed for prefill
input_ids[:, :max_input_length].fill_(0)
for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices):
attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length])
input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length])
for batch, start_index, end_index, left_index in zip(
batches, start_indices, end_indices, left_indices
):
attention_mask[start_index:end_index, left_index:max_input_length].copy_(
batch.attention_mask[:, : batch.max_input_length]
)
input_ids[start_index:end_index, left_index:max_input_length].copy_(
batch.input_ids[:, : batch.max_input_length]
)
position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_()
max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches)
max_tokens = sum(
batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
for batch in batches
)
kv_formats = None
for batch in batches:
if batch.past_key_values is None:
raise ValueError("Only concatenate prefilled batches")
if not isinstance(batch.past_key_values, (list, tuple)):
raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}")
raise NotImplementedError(
f"Unsupported kv cache type: {type(batch.past_key_values)}"
)
if kv_formats is None:
num_layers = len(batch.past_key_values)
if num_layers == 0:
@ -257,13 +322,30 @@ class VectorizedCausalLMBatch(Batch):
for j in range(1 if kv_format is None else kv_format):
tensors_to_merge = [batch.past_key_values[i] for batch in batches]
# Generally `max_input_length`, unless the model allocates more than needed.
right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)]
right_indices = [
left_index + tensor.size(kv_cache_seq_dim)
for tensor, left_index in zip(tensors_to_merge, left_indices)
]
combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:])
combined_shape[kv_cache_seq_dim] = max(right_indices)
# Set to zero to avoid propagating nans in padded values.
kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device)
for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices):
kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor)
kv_cache = torch.zeros(
combined_shape, dtype=tensors_to_merge[0].dtype, device=device
)
for tensor, start_index, end_index, left_index, right_index in zip(
tensors_to_merge,
start_indices,
end_indices,
left_indices,
right_indices,
):
kv_cache[
[
slice(start_index, end_index),
*(slice(None) for _ in range(1, kv_cache_seq_dim)),
slice(left_index, right_index),
]
].copy_(tensor)
if kv_format is None:
past_key_values.append(kv_cache)
elif j == 0:
@ -364,28 +446,45 @@ class VectorizedCausalLM(Model):
past_key_values=batch.past_key_values,
)
# TODO: Post-processing
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details)
next_token_ids, logprobs = batch.next_token_chooser(
input_ids, outputs.logits, batch.details
)
if batch.generate_stream:
# TODO: self.decode_token, offsets?
next_token_texts = self.tokenizer.batch_decode(next_token_ids.tolist())
if batch.details:
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist()
token_logprobs = (
logprobs[:, -1, :]
.gather(1, next_token_ids.unsqueeze(1))
.squeeze(1)
.tolist()
)
if query_length > 1:
prefill_token_ids = batch.input_ids[:, :key_length].tolist()
prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
prefill_logprobs = (
logprobs.gather(2, batch.input_ids[:, 1:key_length, None])
.squeeze(2)
.tolist()
)
prefill_tokens = []
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths):
for prefill_token_ids_, prefill_logprobs_, input_length in zip(
prefill_token_ids, prefill_logprobs, batch.input_lengths
):
prefill_token_ids_ = prefill_token_ids_[-input_length:]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids_,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens.append(PrefillTokens(
prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts
))
prefill_tokens.append(
PrefillTokens(
prefill_token_ids_,
[math.nan, *prefill_logprobs_],
prefill_texts,
)
)
# Update batch
# TODO: Why do we need all input ids?
@ -421,7 +520,6 @@ class VectorizedCausalLM(Model):
generated_text = None
next_batch = batch
generation = Generation(
batch.requests[i].id,
prefill_tokens[i] if batch.details and query_length > 1 else None,
@ -435,4 +533,3 @@ class VectorizedCausalLM(Model):
generations.append(generation)
return generations, next_batch

View File

@ -25,7 +25,9 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
"""
def __init__(self, penalty: List[float], device: torch.device):
self.penalty = torch.tensor(penalty, dtype=torch.float32, device=device).unsqueeze(1)
self.penalty = torch.tensor(
penalty, dtype=torch.float32, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
@ -36,6 +38,7 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
scores.scatter_(1, input_ids, score)
return scores
class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
@ -48,12 +51,15 @@ class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
"""
def __init__(self, temperature: List[float], device: torch.device):
self.temperature = torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
self.temperature = torch.tensor(
temperature, dtype=torch.float32, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature)
return scores
class HeterogeneousTopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
@ -70,8 +76,16 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(1)
def __init__(
self,
top_p: List[float],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(
1
)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@ -86,10 +100,13 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
class HeterogeneousTopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
@ -105,9 +122,19 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: List[int], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
def __init__(
self,
top_k: List[int],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.max_top_k = max(top_k)
self.top_k = torch.tensor([max(x - 1, min_tokens_to_keep-1) for x in top_k], dtype=torch.int64,device=device).unsqueeze(1)
self.top_k = torch.tensor(
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
dtype=torch.int64,
device=device,
).unsqueeze(1)
zeros = [x == 0 for x in top_k]
if any(zeros):
self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device)
@ -147,7 +174,13 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, mass: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
def __init__(
self,
mass: List[float],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.filter_value = filter_value
self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1)
self.min_tokens_to_keep = min_tokens_to_keep
@ -167,11 +200,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
@ -181,6 +218,7 @@ class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy = Greedy()
@ -191,6 +229,7 @@ class HeterogeneousSampling:
def __call__(self, logits):
return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits))
class HeterogeneousNextTokenChooser:
def __init__(
self,
@ -218,11 +257,17 @@ class HeterogeneousNextTokenChooser:
repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0)
if any([x != 1.0 for x in repetition_penalty]):
warpers.append(HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, device))
warpers.append(
HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, device
)
)
temperature = self._standardize(temperature, batch_size, 1.0)
if any([x != 1.0 for x in temperature]):
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device))
top_k = self._standardize(top_k, batch_size, 0)
@ -264,7 +309,9 @@ class HeterogeneousNextTokenChooser:
values[i] = default
return values
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
def __call__(
self, input_ids: torch.Tensor, scores: torch.Tensor, return_logprobs: bool
):
last_token_scores = self.warpers(input_ids, scores[:, -1, :])
next_token_ids = self.choice(last_token_scores)