mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Removing dead code.
This commit is contained in:
parent
ba16994e8a
commit
e95a5a897b
@ -11,8 +11,6 @@ from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.models.types import (
|
||||
@ -320,7 +318,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
# logger.info(f"Filter {request_ids}")
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
@ -471,7 +468,6 @@ class FlashCausalLMBatch(Batch):
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -501,7 +497,6 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
),
|
||||
)
|
||||
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
|
||||
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
@ -788,8 +783,6 @@ class FlashCausalLM(Model):
|
||||
def generate_token(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
|
||||
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
|
||||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
@ -806,8 +799,6 @@ class FlashCausalLM(Model):
|
||||
batch.block_tables_tensor = block_tables_tensor
|
||||
batch.slots = slots
|
||||
|
||||
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
|
||||
|
||||
try:
|
||||
out = self.forward(batch)
|
||||
except Exception as e:
|
||||
|
@ -16,8 +16,6 @@ from text_generation_server.utils.logits_process import (
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
from loguru import logger
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
@ -148,18 +146,17 @@ class StoppingCriteria:
|
||||
)
|
||||
|
||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
||||
# import datetime
|
||||
# start = datetime.datetime.now()
|
||||
# Very trivial approach, find first match in the string.
|
||||
# This is much less refined than actual n-gram but seems to work
|
||||
# relatively OK in grounded mode and is by far much faster with
|
||||
# much less worst case complexity as everything happens on device.
|
||||
B = accepted_ids.shape[0]
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
||||
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
||||
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
||||
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
||||
|
||||
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
|
||||
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
||||
return speculative_ids
|
||||
|
||||
@ -232,10 +229,6 @@ class HeterogeneousNextTokenChooser:
|
||||
self.device = device
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
|
||||
import datetime
|
||||
# from loguru import logger
|
||||
|
||||
start = datetime.datetime.now()
|
||||
if speculated_ids is not None:
|
||||
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||
S = speculated_ids.shape[1] + 1
|
||||
@ -245,10 +238,6 @@ class HeterogeneousNextTokenChooser:
|
||||
S = 1
|
||||
scores = scores.view(B, S, -1)
|
||||
|
||||
# if verbose:
|
||||
# logger.info(f"Reshape {datetime.datetime.now() - start}")
|
||||
|
||||
all_next_ids = []
|
||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||
for j in range(S):
|
||||
_scores = scores[:, j]
|
||||
@ -266,8 +255,6 @@ class HeterogeneousNextTokenChooser:
|
||||
next_ids[:, j] = _next_ids
|
||||
next_ids = next_ids.view(B*S)
|
||||
scores = scores.view( B* S, -1)
|
||||
# if verbose:
|
||||
# logger.info(f"Scores {datetime.datetime.now() - start}")
|
||||
|
||||
if speculated_ids is not None:
|
||||
accepted_ids = []
|
||||
@ -299,8 +286,6 @@ class HeterogeneousNextTokenChooser:
|
||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||
else:
|
||||
accepted_ids = torch.ones_like(next_ids)
|
||||
# if verbose:
|
||||
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
|
||||
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
@ -314,8 +299,6 @@ class HeterogeneousNextTokenChooser:
|
||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
||||
else:
|
||||
speculative_ids = None
|
||||
# if verbose:
|
||||
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
|
||||
|
||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user