mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Removing dead code.
This commit is contained in:
parent
ba16994e8a
commit
e95a5a897b
@ -11,8 +11,6 @@ from opentelemetry import trace
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -320,7 +318,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
# logger.info(f"Filter {request_ids}")
|
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
@ -471,7 +468,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||||
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -501,7 +497,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
|
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
@ -788,8 +783,6 @@ class FlashCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
|
|
||||||
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
|
|
||||||
prefill = batch.cu_seqlen_prefill is not None
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
@ -806,8 +799,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.block_tables_tensor = block_tables_tensor
|
batch.block_tables_tensor = block_tables_tensor
|
||||||
batch.slots = slots
|
batch.slots = slots
|
||||||
|
|
||||||
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = self.forward(batch)
|
out = self.forward(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -16,8 +16,6 @@ from text_generation_server.utils.logits_process import (
|
|||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -148,18 +146,17 @@ class StoppingCriteria:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
||||||
# import datetime
|
# Very trivial approach, find first match in the string.
|
||||||
# start = datetime.datetime.now()
|
# This is much less refined than actual n-gram but seems to work
|
||||||
|
# relatively OK in grounded mode and is by far much faster with
|
||||||
|
# much less worst case complexity as everything happens on device.
|
||||||
B = accepted_ids.shape[0]
|
B = accepted_ids.shape[0]
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
dtype = input_ids.dtype
|
|
||||||
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
|
||||||
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
||||||
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
||||||
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
||||||
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
||||||
|
|
||||||
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
|
|
||||||
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
||||||
return speculative_ids
|
return speculative_ids
|
||||||
|
|
||||||
@ -232,10 +229,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
|
||||||
import datetime
|
|
||||||
# from loguru import logger
|
|
||||||
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
S = speculated_ids.shape[1] + 1
|
S = speculated_ids.shape[1] + 1
|
||||||
@ -245,10 +238,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
S = 1
|
S = 1
|
||||||
scores = scores.view(B, S, -1)
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
# if verbose:
|
|
||||||
# logger.info(f"Reshape {datetime.datetime.now() - start}")
|
|
||||||
|
|
||||||
all_next_ids = []
|
|
||||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
for j in range(S):
|
for j in range(S):
|
||||||
_scores = scores[:, j]
|
_scores = scores[:, j]
|
||||||
@ -266,8 +255,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = next_ids.view(B*S)
|
next_ids = next_ids.view(B*S)
|
||||||
scores = scores.view( B* S, -1)
|
scores = scores.view( B* S, -1)
|
||||||
# if verbose:
|
|
||||||
# logger.info(f"Scores {datetime.datetime.now() - start}")
|
|
||||||
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
@ -299,8 +286,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||||
else:
|
else:
|
||||||
accepted_ids = torch.ones_like(next_ids)
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
# if verbose:
|
|
||||||
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
|
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
@ -314,8 +299,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
# if verbose:
|
|
||||||
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
|
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user