Removing dead code.

This commit is contained in:
Nicolas Patry 2023-12-08 17:33:30 +00:00
parent ba16994e8a
commit e95a5a897b
2 changed files with 4 additions and 30 deletions

View File

@ -11,8 +11,6 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from loguru import logger
from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
@ -320,7 +318,6 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
# logger.info(f"Filter {request_ids}")
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -471,7 +468,6 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -501,7 +497,6 @@ class FlashCausalLMBatch(Batch):
)
),
)
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
@ -788,8 +783,6 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
@ -806,8 +799,6 @@ class FlashCausalLM(Model):
batch.block_tables_tensor = block_tables_tensor
batch.slots = slots
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
try:
out = self.forward(batch)
except Exception as e:

View File

@ -16,8 +16,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from loguru import logger
class NextTokenChooser:
def __init__(
self,
@ -148,18 +146,17 @@ class StoppingCriteria:
)
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# import datetime
# start = datetime.datetime.now()
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
dtype = input_ids.dtype
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
@ -232,10 +229,6 @@ class HeterogeneousNextTokenChooser:
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
import datetime
# from loguru import logger
start = datetime.datetime.now()
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
@ -245,10 +238,6 @@ class HeterogeneousNextTokenChooser:
S = 1
scores = scores.view(B, S, -1)
# if verbose:
# logger.info(f"Reshape {datetime.datetime.now() - start}")
all_next_ids = []
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
@ -266,8 +255,6 @@ class HeterogeneousNextTokenChooser:
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
# if verbose:
# logger.info(f"Scores {datetime.datetime.now() - start}")
if speculated_ids is not None:
accepted_ids = []
@ -299,8 +286,6 @@ class HeterogeneousNextTokenChooser:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
# if verbose:
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
@ -314,8 +299,6 @@ class HeterogeneousNextTokenChooser:
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else:
speculative_ids = None
# if verbose:
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids