Some work to get batching working.

This commit is contained in:
Nicolas Patry 2023-11-30 18:59:16 +00:00
parent b0cb4fa9d0
commit bdbccb774c
6 changed files with 329 additions and 102 deletions

View File

@ -0,0 +1,58 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -1.7587891,
"special": false,
"text": " failed"
},
{
"id": 363,
"logprob": -0.5175781,
"special": false,
"text": " for"
},
{
"id": 1404,
"logprob": 0.0,
"special": false,
"text": " user"
},
{
"id": 376,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 1688,
"logprob": -0.20422363,
"special": false,
"text": "test"
}
]
},
"generated_text": "Test request failed for user \"test"
}

View File

@ -0,0 +1,88 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -2.0878906,
"special": false,
"text": " for"
},
{
"id": 278,
"logprob": -3.4121094,
"special": false,
"text": " the"
},
{
"id": 376,
"logprob": -3.8457031,
"special": false,
"text": " \""
},
{
"id": 2577,
"logprob": -3.5566406,
"special": false,
"text": "Get"
},
{
"id": 599,
"logprob": -3.4746094,
"special": false,
"text": " all"
},
{
"id": 4160,
"logprob": -3.2363281,
"special": false,
"text": " users"
},
{
"id": 29908,
"logprob": -0.49023438,
"special": false,
"text": "\""
},
{
"id": 16248,
"logprob": -1.2402344,
"special": false,
"text": " endpoint"
},
{
"id": 29889,
"logprob": -0.88134766,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.41870117,
"special": false,
"text": "\n"
}
]
},
"generated_text": " for the \"Get all users\" endpoint.\n"
}

View File

@ -0,0 +1,58 @@
import pytest
@pytest.fixture(scope="module")
def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_medusa(flash_medusa_handle):
await flash_medusa_handle.health(300)
return flash_medusa_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 5
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -450,26 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
speculative_ids: Optional[torch.Tensor]
) -> torch.Tensor:
if speculative_ids is not None:
speculative_length = speculative_ids.shape[1]
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
# Add an extra block just in case
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
# Add Copy the block tables for all members
block_tables = block_tables.expand(new_length, -1).contiguous()
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -520,7 +501,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
speculative_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
@ -531,7 +511,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
speculative_ids,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -3,6 +3,7 @@ import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import torch.distributed
from loguru import logger
import numpy as np
@ -46,7 +47,6 @@ class FlashCausalLMBatch(Batch):
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
cu_seqlen_speculative: Optional[torch.Tensor]
# Paged Attention values
@ -123,7 +123,6 @@ class FlashCausalLMBatch(Batch):
position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0]
cu_seqlen_speculative = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
@ -163,18 +162,9 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :]
# # TODO remove this
# # Scaffolding to speculate some ids
# speculate_ids = [1, 2]
# tokenized_input.extend([1, 2])
speculate_ids = []
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
@ -186,7 +176,6 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
next_token_chooser_parameters.append(r.parameters)
@ -199,7 +188,8 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
speculative_length = 2
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
@ -268,9 +258,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
cu_seqlen_speculative = torch.tensor(
cu_seqlen_speculative, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
@ -296,6 +283,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64
)
logger.info("FROM PB")
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -303,7 +291,6 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
cu_seqlen_speculative=cu_seqlen_speculative,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
@ -331,6 +318,7 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
logger.info("FILTER")
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -344,6 +332,7 @@ class FlashCausalLMBatch(Batch):
# Used to index into tensors
indices = []
batch_indices = []
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
@ -371,14 +360,18 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_max_length = 0
logger.info(f"Request ids {request_ids} {len(self.requests)}")
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
batch_indices.append(idx)
S = 1 if self.speculative_ids is None else self.speculative_ids.shape[1] + 1
indices.extend(range(idx * S, (idx + 1) * S))
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Get length
logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
request_input_length = self.input_lengths[idx]
max_seqlen = max(max_seqlen, request_input_length)
@ -429,14 +422,19 @@ class FlashCausalLMBatch(Batch):
self.block_tables = None
# Index into tensors
logger.info(f"INPUT IDS {indices} {self.input_ids}")
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(batch_indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices]
logger.info(f"{indices} {self.speculative_ids}")
speculative_ids = self.speculative_ids[batch_indices]
logger.info(f"SPEC IDS {speculative_ids}")
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -472,12 +470,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}")
requests = []
requests_idx_mapping = {}
@ -488,7 +488,7 @@ class FlashCausalLMBatch(Batch):
max_length = 0
max_seqlen = 0
for b in batches:
total_batch_size += len(b)
total_batch_size += len(b.input_ids)
total_slots += len(b.slots)
blocks += b.blocks
max_blocks = max(max_blocks, b.max_blocks)
@ -536,6 +536,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_1 = 0
cumulative_slots = 0
for i, batch in enumerate(batches):
@ -546,23 +547,28 @@ class FlashCausalLMBatch(Batch):
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
requests_idx_mapping[k] = v + cumulative_1
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
end_index = cumulative_batch_size + len(batch.input_ids)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
start_index1 = cumulative_1
end_index1 = cumulative_1 + len(batch)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
logger.info(f"IN concat {batch.slot_indices}")
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
top_n_tokens_tensor[start_index1:end_index1] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
start_index1:end_index1, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
block_tables_tensor[
@ -584,11 +590,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.extend(batch.top_n_tokens)
# Update
cumulative_batch_size += len(batch)
cumulative_batch_size += len(batch.input_ids)
cumulative_slots += len(batch.slots)
cumulative_1 += len(batch)
start_slots = torch.concat(start_slots)
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
@ -629,6 +638,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
def __del__(self):
@ -731,18 +741,28 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
return self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
lm_head_indices=batch.prefill_head_indices,
speculative_ids =batch.speculative_ids
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=lm_head_indices,
# speculative_ids =batch.speculative_ids
)
@tracer.start_as_current_span("generate_token")
@ -790,8 +810,6 @@ class FlashCausalLM(Model):
next_token_logits = out
# if next_token_logits.shape[0] == 3:
# import ipdb;ipdb.set_trace()
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
)
@ -799,20 +817,15 @@ class FlashCausalLM(Model):
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
if speculative_ids is not None:
# length = len(batch) * (1 + speculative_length)
length = len(batch)
else:
length = len(batch)
# import ipdb;ipdb.set_trace()
length = len(batch)
next_position_ids = batch.position_ids.new_empty(length)
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
@ -820,7 +833,6 @@ class FlashCausalLM(Model):
batch.cu_seqlen_prefill = None
else:
prefill_logprobs = None
# import ipdb;ipdb.set_trace()
next_position_ids = batch.position_ids
# Cumulative length
@ -875,29 +887,32 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length
]
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length
# if accepted_ids[0] > 1:
# import ipdb;ipdb.set_trace()
if len(accepted_ids) > 1:
raise Exception("Implemtent the batched behavior")
# if len(accepted_ids) > 1:
# raise Exception("Implemtent the batched behavior")
# Set values in batch
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
for n_accepted_ids in accepted_ids:
# TODO Make this batched
batch.input_ids = next_input_ids[-1:]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + n_accepted_ids
batch.input_lengths_tensor += n_accepted_ids
batch.slot_indices += n_accepted_ids
accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}")
if accepted_ids.shape != batch.slot_indices:
# This can happen after a concatenation
# The slot indices is already modified for some speculative_ids
B = batch.slot_indices.shape[0] // accepted_ids.shape[0]
accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
batch.slot_indices += accepted_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -955,26 +970,22 @@ class FlashCausalLM(Model):
next_token_texts = []
for j in range(index, index + n_accepted_ids):
# Generated token
all_input_ids.append(next_token_ids[j])
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
# Evaluate stopping criteria
for next_token_id in _next_token_ids:
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
stopped = True
break
if not stop:
else:
stopped = False
# Shard generations
@ -1068,8 +1079,27 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
if prefill:
batch.max_seqlen += speculative_length
batch.max_seqlen = batch.max_seqlen + 1
# Model Forward
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
batch.input_ids = torch.cat([batch.input_ids.unsqueeze(-1), batch.speculative_ids], dim=1).view(-1)
if batch.position_ids.shape[0] != B * new_length:
arange = torch.arange(new_length).unsqueeze(0).to(device=batch.position_ids.device)
batch.position_ids = (batch.position_ids.view((-1, 1)).expand(B,new_length) + arange).view(-1)
batch.slot_indices = (batch.slot_indices.view((-1, 1)).expand(B,new_length) + arange.to(dtype=batch.slot_indices.dtype)).view(-1)
batch.input_lengths_tensor = (batch.input_lengths_tensor.view((-1, 1)).expand(B,new_length) + arange.to(dtype=batch.input_lengths_tensor.dtype)).view(-1)
batch.max_seqlen = batch.max_seqlen + speculative_length
# Add an extra block just in case
block_tables = torch.cat([batch.block_tables_tensor, batch.block_tables_tensor[:, -1:] + 1], dim=1)
# Add Copy the block tables for all members
# Contiguous because paged assumes contiguity
batch.block_tables_tensor = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous()
batch.lm_head_indices=batch.prefill_head_indices
cu_seqlen_prefill=batch.cu_seqlen_prefill
return generations, batch

View File

@ -225,21 +225,35 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores)
accepted_ids = []
next_ids = self.choice(scores)
from loguru import logger
if speculated_ids is not None:
validate_speculative = next_ids[:-1] == speculated_ids[0]
index = 1
for valid in validate_speculative.tolist():
if valid:
index += 1
# print(f"Validated {index - 1}")
next_ids = next_ids[:index]
scores = scores[:index]
speculative_scores = speculative_scores[index - 1:index]
accepted_ids.append(index)
logger.info(f"CHOOSER {next_ids} {speculated_ids}")
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
# print(f"Validated {accepted}")
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
speculative_scores = speculative_scores[accepted_ids.cumsum(dim=-1) - 1]
else:
accepted_ids.append(1)
accepted_ids = torch.ones_like(next_ids)
logprobs = torch.log_softmax(scores, -1)