mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 05:22:07 +00:00
Dump work.
This commit is contained in:
parent
bdbccb774c
commit
5b340a5ffd
@ -53,6 +53,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot)
|
|||||||
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -283,7 +283,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("FROM PB")
|
# logger.info("FROM PB")
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
@ -318,7 +318,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
logger.info("FILTER")
|
# logger.info("FILTER")
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
@ -360,7 +360,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
|
|
||||||
logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
batch_indices.append(idx)
|
batch_indices.append(idx)
|
||||||
@ -371,10 +371,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests.append(self.requests[idx])
|
requests.append(self.requests[idx])
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
|
# logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
|
# logger.info(f"====Appending {self.all_input_ids[idx]}")
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
@ -422,19 +423,21 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.block_tables = None
|
self.block_tables = None
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
logger.info(f"INPUT IDS {indices} {self.input_ids}")
|
# logger.info(f"INPUT IDS {indices} {self.input_ids}")
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[batch_indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(batch_indices)
|
next_token_chooser = self.next_token_chooser.filter(batch_indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices]
|
||||||
|
|
||||||
logger.info(f"{indices} {self.speculative_ids}")
|
logger.info(f"FILTER {all_input_ids_tensor} {all_input_ids}")
|
||||||
|
|
||||||
|
# logger.info(f"{indices} {self.speculative_ids}")
|
||||||
speculative_ids = self.speculative_ids[batch_indices]
|
speculative_ids = self.speculative_ids[batch_indices]
|
||||||
logger.info(f"SPEC IDS {speculative_ids}")
|
# logger.info(f"SPEC IDS {speculative_ids}")
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -477,20 +480,23 @@ class FlashCausalLMBatch(Batch):
|
|||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}")
|
# logger.info(f"CONCATENATE {len(batches)}, {[b.input_ids for b in batches]}")
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
blocks = 0
|
blocks = 0
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
|
total_cu_size = 0
|
||||||
total_slots = 0
|
total_slots = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b.input_ids)
|
total_cu_size += len(b.input_ids)
|
||||||
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
blocks += b.blocks
|
||||||
|
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[0]
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -498,6 +504,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max(
|
max(
|
||||||
input_length
|
input_length
|
||||||
+ stopping_criteria.max_new_tokens
|
+ stopping_criteria.max_new_tokens
|
||||||
|
+ speculative_length
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
for input_length, stopping_criteria in zip(
|
for input_length, stopping_criteria in zip(
|
||||||
b.input_lengths, b.stopping_criterias
|
b.input_lengths, b.stopping_criterias
|
||||||
@ -505,21 +512,21 @@ class FlashCausalLMBatch(Batch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_cu_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_cu_size)
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_cu_size)
|
||||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_cu_size
|
||||||
)
|
)
|
||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_cu_size, max_blocks)
|
||||||
)
|
)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
total_batch_size,
|
total_cu_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
@ -560,7 +567,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
logger.info(f"IN concat {batch.slot_indices}")
|
# logger.info(f"IN concat {batch.slot_indices}")
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
|
|
||||||
@ -813,11 +820,13 @@ class FlashCausalLM(Model):
|
|||||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
logger.info(f"CHOOSER {next_input_ids} -> {accepted_ids}")
|
||||||
|
# logger.info(f"CHOOSER {accepted_ids}")
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
)
|
)
|
||||||
logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
|
# logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
@ -887,7 +896,6 @@ class FlashCausalLM(Model):
|
|||||||
start_index + 1 : start_index + out_length
|
start_index + 1 : start_index + out_length
|
||||||
]
|
]
|
||||||
|
|
||||||
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
|
||||||
for j in range(n_accepted_ids):
|
for j in range(n_accepted_ids):
|
||||||
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
||||||
index += 1
|
index += 1
|
||||||
@ -900,19 +908,20 @@ class FlashCausalLM(Model):
|
|||||||
# Set values in batch
|
# Set values in batch
|
||||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||||
|
|
||||||
accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
|
# accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
|
|
||||||
logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}")
|
# logger.info(f"ACCEPTED IDS {accepted_ids} ")
|
||||||
if accepted_ids.shape != batch.slot_indices:
|
if accepted_ids.shape != batch.slot_indices:
|
||||||
# This can happen after a concatenation
|
# This can happen after a concatenation
|
||||||
# The slot indices is already modified for some speculative_ids
|
# The slot indices is already modified for some speculative_ids
|
||||||
B = batch.slot_indices.shape[0] // accepted_ids.shape[0]
|
B = batch.slot_indices.shape[0] // accepted_ids.shape[0]
|
||||||
accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
|
step_accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
|
||||||
batch.slot_indices += accepted_ids
|
# logger.info(f"ACCEPTED IDS AFTER {accepted_ids} ")
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.slot_indices += step_accepted_ids
|
||||||
batch.input_lengths_tensor += accepted_ids
|
batch.position_ids = next_position_ids + step_accepted_ids
|
||||||
|
batch.input_lengths_tensor += step_accepted_ids
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
@ -964,6 +973,8 @@ class FlashCausalLM(Model):
|
|||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
|
|
||||||
|
logger.info(f"Next token ids {next_token_ids} -> {index}")
|
||||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
||||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||||
|
|
||||||
@ -988,6 +999,8 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
index += n_accepted_ids
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
|
@ -228,7 +228,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
logger.info(f"CHOOSER {next_ids} {speculated_ids}")
|
# logger.info(f"CHOOSER {next_ids} {speculated_ids}")
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
S = speculated_ids.shape[1] + 1
|
S = speculated_ids.shape[1] + 1
|
||||||
|
Loading…
Reference in New Issue
Block a user