Dump work.

This commit is contained in:
Nicolas Patry 2023-11-30 22:05:51 +00:00
parent bdbccb774c
commit 5b340a5ffd
3 changed files with 41 additions and 28 deletions

View File

@ -53,6 +53,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot)
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -283,7 +283,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64
)
logger.info("FROM PB")
# logger.info("FROM PB")
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -318,7 +318,7 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
logger.info("FILTER")
# logger.info("FILTER")
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -360,7 +360,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_max_length = 0
logger.info(f"Request ids {request_ids} {len(self.requests)}")
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
batch_indices.append(idx)
@ -371,10 +371,11 @@ class FlashCausalLMBatch(Batch):
requests.append(self.requests[idx])
# Get length
logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
# logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
request_input_length = self.input_lengths[idx]
max_seqlen = max(max_seqlen, request_input_length)
# logger.info(f"====Appending {self.all_input_ids[idx]}")
all_input_ids.append(self.all_input_ids[idx])
input_lengths.append(request_input_length)
@ -422,19 +423,21 @@ class FlashCausalLMBatch(Batch):
self.block_tables = None
# Index into tensors
logger.info(f"INPUT IDS {indices} {self.input_ids}")
# logger.info(f"INPUT IDS {indices} {self.input_ids}")
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
all_input_ids_tensor = self.all_input_ids_tensor[batch_indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(batch_indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices]
logger.info(f"{indices} {self.speculative_ids}")
logger.info(f"FILTER {all_input_ids_tensor} {all_input_ids}")
# logger.info(f"{indices} {self.speculative_ids}")
speculative_ids = self.speculative_ids[batch_indices]
logger.info(f"SPEC IDS {speculative_ids}")
# logger.info(f"SPEC IDS {speculative_ids}")
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -477,20 +480,23 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}")
# logger.info(f"CONCATENATE {len(batches)}, {[b.input_ids for b in batches]}")
requests = []
requests_idx_mapping = {}
blocks = 0
total_batch_size = 0
total_cu_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_seqlen = 0
for b in batches:
total_batch_size += len(b.input_ids)
total_cu_size += len(b.input_ids)
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[0]
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
@ -498,6 +504,7 @@ class FlashCausalLMBatch(Batch):
max(
input_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
@ -505,21 +512,21 @@ class FlashCausalLMBatch(Batch):
),
)
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
input_ids = batches[0].input_ids.new_empty(total_cu_size)
position_ids = batches[0].position_ids.new_empty(total_cu_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_cu_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
total_cu_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
(total_cu_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
total_cu_size,
)
start_slots = []
@ -560,7 +567,7 @@ class FlashCausalLMBatch(Batch):
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
logger.info(f"IN concat {batch.slot_indices}")
# logger.info(f"IN concat {batch.slot_indices}")
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
@ -813,11 +820,13 @@ class FlashCausalLM(Model):
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
)
logger.info(f"CHOOSER {next_input_ids} -> {accepted_ids}")
# logger.info(f"CHOOSER {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
# logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
if prefill:
if len(batch) > 1 and prefill_logprobs:
@ -855,7 +864,7 @@ class FlashCausalLM(Model):
# For each member of the batch
index = 0
for i, (
for i, (
input_length,
all_input_ids,
n_accepted_ids
@ -887,7 +896,6 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length
]
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
@ -900,19 +908,20 @@ class FlashCausalLM(Model):
# Set values in batch
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
# accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}")
# logger.info(f"ACCEPTED IDS {accepted_ids} ")
if accepted_ids.shape != batch.slot_indices:
# This can happen after a concatenation
# The slot indices is already modified for some speculative_ids
B = batch.slot_indices.shape[0] // accepted_ids.shape[0]
accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
batch.slot_indices += accepted_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
step_accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
# logger.info(f"ACCEPTED IDS AFTER {accepted_ids} ")
batch.slot_indices += step_accepted_ids
batch.position_ids = next_position_ids + step_accepted_ids
batch.input_lengths_tensor += step_accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -964,6 +973,8 @@ class FlashCausalLM(Model):
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
logger.info(f"Next token ids {next_token_ids} -> {index}")
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
@ -988,6 +999,8 @@ class FlashCausalLM(Model):
else:
stopped = False
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:

View File

@ -228,7 +228,7 @@ class HeterogeneousNextTokenChooser:
next_ids = self.choice(scores)
from loguru import logger
if speculated_ids is not None:
logger.info(f"CHOOSER {next_ids} {speculated_ids}")
# logger.info(f"CHOOSER {next_ids} {speculated_ids}")
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1