Removing dead code.

This commit is contained in:
Nicolas Patry 2023-12-08 17:33:30 +00:00
parent ba16994e8a
commit e95a5a897b
2 changed files with 4 additions and 30 deletions

View File

@ -11,8 +11,6 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from loguru import logger
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -320,7 +318,6 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
# logger.info(f"Filter {request_ids}")
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
@ -471,7 +468,6 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -501,7 +497,6 @@ class FlashCausalLMBatch(Batch):
) )
), ),
) )
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
@ -788,8 +783,6 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
prefill = batch.cu_seqlen_prefill is not None prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
@ -806,8 +799,6 @@ class FlashCausalLM(Model):
batch.block_tables_tensor = block_tables_tensor batch.block_tables_tensor = block_tables_tensor
batch.slots = slots batch.slots = slots
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
try: try:
out = self.forward(batch) out = self.forward(batch)
except Exception as e: except Exception as e:

View File

@ -16,8 +16,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from loguru import logger
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
@ -148,18 +146,17 @@ class StoppingCriteria:
) )
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool): def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# import datetime # Very trivial approach, find first match in the string.
# start = datetime.datetime.now() # This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0] B = accepted_ids.shape[0]
device = input_ids.device device = input_ids.device
dtype = input_ids.dtype
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
speculative_ids = input_ids.gather(dim=-1, index=all_indices) speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids return speculative_ids
@ -232,10 +229,6 @@ class HeterogeneousNextTokenChooser:
self.device = device self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
import datetime
# from loguru import logger
start = datetime.datetime.now()
if speculated_ids is not None: if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1) B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1 S = speculated_ids.shape[1] + 1
@ -245,10 +238,6 @@ class HeterogeneousNextTokenChooser:
S = 1 S = 1
scores = scores.view(B, S, -1) scores = scores.view(B, S, -1)
# if verbose:
# logger.info(f"Reshape {datetime.datetime.now() - start}")
all_next_ids = []
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S): for j in range(S):
_scores = scores[:, j] _scores = scores[:, j]
@ -266,8 +255,6 @@ class HeterogeneousNextTokenChooser:
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S) next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1) scores = scores.view( B* S, -1)
# if verbose:
# logger.info(f"Scores {datetime.datetime.now() - start}")
if speculated_ids is not None: if speculated_ids is not None:
accepted_ids = [] accepted_ids = []
@ -299,8 +286,6 @@ class HeterogeneousNextTokenChooser:
speculative_scores = speculative_scores[indices + accepted_ids - 1] speculative_scores = speculative_scores[indices + accepted_ids - 1]
else: else:
accepted_ids = torch.ones_like(next_ids) accepted_ids = torch.ones_like(next_ids)
# if verbose:
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
@ -314,8 +299,6 @@ class HeterogeneousNextTokenChooser:
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else: else:
speculative_ids = None speculative_ids = None
# if verbose:
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids