diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index 64180c0a..47a9b0e2 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -53,6 +53,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot) responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4) assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert responses == response_snapshot diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 539bdf52..d1d603d0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -283,7 +283,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) - logger.info("FROM PB") + # logger.info("FROM PB") return cls( batch_id=pb.id, requests=pb.requests, @@ -318,7 +318,7 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - logger.info("FILTER") + # logger.info("FILTER") if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -360,7 +360,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_max_length = 0 - logger.info(f"Request ids {request_ids} {len(self.requests)}") + # logger.info(f"Request ids {request_ids} {len(self.requests)}") for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] batch_indices.append(idx) @@ -371,10 +371,11 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) # Get length - logger.info(f"Input lengths {self.input_lengths} {idx} {S}") + # logger.info(f"Input lengths {self.input_lengths} {idx} {S}") request_input_length = self.input_lengths[idx] max_seqlen = max(max_seqlen, request_input_length) + # logger.info(f"====Appending {self.all_input_ids[idx]}") all_input_ids.append(self.all_input_ids[idx]) input_lengths.append(request_input_length) @@ -422,19 +423,21 @@ class FlashCausalLMBatch(Batch): self.block_tables = None # Index into tensors - logger.info(f"INPUT IDS {indices} {self.input_ids}") + # logger.info(f"INPUT IDS {indices} {self.input_ids}") input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] - all_input_ids_tensor = self.all_input_ids_tensor[indices] + all_input_ids_tensor = self.all_input_ids_tensor[batch_indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(batch_indices) top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices] - logger.info(f"{indices} {self.speculative_ids}") + logger.info(f"FILTER {all_input_ids_tensor} {all_input_ids}") + + # logger.info(f"{indices} {self.speculative_ids}") speculative_ids = self.speculative_ids[batch_indices] - logger.info(f"SPEC IDS {speculative_ids}") + # logger.info(f"SPEC IDS {speculative_ids}") start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -477,20 +480,23 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": # Batch attributes - logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}") + # logger.info(f"CONCATENATE {len(batches)}, {[b.input_ids for b in batches]}") requests = [] requests_idx_mapping = {} blocks = 0 total_batch_size = 0 + total_cu_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 max_seqlen = 0 for b in batches: - total_batch_size += len(b.input_ids) + total_cu_size += len(b.input_ids) + total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks + speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[0] max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -498,6 +504,7 @@ class FlashCausalLMBatch(Batch): max( input_length + stopping_criteria.max_new_tokens + + speculative_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( b.input_lengths, b.stopping_criterias @@ -505,21 +512,21 @@ class FlashCausalLMBatch(Batch): ), ) - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) + input_ids = batches[0].input_ids.new_empty(total_cu_size) + position_ids = batches[0].position_ids.new_empty(total_cu_size) slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + slot_indices = batches[0].slot_indices.new_empty(total_cu_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( - total_batch_size + total_cu_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( - (total_batch_size, max_blocks) + (total_cu_size, max_blocks) ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, + total_cu_size, ) start_slots = [] @@ -560,7 +567,7 @@ class FlashCausalLMBatch(Batch): # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - logger.info(f"IN concat {batch.slot_indices}") + # logger.info(f"IN concat {batch.slot_indices}") slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor @@ -813,11 +820,13 @@ class FlashCausalLM(Model): next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits ) + logger.info(f"CHOOSER {next_input_ids} -> {accepted_ids}") + # logger.info(f"CHOOSER {accepted_ids}") batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) - logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}") + # logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}") if prefill: if len(batch) > 1 and prefill_logprobs: @@ -855,7 +864,7 @@ class FlashCausalLM(Model): # For each member of the batch index = 0 - for i, ( + for i, ( input_length, all_input_ids, n_accepted_ids @@ -887,7 +896,6 @@ class FlashCausalLM(Model): start_index + 1 : start_index + out_length ] - # logger.info(f"Request ids {request_ids} {len(self.requests)}") for j in range(n_accepted_ids): batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] index += 1 @@ -900,19 +908,20 @@ class FlashCausalLM(Model): # Set values in batch # batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1) - accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype) + # accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype) batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids - logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}") + # logger.info(f"ACCEPTED IDS {accepted_ids} ") if accepted_ids.shape != batch.slot_indices: # This can happen after a concatenation # The slot indices is already modified for some speculative_ids B = batch.slot_indices.shape[0] // accepted_ids.shape[0] - accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1) - batch.slot_indices += accepted_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids + step_accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1) + # logger.info(f"ACCEPTED IDS AFTER {accepted_ids} ") + batch.slot_indices += step_accepted_ids + batch.position_ids = next_position_ids + step_accepted_ids + batch.input_lengths_tensor += step_accepted_ids if prefill and prefill_logprobs: # Get prefill logprobs @@ -964,6 +973,8 @@ class FlashCausalLM(Model): top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens + + logger.info(f"Next token ids {next_token_ids} -> {index}") _next_token_ids = next_token_ids[index: index+n_accepted_ids] _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids] @@ -988,6 +999,8 @@ class FlashCausalLM(Model): else: stopped = False + index += n_accepted_ids + # Shard generations # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 2f4b4f7a..0ebaaa3f 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -228,7 +228,7 @@ class HeterogeneousNextTokenChooser: next_ids = self.choice(scores) from loguru import logger if speculated_ids is not None: - logger.info(f"CHOOSER {next_ids} {speculated_ids}") + # logger.info(f"CHOOSER {next_ids} {speculated_ids}") accepted_ids = [] B = next_ids.shape[0] // (speculated_ids.shape[1] + 1) S = speculated_ids.shape[1] + 1