mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Merge pull request #4 from rsnm2/stopping-next-token-chooser
Stopping next token chooser
This commit is contained in:
commit
b55270320c
@ -1,68 +0,0 @@
|
|||||||
import fastapi, uvicorn
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from threading import Thread
|
|
||||||
from queue import Queue
|
|
||||||
from router import DeepSparseRouter, batching_task
|
|
||||||
from utils import GenerateRequest
|
|
||||||
|
|
||||||
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
|
|
||||||
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
|
|
||||||
|
|
||||||
def serve(
|
|
||||||
model_path=MODEL_PATH,
|
|
||||||
tokenizer_path=TOKENIZER_PATH,
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=5543
|
|
||||||
):
|
|
||||||
|
|
||||||
router = None
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: fastapi.FastAPI):
|
|
||||||
print("\n-------------------- Building Router --------------------\n")
|
|
||||||
router = DeepSparseRouter(
|
|
||||||
model_path=model_path,
|
|
||||||
tokenizer_path=tokenizer_path
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n-------------------- Starting Batching Task --------------------\n")
|
|
||||||
batching_thread = Thread(target=batching_task, args=[router])
|
|
||||||
batching_thread.start()
|
|
||||||
|
|
||||||
print("\n-------------------- Launching App --------------------\n")
|
|
||||||
yield
|
|
||||||
|
|
||||||
print("\n-------------------- Joining Batching Task --------------------\n")
|
|
||||||
router.stop_batching_task()
|
|
||||||
batching_task.join()
|
|
||||||
|
|
||||||
app = fastapi.FastAPI(lifespan=lifespan)
|
|
||||||
|
|
||||||
@app.get("/generate/{prompt}")
|
|
||||||
async def generate(prompt:str):
|
|
||||||
response_stream = Queue()
|
|
||||||
router.submit_request(
|
|
||||||
GenerateRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
max_generated_tokens=100,
|
|
||||||
response_stream=response_stream
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
response_string = prompt
|
|
||||||
generation = response_stream.get()
|
|
||||||
while not generation.stopped:
|
|
||||||
response_string += generation.token
|
|
||||||
generation = response_stream.get()
|
|
||||||
|
|
||||||
return response_string
|
|
||||||
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
workers=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
serve()
|
|
@ -2,7 +2,7 @@ from queue import Queue
|
|||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Optional, Tuple
|
||||||
from service.service import DeepSparseService
|
from service.service import DeepSparseService
|
||||||
from service.causal_lm import DeepSparseCausalLM
|
from service.causal_lm import DeepSparseCausalLM
|
||||||
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
|
||||||
|
|
||||||
class DeepSparseRouter:
|
class DeepSparseRouter:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -11,10 +11,7 @@ class DeepSparseRouter:
|
|||||||
model_path: Optional[str] = None,
|
model_path: Optional[str] = None,
|
||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
):
|
):
|
||||||
assert (
|
assert service is not None or (model_path is not None and tokenizer_path is not None)
|
||||||
service is not None or
|
|
||||||
(model_path is not None and tokenizer_path is not None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if service is not None:
|
if service is not None:
|
||||||
self.service = service
|
self.service = service
|
||||||
@ -38,8 +35,8 @@ class DeepSparseRouter:
|
|||||||
|
|
||||||
# unblock the batching task with a dummy request if blocked
|
# unblock the batching task with a dummy request if blocked
|
||||||
self.queue.append(GenerateRequest(
|
self.queue.append(GenerateRequest(
|
||||||
prompt="dummy",
|
inputs="stop",
|
||||||
max_generated_tokens=1,
|
generation_parameters=GenerationParameters(max_new_tokens=1),
|
||||||
response_stream=Queue()
|
response_stream=Queue()
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -166,17 +163,18 @@ class DeepSparseQueue:
|
|||||||
|
|
||||||
# if block = True, this blocks until something ready
|
# if block = True, this blocks until something ready
|
||||||
# if block = False, the queue has data (if not an exception is raised)
|
# if block = False, the queue has data (if not an exception is raised)
|
||||||
# while queue.empty() == False does not guarentee data
|
# while queue.empty() == False typically not guarentee data on next queue.get(), this
|
||||||
# the queue is only subscribed to by one thread (this one)
|
# queue is only subscribed to by one thread (this one) since batching_task is the only
|
||||||
# since batching_task is the only function that calls next_batch
|
# so it does in our case
|
||||||
|
|
||||||
generate_request = self.queue.get(block=block)
|
generate_request = self.queue.get(block=block)
|
||||||
generate_requests = {self.next_request_id: generate_request}
|
generate_requests = {self.next_request_id: generate_request}
|
||||||
|
|
||||||
# format into request
|
# format into request
|
||||||
request = Request(
|
request = Request(
|
||||||
id=self.next_request_id,
|
id=self.next_request_id,
|
||||||
prompt=generate_request.prompt,
|
inputs=generate_request.inputs,
|
||||||
max_generated_tokens=generate_request.max_generated_tokens
|
generation_parameters=generate_request.generation_parameters,
|
||||||
)
|
)
|
||||||
self.next_request_id += 1
|
self.next_request_id += 1
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||||
from utils import Request, Batch, CachedBatch, Generation
|
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
|
||||||
|
|
||||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||||
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||||
@ -16,6 +16,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests_idx_mapping: Dict[int,int]
|
requests_idx_mapping: Dict[int,int]
|
||||||
input_ids_list: List[np.ndarray]
|
input_ids_list: List[np.ndarray]
|
||||||
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
||||||
|
stopping_criteria_list: List[StoppingCriteria]
|
||||||
|
next_token_chooser_list: List[NextTokenChooser]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_batch(
|
def from_batch(
|
||||||
@ -27,34 +29,54 @@ class DeepSparseCausalLMBatch:
|
|||||||
# parse batch
|
# parse batch
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
|
stopping_criteria_list = []
|
||||||
# setup tokenizer for deepsparse left padding
|
next_token_chooser_list = []
|
||||||
tokenizer.padding_side = "left"
|
|
||||||
if not tokenizer.pad_token:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
padding, truncation = "longest", False
|
|
||||||
|
|
||||||
# loop through items in the batch
|
# loop through items in the batch
|
||||||
for idx, r in enumerate(batch.requests):
|
for idx, r in enumerate(batch.requests):
|
||||||
requests_idx_mapping[r.id] = idx
|
requests_idx_mapping[r.id] = idx
|
||||||
|
|
||||||
# setup inputs_ids, past_key_values
|
# setup inputs_ids, stopping crtieria
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
r.prompt,
|
r.inputs,
|
||||||
return_tensors="np",
|
return_tensors="np",
|
||||||
padding=padding,
|
padding="longest",
|
||||||
truncation=truncation,
|
truncation=False,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
max_length=DEEPSPARSE_SEQUENCE_LENGTH
|
max_length=DEEPSPARSE_SEQUENCE_LENGTH
|
||||||
)
|
)
|
||||||
input_ids_list.append(tokenized_inputs["input_ids"])
|
input_ids_list.append(tokenized_inputs["input_ids"])
|
||||||
|
|
||||||
|
# deepsparse able to accept up to seq len tokens
|
||||||
|
num_input_tokens = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens
|
||||||
|
stopping_criteria_list.append(
|
||||||
|
StoppingCriteria(
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# get next token chooser based on input
|
||||||
|
next_token_chooser_list.append(
|
||||||
|
NextTokenChooser(
|
||||||
|
repetition_penalty=r.generation_parameters.repetition_penalty,
|
||||||
|
temperature=r.generation_parameters.temperature,
|
||||||
|
top_k=r.generation_parameters.top_k,
|
||||||
|
top_p=r.generation_parameters.top_p,
|
||||||
|
do_sample=r.generation_parameters.do_sample,
|
||||||
|
seed=r.generation_parameters.seed,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batch.id,
|
batch_id=batch.id,
|
||||||
requests=batch.requests,
|
requests=batch.requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids_list=input_ids_list,
|
input_ids_list=input_ids_list,
|
||||||
past_key_values_list=[None] * len(batch.requests),
|
past_key_values_list=[None] * len(batch.requests),
|
||||||
|
stopping_criteria_list=stopping_criteria_list,
|
||||||
|
next_token_chooser_list=next_token_chooser_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_cached_batch(self) -> CachedBatch:
|
def to_cached_batch(self) -> CachedBatch:
|
||||||
@ -71,10 +93,12 @@ class DeepSparseCausalLMBatch:
|
|||||||
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
|
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
|
||||||
assert(len(request_ids) > 0)
|
assert(len(request_ids) > 0)
|
||||||
|
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
requests = []
|
requests = []
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
past_key_values_list = []
|
past_key_values_list = []
|
||||||
|
stopping_criteria_list = []
|
||||||
|
next_token_chooser_list = []
|
||||||
|
|
||||||
# loop through requests, keep ones that should remain
|
# loop through requests, keep ones that should remain
|
||||||
for new_idx, request_id in enumerate(request_ids):
|
for new_idx, request_id in enumerate(request_ids):
|
||||||
@ -86,12 +110,16 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests.append(self.requests[old_idx])
|
requests.append(self.requests[old_idx])
|
||||||
input_ids_list.append(self.input_ids_list[old_idx])
|
input_ids_list.append(self.input_ids_list[old_idx])
|
||||||
past_key_values_list.append(self.past_key_values_list[old_idx])
|
past_key_values_list.append(self.past_key_values_list[old_idx])
|
||||||
|
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
|
||||||
|
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
|
||||||
|
|
||||||
# update batch state
|
# update batch state
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
self.requests_idx_mapping = requests_idx_mapping
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
self.input_ids_list = input_ids_list
|
self.input_ids_list = input_ids_list
|
||||||
self.past_key_values_list = past_key_values_list
|
self.past_key_values_list = past_key_values_list
|
||||||
|
self.stopping_criteria_list = stopping_criteria_list
|
||||||
|
self.next_token_chooser_list = next_token_chooser_list
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -100,10 +128,12 @@ class DeepSparseCausalLMBatch:
|
|||||||
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
|
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
|
||||||
assert len(batches) > 1, "must have more than 1 batch to concatenate"
|
assert len(batches) > 1, "must have more than 1 batch to concatenate"
|
||||||
|
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
requests = []
|
requests = []
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
past_key_values_list = []
|
past_key_values_list = []
|
||||||
|
stopping_criteria_list = []
|
||||||
|
next_token_chooser_list = []
|
||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
@ -113,6 +143,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_ids_list.extend(batch.input_ids_list)
|
input_ids_list.extend(batch.input_ids_list)
|
||||||
past_key_values_list.extend(batch.past_key_values_list)
|
past_key_values_list.extend(batch.past_key_values_list)
|
||||||
|
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
||||||
|
next_token_chooser_list.extend(batch.next_token_chooser_list)
|
||||||
|
|
||||||
# merge the request_id to index mapping
|
# merge the request_id to index mapping
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -128,14 +160,16 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids_list=input_ids_list,
|
input_ids_list=input_ids_list,
|
||||||
past_key_values_list=past_key_values_list
|
past_key_values_list=past_key_values_list,
|
||||||
|
stopping_criteria_list=stopping_criteria_list,
|
||||||
|
next_token_chooser_list=next_token_chooser_list
|
||||||
)
|
)
|
||||||
|
|
||||||
class DeepSparseCausalLM:
|
class DeepSparseCausalLM:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
tokenizer_path: str,
|
tokenizer_path: str,
|
||||||
):
|
):
|
||||||
# setup tokenizer
|
# setup tokenizer
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
@ -151,30 +185,6 @@ class DeepSparseCausalLM:
|
|||||||
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: switch to NextTokenChooser
|
|
||||||
def sample_token(
|
|
||||||
self,
|
|
||||||
logits: np.ndarray
|
|
||||||
):
|
|
||||||
# assert b=1 for now
|
|
||||||
assert(logits.shape[0] == 1)
|
|
||||||
|
|
||||||
# grab logits for the last item in the sequence
|
|
||||||
# shape == (batch, seq, vocabulary_size)
|
|
||||||
return np.argmax(logits[0,-1,:])
|
|
||||||
|
|
||||||
# TODO: switch to StoppingCriteria
|
|
||||||
def should_stop(
|
|
||||||
self,
|
|
||||||
num_tokens_processed: int,
|
|
||||||
generated_token_id: int
|
|
||||||
):
|
|
||||||
if num_tokens_processed >= self.model.sequence_length:
|
|
||||||
return True
|
|
||||||
if generated_token_id == self.tokenizer.eos_token_id:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self,
|
self,
|
||||||
batch: DeepSparseCausalLMBatch,
|
batch: DeepSparseCausalLMBatch,
|
||||||
@ -188,13 +198,22 @@ class DeepSparseCausalLM:
|
|||||||
# b) sample and check stopping criteria
|
# b) sample and check stopping criteria
|
||||||
# c) create generation
|
# c) create generation
|
||||||
# d) update batch
|
# d) update batch
|
||||||
for i, (request, input_ids, past_key_values,) in enumerate(
|
|
||||||
zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_ids_list,
|
batch.input_ids_list,
|
||||||
batch.past_key_values_list
|
batch.past_key_values_list,
|
||||||
)
|
batch.stopping_criteria_list,
|
||||||
):
|
batch.next_token_chooser_list,
|
||||||
|
)
|
||||||
|
for i, (
|
||||||
|
request,
|
||||||
|
input_ids,
|
||||||
|
past_key_values,
|
||||||
|
stopping_criteria,
|
||||||
|
next_token_chooser
|
||||||
|
) in enumerate(iterator):
|
||||||
|
# assert input_ids is b=1
|
||||||
assert len(input_ids.shape) == 2
|
assert len(input_ids.shape) == 2
|
||||||
assert input_ids.shape[0] == 1
|
assert input_ids.shape[0] == 1
|
||||||
|
|
||||||
@ -203,12 +222,10 @@ class DeepSparseCausalLM:
|
|||||||
|
|
||||||
# b) sample token and check stopping criteria
|
# b) sample token and check stopping criteria
|
||||||
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
||||||
generated_token_id = self.sample_token(logits)
|
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
|
||||||
generated_token = self.tokenizer.decode(generated_token_id)
|
generated_token = self.tokenizer.decode(generated_token_id)
|
||||||
stop = self.should_stop(
|
|
||||||
num_tokens_processed=input_ids.shape[1] + 1,
|
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
|
||||||
generated_token_id = generated_token_id
|
|
||||||
)
|
|
||||||
if not stop:
|
if not stop:
|
||||||
all_stopped = False
|
all_stopped = False
|
||||||
|
|
||||||
@ -217,11 +234,12 @@ class DeepSparseCausalLM:
|
|||||||
request_id=request.id,
|
request_id=request.id,
|
||||||
token=generated_token,
|
token=generated_token,
|
||||||
token_id=generated_token_id,
|
token_id=generated_token_id,
|
||||||
stopped=stop
|
stopped=stop,
|
||||||
|
finish_reason=finish_reason
|
||||||
))
|
))
|
||||||
|
|
||||||
# d) update batch
|
# d) update batch
|
||||||
# TODO: this does not occur in place)
|
# TODO: this does not occur in place
|
||||||
assert len(batch.input_ids_list[i].shape) == 2
|
assert len(batch.input_ids_list[i].shape) == 2
|
||||||
assert batch.input_ids_list[i].shape[0] == 1
|
assert batch.input_ids_list[i].shape[0] == 1
|
||||||
batch.input_ids_list[i] = np.append(
|
batch.input_ids_list[i] = np.append(
|
||||||
|
@ -56,7 +56,7 @@ class DeepSparseService:
|
|||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
self.cache.set(next_ds_batch)
|
self.cache.set(next_ds_batch)
|
||||||
|
|
||||||
return generations[0], next_ds_batch.to_cached_batch()
|
return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None)
|
||||||
|
|
||||||
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
|
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
|
||||||
assert len(batches) != 0, "Must provide at least one batch"
|
assert len(batches) != 0, "Must provide at least one batch"
|
||||||
|
@ -1,12 +1,120 @@
|
|||||||
|
import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
|
||||||
|
|
||||||
|
# TODO: sample for b > 1 with vectorized code
|
||||||
|
class Greedy:
|
||||||
|
def __call__(self, logits: np.ndarray):
|
||||||
|
# assert b=1 for now
|
||||||
|
# shape == (batch, vocabulary_size)
|
||||||
|
assert(logits.shape[0] == 1)
|
||||||
|
assert(len(logits.shape) == 2)
|
||||||
|
|
||||||
|
return np.argmax(logits[0,:])
|
||||||
|
|
||||||
|
# TODO: sample for b > 1 with vectorized code
|
||||||
|
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
|
||||||
|
class Sampling:
|
||||||
|
def __init__(self, seed:int=42):
|
||||||
|
self.generator = np.random.default_rng(seed=seed)
|
||||||
|
|
||||||
|
def __call__(self, logits:np.ndarray):
|
||||||
|
# assert b=1 for now
|
||||||
|
# shape == (batch, vocabulary_size)
|
||||||
|
assert(logits.shape[0] == 1)
|
||||||
|
assert(len(logits.shape) == 2)
|
||||||
|
|
||||||
|
probs = softmax(logits, axis=1)
|
||||||
|
return self.generator.choice(probs.shape[1], p=probs[0,:])
|
||||||
|
|
||||||
|
class NextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repetition_penalty: Optional[float] = 1.0,
|
||||||
|
temperature: Optional[float] = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
do_sample:bool = False,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
self.repetition_processor = (
|
||||||
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||||
|
if repetition_penalty and repetition_penalty != 1.0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
has_warpers = (
|
||||||
|
(temperature is not None and temperature != 1.0)
|
||||||
|
or (top_k is not None and top_k != 0)
|
||||||
|
or (top_p is not None and top_p < 1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_warpers:
|
||||||
|
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
|
||||||
|
else:
|
||||||
|
self.warpers = None
|
||||||
|
|
||||||
|
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
|
||||||
|
if self.repetition_processor is not None:
|
||||||
|
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
|
||||||
|
|
||||||
|
if self.warpers is not None:
|
||||||
|
scores = self.warpers(scores=scores)
|
||||||
|
|
||||||
|
return self.choice(scores)
|
||||||
|
|
||||||
|
class FinishReason(Enum):
|
||||||
|
FINISH_REASON_LENGTH = 1
|
||||||
|
FINISH_REASON_EOS_TOKEN = 2
|
||||||
|
|
||||||
|
class StoppingCriteria:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
eos_token_id: int,
|
||||||
|
max_new_tokens: int,
|
||||||
|
):
|
||||||
|
assert max_new_tokens > 0
|
||||||
|
self.max_new_tokens = max_new_tokens
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.current_tokens = 0
|
||||||
|
|
||||||
|
def __call__(self, generated_token_id:int):
|
||||||
|
self.current_tokens += 1
|
||||||
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
|
if generated_token_id == self.eos_token_id:
|
||||||
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
class GenerationParameters(BaseModel):
|
||||||
|
max_new_tokens: int = Field(default=20)
|
||||||
|
repetition_penalty: float = Field(default=1)
|
||||||
|
do_sample: bool = Field(default=False)
|
||||||
|
temperature: float = Field(default=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=None)
|
||||||
|
top_p: Optional[float] = Field(default=None)
|
||||||
|
seed: int = Field(default=42)
|
||||||
|
|
||||||
|
class GenerateRequestInputs(BaseModel):
|
||||||
|
inputs: str
|
||||||
|
generation_parameters: GenerationParameters
|
||||||
|
|
||||||
|
class GenerateRequestOutputs(BaseModel):
|
||||||
|
response_text: str = Field(default="")
|
||||||
|
finish_reason: Optional[FinishReason] = Field(default=None)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Request:
|
class Request:
|
||||||
id: int
|
id: int
|
||||||
prompt: str
|
inputs: str
|
||||||
max_generated_tokens: int
|
generation_parameters: GenerationParameters
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Batch:
|
class Batch:
|
||||||
@ -27,9 +135,18 @@ class Generation:
|
|||||||
token: Optional[str]
|
token: Optional[str]
|
||||||
token_id: Optional[str]
|
token_id: Optional[str]
|
||||||
stopped: bool
|
stopped: bool
|
||||||
|
finish_reason: FinishReason = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerateRequest:
|
class GenerateRequest:
|
||||||
prompt: str
|
inputs: str
|
||||||
max_generated_tokens: int
|
generation_parameters: GenerationParameters
|
||||||
response_stream: Queue[Generation]
|
response_stream: Queue[Generation]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
|
||||||
|
return cls(
|
||||||
|
inputs=gr_inputs.inputs,
|
||||||
|
generation_parameters=gr_inputs.generation_parameters,
|
||||||
|
response_stream=Queue()
|
||||||
|
)
|
1008
server-dev.ipynb
1008
server-dev.ipynb
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user