Merge pull request #4 from rsnm2/stopping-next-token-chooser

Stopping next token chooser
This commit is contained in:
Robert Shaw 2023-08-27 20:12:10 -06:00 committed by GitHub
commit b55270320c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1115 additions and 256 deletions

View File

@ -1,68 +0,0 @@
import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
from utils import GenerateRequest
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
def serve(
model_path=MODEL_PATH,
tokenizer_path=TOKENIZER_PATH,
host="0.0.0.0",
port=5543
):
router = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
router.submit_request(
GenerateRequest(
prompt=prompt,
max_generated_tokens=100,
response_stream=response_stream
)
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
return response_string
uvicorn.run(
app,
host=host,
port=port,
workers=1
)
if __name__ == "__main__":
serve()

View File

@ -2,7 +2,7 @@ from queue import Queue
from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
class DeepSparseRouter:
def __init__(
@ -11,10 +11,7 @@ class DeepSparseRouter:
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None
):
assert (
service is not None or
(model_path is not None and tokenizer_path is not None)
)
assert service is not None or (model_path is not None and tokenizer_path is not None)
if service is not None:
self.service = service
@ -38,8 +35,8 @@ class DeepSparseRouter:
# unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest(
prompt="dummy",
max_generated_tokens=1,
inputs="stop",
generation_parameters=GenerationParameters(max_new_tokens=1),
response_stream=Queue()
))
@ -166,17 +163,18 @@ class DeepSparseQueue:
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
# while queue.empty() == False typically not guarentee data on next queue.get(), this
# queue is only subscribed to by one thread (this one) since batching_task is the only
# so it does in our case
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
inputs=generate_request.inputs,
generation_parameters=generate_request.generation_parameters,
)
self.next_request_id += 1

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -16,6 +16,8 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
stopping_criteria_list: List[StoppingCriteria]
next_token_chooser_list: List[NextTokenChooser]
@classmethod
def from_batch(
@ -27,34 +29,54 @@ class DeepSparseCausalLMBatch:
# parse batch
requests_idx_mapping = {}
input_ids_list = []
# setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
padding, truncation = "longest", False
stopping_criteria_list = []
next_token_chooser_list = []
# loop through items in the batch
for idx, r in enumerate(batch.requests):
requests_idx_mapping[r.id] = idx
# setup inputs_ids, past_key_values
# setup inputs_ids, stopping crtieria
tokenized_inputs = tokenizer(
r.prompt,
r.inputs,
return_tensors="np",
padding=padding,
truncation=truncation,
padding="longest",
truncation=False,
return_token_type_ids=False,
max_length=DEEPSPARSE_SEQUENCE_LENGTH
)
input_ids_list.append(tokenized_inputs["input_ids"])
# deepsparse able to accept up to seq len tokens
num_input_tokens = tokenized_inputs["input_ids"].shape[1]
model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens
stopping_criteria_list.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
)
)
# get next token chooser based on input
next_token_chooser_list.append(
NextTokenChooser(
repetition_penalty=r.generation_parameters.repetition_penalty,
temperature=r.generation_parameters.temperature,
top_k=r.generation_parameters.top_k,
top_p=r.generation_parameters.top_p,
do_sample=r.generation_parameters.do_sample,
seed=r.generation_parameters.seed,
)
)
return cls(
batch_id=batch.id,
requests=batch.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests),
stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list,
)
def to_cached_batch(self) -> CachedBatch:
@ -75,6 +97,8 @@ class DeepSparseCausalLMBatch:
requests = []
input_ids_list = []
past_key_values_list = []
stopping_criteria_list = []
next_token_chooser_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
@ -86,12 +110,16 @@ class DeepSparseCausalLMBatch:
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx])
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
# update batch state
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
self.stopping_criteria_list = stopping_criteria_list
self.next_token_chooser_list = next_token_chooser_list
return self
@ -104,6 +132,8 @@ class DeepSparseCausalLMBatch:
requests = []
input_ids_list = []
past_key_values_list = []
stopping_criteria_list = []
next_token_chooser_list = []
start_index = 0
for i, batch in enumerate(batches):
@ -113,6 +143,8 @@ class DeepSparseCausalLMBatch:
requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list)
next_token_chooser_list.extend(batch.next_token_chooser_list)
# merge the request_id to index mapping
if i == 0:
@ -128,7 +160,9 @@ class DeepSparseCausalLMBatch:
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list
past_key_values_list=past_key_values_list,
stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list
)
class DeepSparseCausalLM:
@ -151,30 +185,6 @@ class DeepSparseCausalLM:
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
)
# TODO: switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
# assert b=1 for now
assert(logits.shape[0] == 1)
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
# TODO: switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token(
self,
batch: DeepSparseCausalLMBatch,
@ -188,13 +198,22 @@ class DeepSparseCausalLM:
# b) sample and check stopping criteria
# c) create generation
# d) update batch
for i, (request, input_ids, past_key_values,) in enumerate(
zip(
iterator = zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list
batch.past_key_values_list,
batch.stopping_criteria_list,
batch.next_token_chooser_list,
)
):
for i, (
request,
input_ids,
past_key_values,
stopping_criteria,
next_token_chooser
) in enumerate(iterator):
# assert input_ids is b=1
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
@ -203,12 +222,10 @@ class DeepSparseCausalLM:
# b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits)
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
generated_token = self.tokenizer.decode(generated_token_id)
stop = self.should_stop(
num_tokens_processed=input_ids.shape[1] + 1,
generated_token_id = generated_token_id
)
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
if not stop:
all_stopped = False
@ -217,11 +234,12 @@ class DeepSparseCausalLM:
request_id=request.id,
token=generated_token,
token_id=generated_token_id,
stopped=stop
stopped=stop,
finish_reason=finish_reason
))
# d) update batch
# TODO: this does not occur in place)
# TODO: this does not occur in place
assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append(

View File

@ -56,7 +56,7 @@ class DeepSparseService:
assert len(generations) == 1
self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_cached_batch()
return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None)
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
assert len(batches) != 0, "Must provide at least one batch"

View File

@ -1,12 +1,120 @@
import numpy as np
from dataclasses import dataclass
from pydantic import BaseModel, Field
from queue import Queue
from enum import Enum
from typing import List, Optional
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
# TODO: sample for b > 1 with vectorized code
class Greedy:
def __call__(self, logits: np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
return np.argmax(logits[0,:])
# TODO: sample for b > 1 with vectorized code
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
class Sampling:
def __init__(self, seed:int=42):
self.generator = np.random.default_rng(seed=seed)
def __call__(self, logits:np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
probs = softmax(logits, axis=1)
return self.generator.choice(probs.shape[1], p=probs[0,:])
class NextTokenChooser:
def __init__(
self,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
do_sample:bool = False,
seed: int = 42,
):
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0 else None
)
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
)
if has_warpers:
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
else:
self.warpers = None
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
if self.warpers is not None:
scores = self.warpers(scores=scores)
return self.choice(scores)
class FinishReason(Enum):
FINISH_REASON_LENGTH = 1
FINISH_REASON_EOS_TOKEN = 2
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
max_new_tokens: int,
):
assert max_new_tokens > 0
self.max_new_tokens = max_new_tokens
self.eos_token_id = eos_token_id
self.current_tokens = 0
def __call__(self, generated_token_id:int):
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if generated_token_id == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
return False, None
class GenerationParameters(BaseModel):
max_new_tokens: int = Field(default=20)
repetition_penalty: float = Field(default=1)
do_sample: bool = Field(default=False)
temperature: float = Field(default=1.0)
top_k: Optional[int] = Field(default=None)
top_p: Optional[float] = Field(default=None)
seed: int = Field(default=42)
class GenerateRequestInputs(BaseModel):
inputs: str
generation_parameters: GenerationParameters
class GenerateRequestOutputs(BaseModel):
response_text: str = Field(default="")
finish_reason: Optional[FinishReason] = Field(default=None)
@dataclass
class Request:
id: int
prompt: str
max_generated_tokens: int
inputs: str
generation_parameters: GenerationParameters
@dataclass
class Batch:
@ -27,9 +135,18 @@ class Generation:
token: Optional[str]
token_id: Optional[str]
stopped: bool
finish_reason: FinishReason = None
@dataclass
class GenerateRequest:
prompt: str
max_generated_tokens: int
inputs: str
generation_parameters: GenerationParameters
response_stream: Queue[Generation]
@classmethod
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
return cls(
inputs=gr_inputs.inputs,
generation_parameters=gr_inputs.generation_parameters,
response_stream=Queue()
)

File diff suppressed because it is too large Load Diff