Merge pull request #4 from rsnm2/stopping-next-token-chooser

Stopping next token chooser
This commit is contained in:
Robert Shaw 2023-08-27 20:12:10 -06:00 committed by GitHub
commit b55270320c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1115 additions and 256 deletions

View File

@ -1,68 +0,0 @@
import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
from utils import GenerateRequest
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
def serve(
model_path=MODEL_PATH,
tokenizer_path=TOKENIZER_PATH,
host="0.0.0.0",
port=5543
):
router = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
router.submit_request(
GenerateRequest(
prompt=prompt,
max_generated_tokens=100,
response_stream=response_stream
)
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
return response_string
uvicorn.run(
app,
host=host,
port=port,
workers=1
)
if __name__ == "__main__":
serve()

View File

@ -2,7 +2,7 @@ from queue import Queue
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
class DeepSparseRouter: class DeepSparseRouter:
def __init__( def __init__(
@ -11,10 +11,7 @@ class DeepSparseRouter:
model_path: Optional[str] = None, model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
): ):
assert ( assert service is not None or (model_path is not None and tokenizer_path is not None)
service is not None or
(model_path is not None and tokenizer_path is not None)
)
if service is not None: if service is not None:
self.service = service self.service = service
@ -38,8 +35,8 @@ class DeepSparseRouter:
# unblock the batching task with a dummy request if blocked # unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest( self.queue.append(GenerateRequest(
prompt="dummy", inputs="stop",
max_generated_tokens=1, generation_parameters=GenerationParameters(max_new_tokens=1),
response_stream=Queue() response_stream=Queue()
)) ))
@ -166,17 +163,18 @@ class DeepSparseQueue:
# if block = True, this blocks until something ready # if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised) # if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data # while queue.empty() == False typically not guarentee data on next queue.get(), this
# the queue is only subscribed to by one thread (this one) # queue is only subscribed to by one thread (this one) since batching_task is the only
# since batching_task is the only function that calls next_batch # so it does in our case
generate_request = self.queue.get(block=block) generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request} generate_requests = {self.next_request_id: generate_request}
# format into request # format into request
request = Request( request = Request(
id=self.next_request_id, id=self.next_request_id,
prompt=generate_request.prompt, inputs=generate_request.inputs,
max_generated_tokens=generate_request.max_generated_tokens generation_parameters=generate_request.generation_parameters,
) )
self.next_request_id += 1 self.next_request_id += 1

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np import numpy as np
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4 DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -16,6 +16,8 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping: Dict[int,int] requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray] input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]] past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
stopping_criteria_list: List[StoppingCriteria]
next_token_chooser_list: List[NextTokenChooser]
@classmethod @classmethod
def from_batch( def from_batch(
@ -27,34 +29,54 @@ class DeepSparseCausalLMBatch:
# parse batch # parse batch
requests_idx_mapping = {} requests_idx_mapping = {}
input_ids_list = [] input_ids_list = []
stopping_criteria_list = []
# setup tokenizer for deepsparse left padding next_token_chooser_list = []
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
padding, truncation = "longest", False
# loop through items in the batch # loop through items in the batch
for idx, r in enumerate(batch.requests): for idx, r in enumerate(batch.requests):
requests_idx_mapping[r.id] = idx requests_idx_mapping[r.id] = idx
# setup inputs_ids, past_key_values # setup inputs_ids, stopping crtieria
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
r.prompt, r.inputs,
return_tensors="np", return_tensors="np",
padding=padding, padding="longest",
truncation=truncation, truncation=False,
return_token_type_ids=False, return_token_type_ids=False,
max_length=DEEPSPARSE_SEQUENCE_LENGTH max_length=DEEPSPARSE_SEQUENCE_LENGTH
) )
input_ids_list.append(tokenized_inputs["input_ids"]) input_ids_list.append(tokenized_inputs["input_ids"])
# deepsparse able to accept up to seq len tokens
num_input_tokens = tokenized_inputs["input_ids"].shape[1]
model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens
stopping_criteria_list.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
)
)
# get next token chooser based on input
next_token_chooser_list.append(
NextTokenChooser(
repetition_penalty=r.generation_parameters.repetition_penalty,
temperature=r.generation_parameters.temperature,
top_k=r.generation_parameters.top_k,
top_p=r.generation_parameters.top_p,
do_sample=r.generation_parameters.do_sample,
seed=r.generation_parameters.seed,
)
)
return cls( return cls(
batch_id=batch.id, batch_id=batch.id,
requests=batch.requests, requests=batch.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests), past_key_values_list=[None] * len(batch.requests),
stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list,
) )
def to_cached_batch(self) -> CachedBatch: def to_cached_batch(self) -> CachedBatch:
@ -71,10 +93,12 @@ class DeepSparseCausalLMBatch:
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
assert(len(request_ids) > 0) assert(len(request_ids) > 0)
requests_idx_mapping = {} requests_idx_mapping = {}
requests = [] requests = []
input_ids_list = [] input_ids_list = []
past_key_values_list = [] past_key_values_list = []
stopping_criteria_list = []
next_token_chooser_list = []
# loop through requests, keep ones that should remain # loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids): for new_idx, request_id in enumerate(request_ids):
@ -86,12 +110,16 @@ class DeepSparseCausalLMBatch:
requests.append(self.requests[old_idx]) requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx]) input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx])
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
# update batch state # update batch state
self.requests = requests self.requests = requests
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list self.past_key_values_list = past_key_values_list
self.stopping_criteria_list = stopping_criteria_list
self.next_token_chooser_list = next_token_chooser_list
return self return self
@ -100,10 +128,12 @@ class DeepSparseCausalLMBatch:
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
assert len(batches) > 1, "must have more than 1 batch to concatenate" assert len(batches) > 1, "must have more than 1 batch to concatenate"
requests_idx_mapping = {} requests_idx_mapping = {}
requests = [] requests = []
input_ids_list = [] input_ids_list = []
past_key_values_list = [] past_key_values_list = []
stopping_criteria_list = []
next_token_chooser_list = []
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
@ -113,6 +143,8 @@ class DeepSparseCausalLMBatch:
requests.extend(batch.requests) requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list) input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list) past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list)
next_token_chooser_list.extend(batch.next_token_chooser_list)
# merge the request_id to index mapping # merge the request_id to index mapping
if i == 0: if i == 0:
@ -128,14 +160,16 @@ class DeepSparseCausalLMBatch:
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list past_key_values_list=past_key_values_list,
stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list
) )
class DeepSparseCausalLM: class DeepSparseCausalLM:
def __init__( def __init__(
self, self,
model_path: str, model_path: str,
tokenizer_path: str, tokenizer_path: str,
): ):
# setup tokenizer # setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
@ -151,30 +185,6 @@ class DeepSparseCausalLM:
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
) )
# TODO: switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
# assert b=1 for now
assert(logits.shape[0] == 1)
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
# TODO: switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token( def generate_token(
self, self,
batch: DeepSparseCausalLMBatch, batch: DeepSparseCausalLMBatch,
@ -188,13 +198,22 @@ class DeepSparseCausalLM:
# b) sample and check stopping criteria # b) sample and check stopping criteria
# c) create generation # c) create generation
# d) update batch # d) update batch
for i, (request, input_ids, past_key_values,) in enumerate(
zip( iterator = zip(
batch.requests, batch.requests,
batch.input_ids_list, batch.input_ids_list,
batch.past_key_values_list batch.past_key_values_list,
) batch.stopping_criteria_list,
): batch.next_token_chooser_list,
)
for i, (
request,
input_ids,
past_key_values,
stopping_criteria,
next_token_chooser
) in enumerate(iterator):
# assert input_ids is b=1
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1 assert input_ids.shape[0] == 1
@ -203,12 +222,10 @@ class DeepSparseCausalLM:
# b) sample token and check stopping criteria # b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now) # TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits) generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
generated_token = self.tokenizer.decode(generated_token_id) generated_token = self.tokenizer.decode(generated_token_id)
stop = self.should_stop(
num_tokens_processed=input_ids.shape[1] + 1, stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
generated_token_id = generated_token_id
)
if not stop: if not stop:
all_stopped = False all_stopped = False
@ -217,11 +234,12 @@ class DeepSparseCausalLM:
request_id=request.id, request_id=request.id,
token=generated_token, token=generated_token,
token_id=generated_token_id, token_id=generated_token_id,
stopped=stop stopped=stop,
finish_reason=finish_reason
)) ))
# d) update batch # d) update batch
# TODO: this does not occur in place) # TODO: this does not occur in place
assert len(batch.input_ids_list[i].shape) == 2 assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1 assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append( batch.input_ids_list[i] = np.append(

View File

@ -56,7 +56,7 @@ class DeepSparseService:
assert len(generations) == 1 assert len(generations) == 1
self.cache.set(next_ds_batch) self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_cached_batch() return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None)
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
assert len(batches) != 0, "Must provide at least one batch" assert len(batches) != 0, "Must provide at least one batch"

View File

@ -1,12 +1,120 @@
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from pydantic import BaseModel, Field
from queue import Queue from queue import Queue
from enum import Enum
from typing import List, Optional from typing import List, Optional
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
# TODO: sample for b > 1 with vectorized code
class Greedy:
def __call__(self, logits: np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
return np.argmax(logits[0,:])
# TODO: sample for b > 1 with vectorized code
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
class Sampling:
def __init__(self, seed:int=42):
self.generator = np.random.default_rng(seed=seed)
def __call__(self, logits:np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
probs = softmax(logits, axis=1)
return self.generator.choice(probs.shape[1], p=probs[0,:])
class NextTokenChooser:
def __init__(
self,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
do_sample:bool = False,
seed: int = 42,
):
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0 else None
)
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
)
if has_warpers:
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
else:
self.warpers = None
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
if self.warpers is not None:
scores = self.warpers(scores=scores)
return self.choice(scores)
class FinishReason(Enum):
FINISH_REASON_LENGTH = 1
FINISH_REASON_EOS_TOKEN = 2
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
max_new_tokens: int,
):
assert max_new_tokens > 0
self.max_new_tokens = max_new_tokens
self.eos_token_id = eos_token_id
self.current_tokens = 0
def __call__(self, generated_token_id:int):
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if generated_token_id == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
return False, None
class GenerationParameters(BaseModel):
max_new_tokens: int = Field(default=20)
repetition_penalty: float = Field(default=1)
do_sample: bool = Field(default=False)
temperature: float = Field(default=1.0)
top_k: Optional[int] = Field(default=None)
top_p: Optional[float] = Field(default=None)
seed: int = Field(default=42)
class GenerateRequestInputs(BaseModel):
inputs: str
generation_parameters: GenerationParameters
class GenerateRequestOutputs(BaseModel):
response_text: str = Field(default="")
finish_reason: Optional[FinishReason] = Field(default=None)
@dataclass @dataclass
class Request: class Request:
id: int id: int
prompt: str inputs: str
max_generated_tokens: int generation_parameters: GenerationParameters
@dataclass @dataclass
class Batch: class Batch:
@ -27,9 +135,18 @@ class Generation:
token: Optional[str] token: Optional[str]
token_id: Optional[str] token_id: Optional[str]
stopped: bool stopped: bool
finish_reason: FinishReason = None
@dataclass @dataclass
class GenerateRequest: class GenerateRequest:
prompt: str inputs: str
max_generated_tokens: int generation_parameters: GenerationParameters
response_stream: Queue[Generation] response_stream: Queue[Generation]
@classmethod
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
return cls(
inputs=gr_inputs.inputs,
generation_parameters=gr_inputs.generation_parameters,
response_stream=Queue()
)

File diff suppressed because it is too large Load Diff