diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b5ac1a16..ed6d9198 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,8 +11,6 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict -from loguru import logger - from text_generation_server.models import Model from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( @@ -320,7 +318,6 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - # logger.info(f"Filter {request_ids}") if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -471,7 +468,6 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": - # logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}") # Batch attributes requests = [] requests_idx_mapping = {} @@ -501,7 +497,6 @@ class FlashCausalLMBatch(Batch): ) ), ) - # logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}") input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) @@ -788,8 +783,6 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - # logger.info(f"GENERATE {[r.id for r in batch.requests]}") - # logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}") prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None @@ -806,8 +799,6 @@ class FlashCausalLM(Model): batch.block_tables_tensor = block_tables_tensor batch.slots = slots - # logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}") - try: out = self.forward(batch) except Exception as e: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e3457483..a34c5afc 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -16,8 +16,6 @@ from text_generation_server.utils.logits_process import ( from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor -from loguru import logger - class NextTokenChooser: def __init__( self, @@ -148,18 +146,17 @@ class StoppingCriteria: ) def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool): - # import datetime - # start = datetime.datetime.now() + # Very trivial approach, find first match in the string. + # This is much less refined than actual n-gram but seems to work + # relatively OK in grounded mode and is by far much faster with + # much less worst case complexity as everything happens on device. B = accepted_ids.shape[0] device = input_ids.device - dtype = input_ids.dtype - # speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) - # logger.info(f"All indices {all_indices} - {input_ids.shape}") speculative_ids = input_ids.gather(dim=-1, index=all_indices) return speculative_ids @@ -232,10 +229,6 @@ class HeterogeneousNextTokenChooser: self.device = device def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): - import datetime - # from loguru import logger - - start = datetime.datetime.now() if speculated_ids is not None: B = scores.shape[0] // (speculated_ids.shape[1] + 1) S = speculated_ids.shape[1] + 1 @@ -245,10 +238,6 @@ class HeterogeneousNextTokenChooser: S = 1 scores = scores.view(B, S, -1) - # if verbose: - # logger.info(f"Reshape {datetime.datetime.now() - start}") - - all_next_ids = [] next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) for j in range(S): _scores = scores[:, j] @@ -266,8 +255,6 @@ class HeterogeneousNextTokenChooser: next_ids[:, j] = _next_ids next_ids = next_ids.view(B*S) scores = scores.view( B* S, -1) - # if verbose: - # logger.info(f"Scores {datetime.datetime.now() - start}") if speculated_ids is not None: accepted_ids = [] @@ -299,8 +286,6 @@ class HeterogeneousNextTokenChooser: speculative_scores = speculative_scores[indices + accepted_ids - 1] else: accepted_ids = torch.ones_like(next_ids) - # if verbose: - # logger.info(f"Indices/accepted id {datetime.datetime.now() - start}") logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) @@ -314,8 +299,6 @@ class HeterogeneousNextTokenChooser: speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) else: speculative_ids = None - # if verbose: - # logger.info(f"new speculative ids {datetime.datetime.now() - start}") return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids