mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
Dump work.
This commit is contained in:
parent
bdbccb774c
commit
5b340a5ffd
@ -53,6 +53,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot)
|
||||
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -283,7 +283,7 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
logger.info("FROM PB")
|
||||
# logger.info("FROM PB")
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
@ -318,7 +318,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
logger.info("FILTER")
|
||||
# logger.info("FILTER")
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
@ -360,7 +360,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
|
||||
logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
||||
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
batch_indices.append(idx)
|
||||
@ -371,10 +371,11 @@ class FlashCausalLMBatch(Batch):
|
||||
requests.append(self.requests[idx])
|
||||
|
||||
# Get length
|
||||
logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
|
||||
# logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
|
||||
request_input_length = self.input_lengths[idx]
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
|
||||
# logger.info(f"====Appending {self.all_input_ids[idx]}")
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
|
||||
input_lengths.append(request_input_length)
|
||||
@ -422,19 +423,21 @@ class FlashCausalLMBatch(Batch):
|
||||
self.block_tables = None
|
||||
|
||||
# Index into tensors
|
||||
logger.info(f"INPUT IDS {indices} {self.input_ids}")
|
||||
# logger.info(f"INPUT IDS {indices} {self.input_ids}")
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[batch_indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(batch_indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices]
|
||||
|
||||
logger.info(f"{indices} {self.speculative_ids}")
|
||||
logger.info(f"FILTER {all_input_ids_tensor} {all_input_ids}")
|
||||
|
||||
# logger.info(f"{indices} {self.speculative_ids}")
|
||||
speculative_ids = self.speculative_ids[batch_indices]
|
||||
logger.info(f"SPEC IDS {speculative_ids}")
|
||||
# logger.info(f"SPEC IDS {speculative_ids}")
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -477,20 +480,23 @@ class FlashCausalLMBatch(Batch):
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}")
|
||||
# logger.info(f"CONCATENATE {len(batches)}, {[b.input_ids for b in batches]}")
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
blocks = 0
|
||||
total_batch_size = 0
|
||||
total_cu_size = 0
|
||||
total_slots = 0
|
||||
max_blocks = 0
|
||||
max_length = 0
|
||||
max_seqlen = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b.input_ids)
|
||||
total_cu_size += len(b.input_ids)
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
blocks += b.blocks
|
||||
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[0]
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||
max_length = max(
|
||||
@ -498,6 +504,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max(
|
||||
input_length
|
||||
+ stopping_criteria.max_new_tokens
|
||||
+ speculative_length
|
||||
- stopping_criteria.current_tokens
|
||||
for input_length, stopping_criteria in zip(
|
||||
b.input_lengths, b.stopping_criterias
|
||||
@ -505,21 +512,21 @@ class FlashCausalLMBatch(Batch):
|
||||
),
|
||||
)
|
||||
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
input_ids = batches[0].input_ids.new_empty(total_cu_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_cu_size)
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_cu_size)
|
||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
total_cu_size
|
||||
)
|
||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||
(total_batch_size, max_blocks)
|
||||
(total_cu_size, max_blocks)
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
total_cu_size,
|
||||
)
|
||||
|
||||
start_slots = []
|
||||
@ -560,7 +567,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Copy tensors (GPU)
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
logger.info(f"IN concat {batch.slot_indices}")
|
||||
# logger.info(f"IN concat {batch.slot_indices}")
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
|
||||
@ -813,11 +820,13 @@ class FlashCausalLM(Model):
|
||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||
)
|
||||
logger.info(f"CHOOSER {next_input_ids} -> {accepted_ids}")
|
||||
# logger.info(f"CHOOSER {accepted_ids}")
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
|
||||
# logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
@ -887,7 +896,6 @@ class FlashCausalLM(Model):
|
||||
start_index + 1 : start_index + out_length
|
||||
]
|
||||
|
||||
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
||||
index += 1
|
||||
@ -900,19 +908,20 @@ class FlashCausalLM(Model):
|
||||
# Set values in batch
|
||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||
|
||||
accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
|
||||
# accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
|
||||
logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}")
|
||||
# logger.info(f"ACCEPTED IDS {accepted_ids} ")
|
||||
if accepted_ids.shape != batch.slot_indices:
|
||||
# This can happen after a concatenation
|
||||
# The slot indices is already modified for some speculative_ids
|
||||
B = batch.slot_indices.shape[0] // accepted_ids.shape[0]
|
||||
accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.input_lengths_tensor += accepted_ids
|
||||
step_accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
|
||||
# logger.info(f"ACCEPTED IDS AFTER {accepted_ids} ")
|
||||
batch.slot_indices += step_accepted_ids
|
||||
batch.position_ids = next_position_ids + step_accepted_ids
|
||||
batch.input_lengths_tensor += step_accepted_ids
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
@ -964,6 +973,8 @@ class FlashCausalLM(Model):
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Append next token to all tokens
|
||||
|
||||
logger.info(f"Next token ids {next_token_ids} -> {index}")
|
||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||
|
||||
@ -988,6 +999,8 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
index += n_accepted_ids
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
|
@ -228,7 +228,7 @@ class HeterogeneousNextTokenChooser:
|
||||
next_ids = self.choice(scores)
|
||||
from loguru import logger
|
||||
if speculated_ids is not None:
|
||||
logger.info(f"CHOOSER {next_ids} {speculated_ids}")
|
||||
# logger.info(f"CHOOSER {next_ids} {speculated_ids}")
|
||||
accepted_ids = []
|
||||
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
||||
S = speculated_ids.shape[1] + 1
|
||||
|
Loading…
Reference in New Issue
Block a user