mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge pull request #4 from rsnm2/stopping-next-token-chooser
Stopping next token chooser
This commit is contained in:
commit
b55270320c
@ -1,68 +0,0 @@
|
||||
import fastapi, uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from router import DeepSparseRouter, batching_task
|
||||
from utils import GenerateRequest
|
||||
|
||||
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
|
||||
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
|
||||
|
||||
def serve(
|
||||
model_path=MODEL_PATH,
|
||||
tokenizer_path=TOKENIZER_PATH,
|
||||
host="0.0.0.0",
|
||||
port=5543
|
||||
):
|
||||
|
||||
router = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
print("\n-------------------- Building Router --------------------\n")
|
||||
router = DeepSparseRouter(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path
|
||||
)
|
||||
|
||||
print("\n-------------------- Starting Batching Task --------------------\n")
|
||||
batching_thread = Thread(target=batching_task, args=[router])
|
||||
batching_thread.start()
|
||||
|
||||
print("\n-------------------- Launching App --------------------\n")
|
||||
yield
|
||||
|
||||
print("\n-------------------- Joining Batching Task --------------------\n")
|
||||
router.stop_batching_task()
|
||||
batching_task.join()
|
||||
|
||||
app = fastapi.FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.get("/generate/{prompt}")
|
||||
async def generate(prompt:str):
|
||||
response_stream = Queue()
|
||||
router.submit_request(
|
||||
GenerateRequest(
|
||||
prompt=prompt,
|
||||
max_generated_tokens=100,
|
||||
response_stream=response_stream
|
||||
)
|
||||
)
|
||||
|
||||
response_string = prompt
|
||||
generation = response_stream.get()
|
||||
while not generation.stopped:
|
||||
response_string += generation.token
|
||||
generation = response_stream.get()
|
||||
|
||||
return response_string
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
workers=1
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
@ -2,7 +2,7 @@ from queue import Queue
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from service.service import DeepSparseService
|
||||
from service.causal_lm import DeepSparseCausalLM
|
||||
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
||||
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
|
||||
|
||||
class DeepSparseRouter:
|
||||
def __init__(
|
||||
@ -11,10 +11,7 @@ class DeepSparseRouter:
|
||||
model_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None
|
||||
):
|
||||
assert (
|
||||
service is not None or
|
||||
(model_path is not None and tokenizer_path is not None)
|
||||
)
|
||||
assert service is not None or (model_path is not None and tokenizer_path is not None)
|
||||
|
||||
if service is not None:
|
||||
self.service = service
|
||||
@ -38,14 +35,14 @@ class DeepSparseRouter:
|
||||
|
||||
# unblock the batching task with a dummy request if blocked
|
||||
self.queue.append(GenerateRequest(
|
||||
prompt="dummy",
|
||||
max_generated_tokens=1,
|
||||
inputs="stop",
|
||||
generation_parameters=GenerationParameters(max_new_tokens=1),
|
||||
response_stream=Queue()
|
||||
))
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
batch: Batch,
|
||||
self,
|
||||
batch: Batch,
|
||||
generate_requests: Dict[int, GenerateRequest]
|
||||
) -> Optional[CachedBatch]:
|
||||
|
||||
@ -166,17 +163,18 @@ class DeepSparseQueue:
|
||||
|
||||
# if block = True, this blocks until something ready
|
||||
# if block = False, the queue has data (if not an exception is raised)
|
||||
# while queue.empty() == False does not guarentee data
|
||||
# the queue is only subscribed to by one thread (this one)
|
||||
# since batching_task is the only function that calls next_batch
|
||||
# while queue.empty() == False typically not guarentee data on next queue.get(), this
|
||||
# queue is only subscribed to by one thread (this one) since batching_task is the only
|
||||
# so it does in our case
|
||||
|
||||
generate_request = self.queue.get(block=block)
|
||||
generate_requests = {self.next_request_id: generate_request}
|
||||
|
||||
# format into request
|
||||
request = Request(
|
||||
id=self.next_request_id,
|
||||
prompt=generate_request.prompt,
|
||||
max_generated_tokens=generate_request.max_generated_tokens
|
||||
inputs=generate_request.inputs,
|
||||
generation_parameters=generate_request.generation_parameters,
|
||||
)
|
||||
self.next_request_id += 1
|
||||
|
||||
|
@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
import numpy as np
|
||||
|
||||
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||
from utils import Request, Batch, CachedBatch, Generation
|
||||
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
|
||||
|
||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||
@ -16,6 +16,8 @@ class DeepSparseCausalLMBatch:
|
||||
requests_idx_mapping: Dict[int,int]
|
||||
input_ids_list: List[np.ndarray]
|
||||
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
||||
stopping_criteria_list: List[StoppingCriteria]
|
||||
next_token_chooser_list: List[NextTokenChooser]
|
||||
|
||||
@classmethod
|
||||
def from_batch(
|
||||
@ -27,34 +29,54 @@ class DeepSparseCausalLMBatch:
|
||||
# parse batch
|
||||
requests_idx_mapping = {}
|
||||
input_ids_list = []
|
||||
|
||||
# setup tokenizer for deepsparse left padding
|
||||
tokenizer.padding_side = "left"
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
padding, truncation = "longest", False
|
||||
stopping_criteria_list = []
|
||||
next_token_chooser_list = []
|
||||
|
||||
# loop through items in the batch
|
||||
for idx, r in enumerate(batch.requests):
|
||||
requests_idx_mapping[r.id] = idx
|
||||
|
||||
# setup inputs_ids, past_key_values
|
||||
# setup inputs_ids, stopping crtieria
|
||||
tokenized_inputs = tokenizer(
|
||||
r.prompt,
|
||||
r.inputs,
|
||||
return_tensors="np",
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
return_token_type_ids=False,
|
||||
max_length=DEEPSPARSE_SEQUENCE_LENGTH
|
||||
)
|
||||
input_ids_list.append(tokenized_inputs["input_ids"])
|
||||
|
||||
# deepsparse able to accept up to seq len tokens
|
||||
num_input_tokens = tokenized_inputs["input_ids"].shape[1]
|
||||
model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens
|
||||
stopping_criteria_list.append(
|
||||
StoppingCriteria(
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
|
||||
)
|
||||
)
|
||||
|
||||
# get next token chooser based on input
|
||||
next_token_chooser_list.append(
|
||||
NextTokenChooser(
|
||||
repetition_penalty=r.generation_parameters.repetition_penalty,
|
||||
temperature=r.generation_parameters.temperature,
|
||||
top_k=r.generation_parameters.top_k,
|
||||
top_p=r.generation_parameters.top_p,
|
||||
do_sample=r.generation_parameters.do_sample,
|
||||
seed=r.generation_parameters.seed,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=batch.id,
|
||||
requests=batch.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids_list=input_ids_list,
|
||||
past_key_values_list=[None] * len(batch.requests),
|
||||
stopping_criteria_list=stopping_criteria_list,
|
||||
next_token_chooser_list=next_token_chooser_list,
|
||||
)
|
||||
|
||||
def to_cached_batch(self) -> CachedBatch:
|
||||
@ -71,10 +93,12 @@ class DeepSparseCausalLMBatch:
|
||||
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
|
||||
assert(len(request_ids) > 0)
|
||||
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_ids_list = []
|
||||
past_key_values_list = []
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_ids_list = []
|
||||
past_key_values_list = []
|
||||
stopping_criteria_list = []
|
||||
next_token_chooser_list = []
|
||||
|
||||
# loop through requests, keep ones that should remain
|
||||
for new_idx, request_id in enumerate(request_ids):
|
||||
@ -86,12 +110,16 @@ class DeepSparseCausalLMBatch:
|
||||
requests.append(self.requests[old_idx])
|
||||
input_ids_list.append(self.input_ids_list[old_idx])
|
||||
past_key_values_list.append(self.past_key_values_list[old_idx])
|
||||
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
|
||||
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
|
||||
|
||||
# update batch state
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids_list = input_ids_list
|
||||
self.past_key_values_list = past_key_values_list
|
||||
self.stopping_criteria_list = stopping_criteria_list
|
||||
self.next_token_chooser_list = next_token_chooser_list
|
||||
|
||||
return self
|
||||
|
||||
@ -100,10 +128,12 @@ class DeepSparseCausalLMBatch:
|
||||
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
|
||||
assert len(batches) > 1, "must have more than 1 batch to concatenate"
|
||||
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_ids_list = []
|
||||
past_key_values_list = []
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_ids_list = []
|
||||
past_key_values_list = []
|
||||
stopping_criteria_list = []
|
||||
next_token_chooser_list = []
|
||||
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
@ -113,6 +143,8 @@ class DeepSparseCausalLMBatch:
|
||||
requests.extend(batch.requests)
|
||||
input_ids_list.extend(batch.input_ids_list)
|
||||
past_key_values_list.extend(batch.past_key_values_list)
|
||||
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
||||
next_token_chooser_list.extend(batch.next_token_chooser_list)
|
||||
|
||||
# merge the request_id to index mapping
|
||||
if i == 0:
|
||||
@ -128,14 +160,16 @@ class DeepSparseCausalLMBatch:
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids_list=input_ids_list,
|
||||
past_key_values_list=past_key_values_list
|
||||
past_key_values_list=past_key_values_list,
|
||||
stopping_criteria_list=stopping_criteria_list,
|
||||
next_token_chooser_list=next_token_chooser_list
|
||||
)
|
||||
|
||||
class DeepSparseCausalLM:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tokenizer_path: str,
|
||||
self,
|
||||
model_path: str,
|
||||
tokenizer_path: str,
|
||||
):
|
||||
# setup tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
@ -151,30 +185,6 @@ class DeepSparseCausalLM:
|
||||
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
||||
)
|
||||
|
||||
# TODO: switch to NextTokenChooser
|
||||
def sample_token(
|
||||
self,
|
||||
logits: np.ndarray
|
||||
):
|
||||
# assert b=1 for now
|
||||
assert(logits.shape[0] == 1)
|
||||
|
||||
# grab logits for the last item in the sequence
|
||||
# shape == (batch, seq, vocabulary_size)
|
||||
return np.argmax(logits[0,-1,:])
|
||||
|
||||
# TODO: switch to StoppingCriteria
|
||||
def should_stop(
|
||||
self,
|
||||
num_tokens_processed: int,
|
||||
generated_token_id: int
|
||||
):
|
||||
if num_tokens_processed >= self.model.sequence_length:
|
||||
return True
|
||||
if generated_token_id == self.tokenizer.eos_token_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def generate_token(
|
||||
self,
|
||||
batch: DeepSparseCausalLMBatch,
|
||||
@ -188,13 +198,22 @@ class DeepSparseCausalLM:
|
||||
# b) sample and check stopping criteria
|
||||
# c) create generation
|
||||
# d) update batch
|
||||
for i, (request, input_ids, past_key_values,) in enumerate(
|
||||
zip(
|
||||
batch.requests,
|
||||
batch.input_ids_list,
|
||||
batch.past_key_values_list
|
||||
)
|
||||
):
|
||||
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_ids_list,
|
||||
batch.past_key_values_list,
|
||||
batch.stopping_criteria_list,
|
||||
batch.next_token_chooser_list,
|
||||
)
|
||||
for i, (
|
||||
request,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
stopping_criteria,
|
||||
next_token_chooser
|
||||
) in enumerate(iterator):
|
||||
# assert input_ids is b=1
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
|
||||
@ -203,12 +222,10 @@ class DeepSparseCausalLM:
|
||||
|
||||
# b) sample token and check stopping criteria
|
||||
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
||||
generated_token_id = self.sample_token(logits)
|
||||
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
|
||||
generated_token = self.tokenizer.decode(generated_token_id)
|
||||
stop = self.should_stop(
|
||||
num_tokens_processed=input_ids.shape[1] + 1,
|
||||
generated_token_id = generated_token_id
|
||||
)
|
||||
|
||||
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
|
||||
if not stop:
|
||||
all_stopped = False
|
||||
|
||||
@ -217,11 +234,12 @@ class DeepSparseCausalLM:
|
||||
request_id=request.id,
|
||||
token=generated_token,
|
||||
token_id=generated_token_id,
|
||||
stopped=stop
|
||||
stopped=stop,
|
||||
finish_reason=finish_reason
|
||||
))
|
||||
|
||||
# d) update batch
|
||||
# TODO: this does not occur in place)
|
||||
# TODO: this does not occur in place
|
||||
assert len(batch.input_ids_list[i].shape) == 2
|
||||
assert batch.input_ids_list[i].shape[0] == 1
|
||||
batch.input_ids_list[i] = np.append(
|
||||
|
@ -56,7 +56,7 @@ class DeepSparseService:
|
||||
assert len(generations) == 1
|
||||
self.cache.set(next_ds_batch)
|
||||
|
||||
return generations[0], next_ds_batch.to_cached_batch()
|
||||
return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None)
|
||||
|
||||
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
|
||||
assert len(batches) != 0, "Must provide at least one batch"
|
||||
|
@ -1,12 +1,120 @@
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
from queue import Queue
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
|
||||
|
||||
# TODO: sample for b > 1 with vectorized code
|
||||
class Greedy:
|
||||
def __call__(self, logits: np.ndarray):
|
||||
# assert b=1 for now
|
||||
# shape == (batch, vocabulary_size)
|
||||
assert(logits.shape[0] == 1)
|
||||
assert(len(logits.shape) == 2)
|
||||
|
||||
return np.argmax(logits[0,:])
|
||||
|
||||
# TODO: sample for b > 1 with vectorized code
|
||||
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
|
||||
class Sampling:
|
||||
def __init__(self, seed:int=42):
|
||||
self.generator = np.random.default_rng(seed=seed)
|
||||
|
||||
def __call__(self, logits:np.ndarray):
|
||||
# assert b=1 for now
|
||||
# shape == (batch, vocabulary_size)
|
||||
assert(logits.shape[0] == 1)
|
||||
assert(len(logits.shape) == 2)
|
||||
|
||||
probs = softmax(logits, axis=1)
|
||||
return self.generator.choice(probs.shape[1], p=probs[0,:])
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
repetition_penalty: Optional[float] = 1.0,
|
||||
temperature: Optional[float] = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
do_sample:bool = False,
|
||||
seed: int = 42,
|
||||
):
|
||||
self.repetition_processor = (
|
||||
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||
if repetition_penalty and repetition_penalty != 1.0 else None
|
||||
)
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
or (top_k is not None and top_k != 0)
|
||||
or (top_p is not None and top_p < 1.0)
|
||||
)
|
||||
|
||||
if has_warpers:
|
||||
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
|
||||
else:
|
||||
self.warpers = None
|
||||
|
||||
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
|
||||
|
||||
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
|
||||
|
||||
if self.warpers is not None:
|
||||
scores = self.warpers(scores=scores)
|
||||
|
||||
return self.choice(scores)
|
||||
|
||||
class FinishReason(Enum):
|
||||
FINISH_REASON_LENGTH = 1
|
||||
FINISH_REASON_EOS_TOKEN = 2
|
||||
|
||||
class StoppingCriteria:
|
||||
def __init__(
|
||||
self,
|
||||
eos_token_id: int,
|
||||
max_new_tokens: int,
|
||||
):
|
||||
assert max_new_tokens > 0
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.eos_token_id = eos_token_id
|
||||
self.current_tokens = 0
|
||||
|
||||
def __call__(self, generated_token_id:int):
|
||||
self.current_tokens += 1
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True, FinishReason.FINISH_REASON_LENGTH
|
||||
|
||||
if generated_token_id == self.eos_token_id:
|
||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||
|
||||
return False, None
|
||||
|
||||
class GenerationParameters(BaseModel):
|
||||
max_new_tokens: int = Field(default=20)
|
||||
repetition_penalty: float = Field(default=1)
|
||||
do_sample: bool = Field(default=False)
|
||||
temperature: float = Field(default=1.0)
|
||||
top_k: Optional[int] = Field(default=None)
|
||||
top_p: Optional[float] = Field(default=None)
|
||||
seed: int = Field(default=42)
|
||||
|
||||
class GenerateRequestInputs(BaseModel):
|
||||
inputs: str
|
||||
generation_parameters: GenerationParameters
|
||||
|
||||
class GenerateRequestOutputs(BaseModel):
|
||||
response_text: str = Field(default="")
|
||||
finish_reason: Optional[FinishReason] = Field(default=None)
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
id: int
|
||||
prompt: str
|
||||
max_generated_tokens: int
|
||||
inputs: str
|
||||
generation_parameters: GenerationParameters
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
@ -27,9 +135,18 @@ class Generation:
|
||||
token: Optional[str]
|
||||
token_id: Optional[str]
|
||||
stopped: bool
|
||||
finish_reason: FinishReason = None
|
||||
|
||||
@dataclass
|
||||
class GenerateRequest:
|
||||
prompt: str
|
||||
max_generated_tokens: int
|
||||
response_stream: Queue[Generation]
|
||||
inputs: str
|
||||
generation_parameters: GenerationParameters
|
||||
response_stream: Queue[Generation]
|
||||
|
||||
@classmethod
|
||||
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
|
||||
return cls(
|
||||
inputs=gr_inputs.inputs,
|
||||
generation_parameters=gr_inputs.generation_parameters,
|
||||
response_stream=Queue()
|
||||
)
|
1008
server-dev.ipynb
1008
server-dev.ipynb
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user