mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
readded changed file names
This commit is contained in:
parent
06fc85f93c
commit
cd3349f53b
161
server/deepsparse/router.py
Normal file
161
server/deepsparse/router.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
from queue import Queue
|
||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
from server.deepsparse.service.service import DeepSparseService
|
||||||
|
from server.deepsparse.utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
||||||
|
|
||||||
|
# TODO: implement logic for maximum size of the queue based on memory usage
|
||||||
|
class DeepSparseQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self.next_request_id: int = 0
|
||||||
|
self.next_batch_id: int = 0
|
||||||
|
self.queue: Queue[GenerateRequest] = Queue()
|
||||||
|
|
||||||
|
def append(self, generate_request: GenerateRequest):
|
||||||
|
self.queue.put(generate_request)
|
||||||
|
|
||||||
|
# TODO: enable multiple prefill requests in a batch
|
||||||
|
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
|
||||||
|
|
||||||
|
# if not blocking, return none if empty
|
||||||
|
if not block and self.queue.empty():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# if block = True, this blocks until something ready
|
||||||
|
# if block = False, the queue has data (if not an exception is raised)
|
||||||
|
# while queue.empty() == False does not guarentee data
|
||||||
|
# the queue is only subscribed to by one thread (this one)
|
||||||
|
# since batching_task is the only function that calls next_batch
|
||||||
|
generate_request = self.queue.get(block=block)
|
||||||
|
generate_requests = {self.next_request_id: generate_request}
|
||||||
|
|
||||||
|
# format into request
|
||||||
|
request = Request(
|
||||||
|
id=self.next_request_id,
|
||||||
|
prompt=generate_request.prompt,
|
||||||
|
max_generated_tokens=generate_request.max_generated_tokens
|
||||||
|
)
|
||||||
|
self.next_request_id += 1
|
||||||
|
|
||||||
|
# format into batch
|
||||||
|
batch = Batch(
|
||||||
|
id = self.next_batch_id,
|
||||||
|
requests=[request]
|
||||||
|
)
|
||||||
|
self.next_batch_id += 1
|
||||||
|
|
||||||
|
# return batch, generate_requests
|
||||||
|
return (batch, generate_requests)
|
||||||
|
|
||||||
|
class DeepSparseRouter:
|
||||||
|
def __init__(self, service: DeepSparseService):
|
||||||
|
self.service: DeepSparseService = service
|
||||||
|
self.queue: DeepSparseQueue = DeepSparseQueue()
|
||||||
|
|
||||||
|
def generate(self, generate_request: GenerateRequest):
|
||||||
|
self.queue.append(generate_request)
|
||||||
|
|
||||||
|
def prefill(
|
||||||
|
self,
|
||||||
|
batch: Batch,
|
||||||
|
generate_requests: Dict[int, GenerateRequest]
|
||||||
|
) -> Optional[CachedBatch]:
|
||||||
|
|
||||||
|
generation, next_batch = self.service.Prefill(batch=batch)
|
||||||
|
active_generate_request_ids = self.filter_send_generations([generation], generate_requests)
|
||||||
|
return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
batches: List[CachedBatch],
|
||||||
|
generate_requests: Dict[int,GenerateRequest]
|
||||||
|
) -> Optional[CachedBatch]:
|
||||||
|
|
||||||
|
generations, next_batch = self.service.Decode(batches=batches)
|
||||||
|
active_generate_request_ids = self.filter_send_generations(generations, generate_requests)
|
||||||
|
return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids)
|
||||||
|
|
||||||
|
def filter_send_generations(
|
||||||
|
self,
|
||||||
|
generations: List[Generation],
|
||||||
|
generate_requests: Dict[int, GenerateRequest]
|
||||||
|
) -> List[int]:
|
||||||
|
|
||||||
|
active_request_ids = []
|
||||||
|
for generation in generations:
|
||||||
|
# send generation to the response stream
|
||||||
|
generate_requests[generation.request_id].response_stream.put(generation)
|
||||||
|
|
||||||
|
# remove request from active requests if stopped
|
||||||
|
if generation.stopped:
|
||||||
|
generate_requests.pop(generation.request_id)
|
||||||
|
else:
|
||||||
|
active_request_ids.append(generation.request_id)
|
||||||
|
|
||||||
|
return active_request_ids
|
||||||
|
|
||||||
|
def filter_batch(
|
||||||
|
self,
|
||||||
|
batch: Optional[CachedBatch],
|
||||||
|
active_generate_request_ids: List[int]
|
||||||
|
) -> Optional[CachedBatch]:
|
||||||
|
|
||||||
|
# if batch done OR nothing to filter
|
||||||
|
if batch is None or len(batch) == len(active_generate_request_ids):
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# active request_ids
|
||||||
|
batch.request_ids = active_generate_request_ids
|
||||||
|
|
||||||
|
# if all requests complete, clear cache
|
||||||
|
if len(batch) == 0:
|
||||||
|
self.service.ClearCache()
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.service.FilterBatch(batch_id=batch.batch_id, request_ids=batch.request_ids)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: update to do more sophisticated logic as to when to do a prefill
|
||||||
|
def batching_task(router: DeepSparseRouter):
|
||||||
|
while True:
|
||||||
|
# loop until no requests to process (note: this blocks if queue is empty)
|
||||||
|
next_batch = router.queue.next_batch(block=True)
|
||||||
|
while next_batch is not None:
|
||||||
|
batch, generate_requests = next_batch
|
||||||
|
|
||||||
|
# HACK for development --- breaks out of the cycle
|
||||||
|
if batch.requests[0].prompt == "stop":
|
||||||
|
return
|
||||||
|
|
||||||
|
# run prefill
|
||||||
|
cached_batch = router.prefill(
|
||||||
|
batch=batch,
|
||||||
|
generate_requests=generate_requests
|
||||||
|
)
|
||||||
|
|
||||||
|
# loop until we do not reiceve any cached batch from the service
|
||||||
|
# == until all active requests have met their stopping criteria
|
||||||
|
while cached_batch is not None:
|
||||||
|
batches = [cached_batch]
|
||||||
|
|
||||||
|
# try to get a new batch and run prefill on this batch
|
||||||
|
next_batch = router.queue.next_batch(block=False)
|
||||||
|
|
||||||
|
if next_batch is not None:
|
||||||
|
new_batch, new_generate_requests = next_batch
|
||||||
|
new_cached_batch = router.prefill(
|
||||||
|
batch=new_batch,
|
||||||
|
generate_requests=new_generate_requests
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_cached_batch is not None:
|
||||||
|
batches.append(new_cached_batch)
|
||||||
|
assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
|
||||||
|
generate_requests.update(new_generate_requests)
|
||||||
|
|
||||||
|
# run decode
|
||||||
|
cached_batch = router.decode(
|
||||||
|
batches=batches,
|
||||||
|
generate_requests=generate_requests
|
||||||
|
)
|
||||||
|
|
||||||
|
next_batch = router.queue.next_batch(block=False)
|
0
server/deepsparse/server.py
Normal file
0
server/deepsparse/server.py
Normal file
239
server/deepsparse/service/causal_lm.py
Normal file
239
server/deepsparse/service/causal_lm.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from server.deepsparse.service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||||
|
from server.deepsparse.utils import Request, Batch, CachedBatch, Generation
|
||||||
|
|
||||||
|
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||||
|
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepSparseCausalLMBatch:
|
||||||
|
batch_id: int
|
||||||
|
requests: List[Request]
|
||||||
|
requests_idx_mapping: Dict[int,int]
|
||||||
|
input_ids_list: List[np.ndarray]
|
||||||
|
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_batch(
|
||||||
|
cls,
|
||||||
|
batch: Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> "DeepSparseCausalLMBatch":
|
||||||
|
|
||||||
|
# parse batch
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
input_ids_list = []
|
||||||
|
|
||||||
|
# setup tokenizer for deepsparse left padding
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
if not tokenizer.pad_token:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
padding, truncation = "longest", False
|
||||||
|
|
||||||
|
# loop through items in the batch
|
||||||
|
for idx, r in enumerate(batch.requests):
|
||||||
|
requests_idx_mapping[r.id] = idx
|
||||||
|
|
||||||
|
# setup inputs_ids, past_key_values
|
||||||
|
tokenized_inputs = tokenizer(
|
||||||
|
r.prompt,
|
||||||
|
return_tensors="np",
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
max_length=DEEPSPARSE_SEQUENCE_LENGTH
|
||||||
|
)
|
||||||
|
input_ids_list.append(tokenized_inputs["input_ids"])
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batch.id,
|
||||||
|
requests=batch.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids_list=input_ids_list,
|
||||||
|
past_key_values_list=[None] * len(batch.requests),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_cached_batch(self) -> CachedBatch:
|
||||||
|
return CachedBatch(
|
||||||
|
batch_id = self.batch_id,
|
||||||
|
request_ids=[r.id for r in self.requests],
|
||||||
|
)
|
||||||
|
|
||||||
|
# length of the batch
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
# pass list of request ids, returns batch with only those request ids
|
||||||
|
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
|
||||||
|
assert(len(request_ids) > 0)
|
||||||
|
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
|
input_ids_list = []
|
||||||
|
past_key_values_list = []
|
||||||
|
|
||||||
|
# loop through requests, keep ones that should remain
|
||||||
|
for new_idx, request_id in enumerate(request_ids):
|
||||||
|
assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
|
||||||
|
|
||||||
|
requests_idx_mapping[request_id] = new_idx
|
||||||
|
|
||||||
|
old_idx = self.requests_idx_mapping[request_id]
|
||||||
|
requests.append(self.requests[old_idx])
|
||||||
|
input_ids_list.append(self.input_ids_list[old_idx])
|
||||||
|
past_key_values_list.append(self.past_key_values_list[old_idx])
|
||||||
|
|
||||||
|
# update batch state
|
||||||
|
self.requests = requests
|
||||||
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
|
self.input_ids_list = input_ids_list
|
||||||
|
self.past_key_values_list = past_key_values_list
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
# combine two batches into one
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
|
||||||
|
assert len(batches) > 1, "must have more than 1 batch to concatenate"
|
||||||
|
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
|
input_ids_list = []
|
||||||
|
past_key_values_list = []
|
||||||
|
|
||||||
|
start_index = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
|
||||||
|
|
||||||
|
# concatenate request, input_ids, and past_key_values lists
|
||||||
|
requests.extend(batch.requests)
|
||||||
|
input_ids_list.extend(batch.input_ids_list)
|
||||||
|
past_key_values_list.extend(batch.past_key_values_list)
|
||||||
|
|
||||||
|
# merge the request_id to index mapping
|
||||||
|
if i == 0:
|
||||||
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
|
else:
|
||||||
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
|
requests_idx_mapping[k] = v + start_index
|
||||||
|
|
||||||
|
start_index += len(batch)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids_list=input_ids_list,
|
||||||
|
past_key_values_list=past_key_values_list
|
||||||
|
)
|
||||||
|
|
||||||
|
class DeepSparseCausalLM:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
tokenizer_path: str,
|
||||||
|
):
|
||||||
|
# setup tokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
if not self.tokenizer.pad_token:
|
||||||
|
assert self.tokenizer.eos_token
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
# setup model
|
||||||
|
self.model = DeepSparseDecoderModel(
|
||||||
|
onnx_file_path = model_path,
|
||||||
|
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
|
||||||
|
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: switch to NextTokenChooser
|
||||||
|
def sample_token(
|
||||||
|
self,
|
||||||
|
logits: np.ndarray
|
||||||
|
):
|
||||||
|
# assert b=1 for now
|
||||||
|
assert(logits.shape[0] == 1)
|
||||||
|
|
||||||
|
# grab logits for the last item in the sequence
|
||||||
|
# shape == (batch, seq, vocabulary_size)
|
||||||
|
return np.argmax(logits[0,-1,:])
|
||||||
|
|
||||||
|
# TODO: switch to StoppingCriteria
|
||||||
|
def should_stop(
|
||||||
|
self,
|
||||||
|
num_tokens_processed: int,
|
||||||
|
generated_token_id: int
|
||||||
|
):
|
||||||
|
if num_tokens_processed >= self.model.sequence_length:
|
||||||
|
return True
|
||||||
|
if generated_token_id == self.tokenizer.eos_token_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate_token(
|
||||||
|
self,
|
||||||
|
batch: DeepSparseCausalLMBatch,
|
||||||
|
) -> (List[Generation], Optional[DeepSparseCausalLMBatch]):
|
||||||
|
|
||||||
|
generations: List[Generation] = []
|
||||||
|
all_stopped = True
|
||||||
|
|
||||||
|
# for each member of the batch:
|
||||||
|
# a) run inference
|
||||||
|
# b) sample and check stopping criteria
|
||||||
|
# c) create generation
|
||||||
|
# d) update batch
|
||||||
|
for i, (request, input_ids, past_key_values,) in enumerate(
|
||||||
|
zip(
|
||||||
|
batch.requests,
|
||||||
|
batch.input_ids_list,
|
||||||
|
batch.past_key_values_list
|
||||||
|
)
|
||||||
|
):
|
||||||
|
assert len(input_ids.shape) == 2
|
||||||
|
assert input_ids.shape[0] == 1
|
||||||
|
|
||||||
|
# a) run inference
|
||||||
|
logits, past_key_values = self.model(input_ids, past_key_values)
|
||||||
|
|
||||||
|
# b) sample token and check stopping criteria
|
||||||
|
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
||||||
|
generated_token_id = self.sample_token(logits)
|
||||||
|
generated_token = self.tokenizer.decode(generated_token_id)
|
||||||
|
stop = self.should_stop(
|
||||||
|
num_tokens_processed=input_ids.shape[1] + 1,
|
||||||
|
generated_token_id = generated_token_id
|
||||||
|
)
|
||||||
|
if not stop:
|
||||||
|
all_stopped = False
|
||||||
|
|
||||||
|
# c) make generation
|
||||||
|
generations.append(Generation(
|
||||||
|
request_id=request.id,
|
||||||
|
token=generated_token,
|
||||||
|
token_id=generated_token_id,
|
||||||
|
stopped=stop
|
||||||
|
))
|
||||||
|
|
||||||
|
# d) update batch
|
||||||
|
# TODO: this does not occur in place)
|
||||||
|
assert len(batch.input_ids_list[i].shape) == 2
|
||||||
|
assert batch.input_ids_list[i].shape[0] == 1
|
||||||
|
batch.input_ids_list[i] = np.append(
|
||||||
|
batch.input_ids_list[i],
|
||||||
|
np.array([[generated_token_id]]),
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
|
batch.past_key_values_list[i] = past_key_values
|
||||||
|
|
||||||
|
# if all elements of the batch are done, return null for batch
|
||||||
|
if all_stopped:
|
||||||
|
return generations, None
|
||||||
|
|
||||||
|
# return generation + updated batch
|
||||||
|
return generations, batch
|
241
server/deepsparse/service/model.py
Normal file
241
server/deepsparse/service/model.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
import os
|
||||||
|
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
|
||||||
|
from deepsparse import Context
|
||||||
|
from deepsparse.engine import LIB
|
||||||
|
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
|
||||||
|
from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask
|
||||||
|
|
||||||
|
PAST_KEY_VALUES_NAME = "past_key_values"
|
||||||
|
|
||||||
|
class DeepSparsePastKeyValues:
|
||||||
|
def __init__(self):
|
||||||
|
prev_num_tokens = 0
|
||||||
|
num_frozen_tokens = 1
|
||||||
|
self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
|
||||||
|
|
||||||
|
class DeepSparseDecoderEngine:
|
||||||
|
def __init__ (
|
||||||
|
self,
|
||||||
|
onnx_file_path: str,
|
||||||
|
sequence_length: int = 1024,
|
||||||
|
input_ids_length: int = 1,
|
||||||
|
engine_context: Optional[Context] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
# setup ONNX graph
|
||||||
|
onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs(
|
||||||
|
onnx_file_path=onnx_file_path,
|
||||||
|
batch_size=1,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
input_ids_length=input_ids_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compile engine
|
||||||
|
self.engine = create_engine(
|
||||||
|
onnx_file_path=onnx_file_path,
|
||||||
|
engine_type=DEEPSPARSE_ENGINE,
|
||||||
|
engine_args={"cached_outputs": cached_outputs},
|
||||||
|
context=engine_context,
|
||||||
|
)
|
||||||
|
print(self.engine)
|
||||||
|
|
||||||
|
# save utilties
|
||||||
|
self.past_key_value_dtype = data_type
|
||||||
|
self.onnx_inputs = self.engine.input_names
|
||||||
|
self.empty_past_key_values = self.make_empty_past_key_values()
|
||||||
|
|
||||||
|
# forward function
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
engine_inputs: Dict[str, np.ndarray],
|
||||||
|
past_key_values: DeepSparsePastKeyValues,
|
||||||
|
val_inputs: bool = True
|
||||||
|
):
|
||||||
|
# format input into lists (we pass empty past key values)
|
||||||
|
inputs = [
|
||||||
|
self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
|
||||||
|
else engine_inputs[name] for name in self.engine.input_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# validate inputs formatted correctly
|
||||||
|
if val_inputs:
|
||||||
|
self.engine._validate_inputs(inputs)
|
||||||
|
|
||||||
|
# run inference, updates past_key_values internally
|
||||||
|
output = self.engine._eng_net.execute_list_out(
|
||||||
|
inputs,
|
||||||
|
past_key_values.internal_past_key_values
|
||||||
|
)
|
||||||
|
logits = output[0]
|
||||||
|
return logits, past_key_values
|
||||||
|
|
||||||
|
# empty past kvs (dummy values to be passed around)
|
||||||
|
def make_empty_past_key_values(self):
|
||||||
|
past_key_values = {}
|
||||||
|
for idx, name in enumerate(self.onnx_inputs):
|
||||||
|
if name.startswith(PAST_KEY_VALUES_NAME):
|
||||||
|
past_key_values[name] = np.zeros(
|
||||||
|
self.engine.input_shapes[idx],
|
||||||
|
dtype=self.past_key_value_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
class DeepSparseDecoderModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
onnx_file_path: str,
|
||||||
|
sequence_length: int = 1024,
|
||||||
|
multitoken_length: int = 16,
|
||||||
|
engine_context: Optional[Context] = None,
|
||||||
|
):
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.multitoken_length = multitoken_length
|
||||||
|
|
||||||
|
# compile decode engine
|
||||||
|
self.singletoken_engine = DeepSparseDecoderEngine(
|
||||||
|
onnx_file_path=onnx_file_path,
|
||||||
|
engine_context=engine_context,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
input_ids_length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compile prefill engine
|
||||||
|
self.multitoken_engine = DeepSparseDecoderEngine(
|
||||||
|
onnx_file_path=onnx_file_path,
|
||||||
|
engine_context=engine_context,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
input_ids_length=self.multitoken_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "input_ids" in self.singletoken_engine.onnx_inputs
|
||||||
|
assert "attention_mask" in self.singletoken_engine.onnx_inputs
|
||||||
|
assert "causal_mask" in self.singletoken_engine.onnx_inputs
|
||||||
|
assert "positions" in self.singletoken_engine.onnx_inputs
|
||||||
|
|
||||||
|
def engine_inputs_for_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
):
|
||||||
|
# split batch into N token_batches
|
||||||
|
num_batches = input_ids.shape[1] // self.multitoken_length
|
||||||
|
token_batches = [
|
||||||
|
input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
|
||||||
|
for i in range(0, num_batches)
|
||||||
|
]
|
||||||
|
|
||||||
|
# format inputs for each of the N token_batches
|
||||||
|
for idx, token_batch in enumerate(token_batches):
|
||||||
|
num_processed_tokens = self.multitoken_length * idx
|
||||||
|
|
||||||
|
engine_inputs = {}
|
||||||
|
engine_inputs["input_ids"] = token_batch
|
||||||
|
|
||||||
|
# make attention mask from the right
|
||||||
|
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||||
|
engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
|
||||||
|
|
||||||
|
# make positions (building from the right)
|
||||||
|
# TODO: handle case when multitoken engine is 1
|
||||||
|
assert self.multitoken_length > 1
|
||||||
|
engine_inputs["positions"] = np.arange(
|
||||||
|
num_processed_tokens, num_processed_tokens + self.multitoken_length
|
||||||
|
).reshape(1, -1).astype(np.int64)
|
||||||
|
|
||||||
|
# make causal mask (building from the right)
|
||||||
|
engine_inputs["causal_mask"] = create_causal_mask(
|
||||||
|
input_ids=engine_inputs["input_ids"],
|
||||||
|
attention_mask=engine_inputs["attention_mask"]
|
||||||
|
)
|
||||||
|
yield engine_inputs
|
||||||
|
|
||||||
|
def engine_inputs_for_decode(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
):
|
||||||
|
engine_inputs = {}
|
||||||
|
engine_inputs["input_ids"] = input_ids[:,-1:]
|
||||||
|
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||||
|
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
|
||||||
|
|
||||||
|
engine_inputs["causal_mask"] = create_causal_mask(
|
||||||
|
engine_inputs["input_ids"],
|
||||||
|
engine_inputs["attention_mask"]
|
||||||
|
)
|
||||||
|
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
|
||||||
|
|
||||||
|
return engine_inputs
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
past_key_values: DeepSparsePastKeyValues
|
||||||
|
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||||
|
|
||||||
|
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||||
|
assert len(input_ids.shape) == 2
|
||||||
|
assert input_ids.shape[0] == 1
|
||||||
|
assert input_ids.shape[1] < self.sequence_length
|
||||||
|
|
||||||
|
engine_inputs = self.engine_inputs_for_decode(input_ids)
|
||||||
|
logits, past_key_values = self.singletoken_engine(
|
||||||
|
engine_inputs,
|
||||||
|
past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits, past_key_values
|
||||||
|
|
||||||
|
def prefill(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||||
|
|
||||||
|
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||||
|
assert len(input_ids.shape) == 2
|
||||||
|
assert input_ids.shape[0] == 1
|
||||||
|
assert input_ids.shape[1] < self.sequence_length
|
||||||
|
|
||||||
|
tokens_processed = 0
|
||||||
|
|
||||||
|
# setup empty past key values
|
||||||
|
past_key_values = DeepSparsePastKeyValues()
|
||||||
|
|
||||||
|
# loop through chunks, run inference w/ multitoken engine
|
||||||
|
for engine_inputs in self.engine_inputs_for_prefill(input_ids):
|
||||||
|
logits, past_key_values = self.multitoken_engine(
|
||||||
|
engine_inputs,
|
||||||
|
past_key_values
|
||||||
|
)
|
||||||
|
tokens_processed += self.multitoken_length
|
||||||
|
|
||||||
|
# if anything left over, run inference w/ singletoken engine
|
||||||
|
while tokens_processed < input_ids.shape[1]:
|
||||||
|
logits, past_key_values = self.decode(
|
||||||
|
input_ids=input_ids[:,:tokens_processed+1],
|
||||||
|
past_key_values=past_key_values
|
||||||
|
)
|
||||||
|
tokens_processed += 1
|
||||||
|
# print(logits[:,-1:,:])
|
||||||
|
|
||||||
|
return logits, past_key_values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
past_key_values: Optional[DeepSparsePastKeyValues] = None,
|
||||||
|
):
|
||||||
|
if past_key_values is None:
|
||||||
|
return self.prefill(input_ids)
|
||||||
|
else:
|
||||||
|
return self.decode(input_ids, past_key_values)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
past_key_values: Optional[DeepSparsePastKeyValues] = None,
|
||||||
|
):
|
||||||
|
return self.forward(input_ids, past_key_values)
|
76
server/deepsparse/service/service.py
Normal file
76
server/deepsparse/service/service.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
from server.deepsparse.service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
|
||||||
|
from server.deepsparse.utils import Generation, CachedBatch, Batch
|
||||||
|
|
||||||
|
class BatchCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
|
||||||
|
|
||||||
|
def pop(self, batch_id: int) -> DeepSparseCausalLMBatch:
|
||||||
|
batch = self.cache.pop(batch_id, None)
|
||||||
|
assert batch is not None, "Batch ID {batch_id} not found in cache."
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def set(self, entry: DeepSparseCausalLMBatch):
|
||||||
|
if entry is not None:
|
||||||
|
self.cache[entry.batch_id] = entry
|
||||||
|
|
||||||
|
def delete(self, batch_id: int):
|
||||||
|
batch = self.pop(batch_id)
|
||||||
|
if batch is not None:
|
||||||
|
del batch
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
keys = list(self.cache.keys())
|
||||||
|
for k in keys:
|
||||||
|
self.delete(k)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cache.keys())
|
||||||
|
|
||||||
|
class DeepSparseService:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: DeepSparseCausalLM
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.cache = BatchCache()
|
||||||
|
|
||||||
|
def ClearCache(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def FilterBatch(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
|
||||||
|
ds_batch = self.cache.pop(batch_id)
|
||||||
|
filtered_ds_batch = ds_batch.filter(request_ids)
|
||||||
|
self.cache.set(filtered_ds_batch)
|
||||||
|
|
||||||
|
return filtered_ds_batch.to_cached_batch()
|
||||||
|
|
||||||
|
def Prefill(self, batch: Batch) -> Tuple[Generation, CachedBatch]:
|
||||||
|
ds_batch = DeepSparseCausalLMBatch.from_batch(
|
||||||
|
batch=batch,
|
||||||
|
tokenizer=self.model.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||||
|
assert len(generations) == 1
|
||||||
|
self.cache.set(next_ds_batch)
|
||||||
|
|
||||||
|
return generations[0], next_ds_batch.to_cached_batch()
|
||||||
|
|
||||||
|
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
|
||||||
|
assert len(batches) != 0, "Must provide at least one batch"
|
||||||
|
|
||||||
|
ds_batches = []
|
||||||
|
for cached_batch in batches:
|
||||||
|
ds_batches.append(self.cache.pop(cached_batch.batch_id))
|
||||||
|
|
||||||
|
if len(ds_batches) > 1:
|
||||||
|
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
|
||||||
|
else:
|
||||||
|
ds_batch = ds_batches[0]
|
||||||
|
|
||||||
|
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||||
|
self.cache.set(next_ds_batch)
|
||||||
|
|
||||||
|
return generations, (next_ds_batch.to_cached_batch() if next_ds_batch else None)
|
35
server/deepsparse/utils.py
Normal file
35
server/deepsparse/utils.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from queue import Queue
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Request:
|
||||||
|
id: int
|
||||||
|
prompt: str
|
||||||
|
max_generated_tokens: int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Batch:
|
||||||
|
id: int
|
||||||
|
requests: List[Request]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedBatch:
|
||||||
|
batch_id: int
|
||||||
|
request_ids: List[int]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.request_ids)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Generation:
|
||||||
|
request_id: int
|
||||||
|
token: Optional[str]
|
||||||
|
token_id: Optional[str]
|
||||||
|
stopped: bool
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerateRequest:
|
||||||
|
prompt: str
|
||||||
|
max_generated_tokens: int
|
||||||
|
response_stream: Queue[Generation]
|
Loading…
Reference in New Issue
Block a user