readded changed file names

This commit is contained in:
rsnm2 2023-08-24 17:50:01 +00:00
parent 06fc85f93c
commit cd3349f53b
6 changed files with 752 additions and 0 deletions

161
server/deepsparse/router.py Normal file
View File

@ -0,0 +1,161 @@
from queue import Queue
from typing import List, Dict, Optional, Tuple
from server.deepsparse.service.service import DeepSparseService
from server.deepsparse.utils import CachedBatch, Batch, Generation, GenerateRequest, Request
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Queue[GenerateRequest] = Queue()
def append(self, generate_request: GenerateRequest):
self.queue.put(generate_request)
# TODO: enable multiple prefill requests in a batch
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
# if not blocking, return none if empty
if not block and self.queue.empty():
return None
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)
class DeepSparseRouter:
def __init__(self, service: DeepSparseService):
self.service: DeepSparseService = service
self.queue: DeepSparseQueue = DeepSparseQueue()
def generate(self, generate_request: GenerateRequest):
self.queue.append(generate_request)
def prefill(
self,
batch: Batch,
generate_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]:
generation, next_batch = self.service.Prefill(batch=batch)
active_generate_request_ids = self.filter_send_generations([generation], generate_requests)
return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids)
def decode(
self,
batches: List[CachedBatch],
generate_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]:
generations, next_batch = self.service.Decode(batches=batches)
active_generate_request_ids = self.filter_send_generations(generations, generate_requests)
return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids)
def filter_send_generations(
self,
generations: List[Generation],
generate_requests: Dict[int, GenerateRequest]
) -> List[int]:
active_request_ids = []
for generation in generations:
# send generation to the response stream
generate_requests[generation.request_id].response_stream.put(generation)
# remove request from active requests if stopped
if generation.stopped:
generate_requests.pop(generation.request_id)
else:
active_request_ids.append(generation.request_id)
return active_request_ids
def filter_batch(
self,
batch: Optional[CachedBatch],
active_generate_request_ids: List[int]
) -> Optional[CachedBatch]:
# if batch done OR nothing to filter
if batch is None or len(batch) == len(active_generate_request_ids):
return batch
# active request_ids
batch.request_ids = active_generate_request_ids
# if all requests complete, clear cache
if len(batch) == 0:
self.service.ClearCache()
return None
return self.service.FilterBatch(batch_id=batch.batch_id, request_ids=batch.request_ids)
# TODO: update to do more sophisticated logic as to when to do a prefill
def batching_task(router: DeepSparseRouter):
while True:
# loop until no requests to process (note: this blocks if queue is empty)
next_batch = router.queue.next_batch(block=True)
while next_batch is not None:
batch, generate_requests = next_batch
# HACK for development --- breaks out of the cycle
if batch.requests[0].prompt == "stop":
return
# run prefill
cached_batch = router.prefill(
batch=batch,
generate_requests=generate_requests
)
# loop until we do not reiceve any cached batch from the service
# == until all active requests have met their stopping criteria
while cached_batch is not None:
batches = [cached_batch]
# try to get a new batch and run prefill on this batch
next_batch = router.queue.next_batch(block=False)
if next_batch is not None:
new_batch, new_generate_requests = next_batch
new_cached_batch = router.prefill(
batch=new_batch,
generate_requests=new_generate_requests
)
if new_cached_batch is not None:
batches.append(new_cached_batch)
assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
generate_requests.update(new_generate_requests)
# run decode
cached_batch = router.decode(
batches=batches,
generate_requests=generate_requests
)
next_batch = router.queue.next_batch(block=False)

View File

View File

@ -0,0 +1,239 @@
from dataclasses import dataclass
from typing import List, Dict, Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np
from server.deepsparse.service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from server.deepsparse.utils import Request, Batch, CachedBatch, Generation
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@dataclass
class DeepSparseCausalLMBatch:
batch_id: int
requests: List[Request]
requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
@classmethod
def from_batch(
cls,
batch: Batch,
tokenizer: PreTrainedTokenizerBase,
) -> "DeepSparseCausalLMBatch":
# parse batch
requests_idx_mapping = {}
input_ids_list = []
# setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
padding, truncation = "longest", False
# loop through items in the batch
for idx, r in enumerate(batch.requests):
requests_idx_mapping[r.id] = idx
# setup inputs_ids, past_key_values
tokenized_inputs = tokenizer(
r.prompt,
return_tensors="np",
padding=padding,
truncation=truncation,
return_token_type_ids=False,
max_length=DEEPSPARSE_SEQUENCE_LENGTH
)
input_ids_list.append(tokenized_inputs["input_ids"])
return cls(
batch_id=batch.id,
requests=batch.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests),
)
def to_cached_batch(self) -> CachedBatch:
return CachedBatch(
batch_id = self.batch_id,
request_ids=[r.id for r in self.requests],
)
# length of the batch
def __len__(self):
return len(self.requests)
# pass list of request ids, returns batch with only those request ids
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
assert(len(request_ids) > 0)
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
requests_idx_mapping[request_id] = new_idx
old_idx = self.requests_idx_mapping[request_id]
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx])
# update batch state
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
return self
# combine two batches into one
@classmethod
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
assert len(batches) > 1, "must have more than 1 batch to concatenate"
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
start_index = 0
for i, batch in enumerate(batches):
assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
# concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
# merge the request_id to index mapping
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
start_index += len(batch)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list
)
class DeepSparseCausalLM:
def __init__(
self,
model_path: str,
tokenizer_path: str,
):
# setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.tokenizer.padding_side = "left"
if not self.tokenizer.pad_token:
assert self.tokenizer.eos_token
self.tokenizer.pad_token = self.tokenizer.eos_token
# setup model
self.model = DeepSparseDecoderModel(
onnx_file_path = model_path,
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
)
# TODO: switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
# assert b=1 for now
assert(logits.shape[0] == 1)
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
# TODO: switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token(
self,
batch: DeepSparseCausalLMBatch,
) -> (List[Generation], Optional[DeepSparseCausalLMBatch]):
generations: List[Generation] = []
all_stopped = True
# for each member of the batch:
# a) run inference
# b) sample and check stopping criteria
# c) create generation
# d) update batch
for i, (request, input_ids, past_key_values,) in enumerate(
zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list
)
):
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
# a) run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits)
generated_token = self.tokenizer.decode(generated_token_id)
stop = self.should_stop(
num_tokens_processed=input_ids.shape[1] + 1,
generated_token_id = generated_token_id
)
if not stop:
all_stopped = False
# c) make generation
generations.append(Generation(
request_id=request.id,
token=generated_token,
token_id=generated_token_id,
stopped=stop
))
# d) update batch
# TODO: this does not occur in place)
assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append(
batch.input_ids_list[i],
np.array([[generated_token_id]]),
axis=1
)
batch.past_key_values_list[i] = past_key_values
# if all elements of the batch are done, return null for batch
if all_stopped:
return generations, None
# return generation + updated batch
return generations, batch

View File

@ -0,0 +1,241 @@
import os
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
import numpy as np
from typing import Optional, List, Dict
from deepsparse import Context
from deepsparse.engine import LIB
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask
PAST_KEY_VALUES_NAME = "past_key_values"
class DeepSparsePastKeyValues:
def __init__(self):
prev_num_tokens = 0
num_frozen_tokens = 1
self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
class DeepSparseDecoderEngine:
def __init__ (
self,
onnx_file_path: str,
sequence_length: int = 1024,
input_ids_length: int = 1,
engine_context: Optional[Context] = None,
):
# setup ONNX graph
onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
batch_size=1,
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)
# compile engine
self.engine = create_engine(
onnx_file_path=onnx_file_path,
engine_type=DEEPSPARSE_ENGINE,
engine_args={"cached_outputs": cached_outputs},
context=engine_context,
)
print(self.engine)
# save utilties
self.past_key_value_dtype = data_type
self.onnx_inputs = self.engine.input_names
self.empty_past_key_values = self.make_empty_past_key_values()
# forward function
def __call__(
self,
engine_inputs: Dict[str, np.ndarray],
past_key_values: DeepSparsePastKeyValues,
val_inputs: bool = True
):
# format input into lists (we pass empty past key values)
inputs = [
self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
else engine_inputs[name] for name in self.engine.input_names
]
# validate inputs formatted correctly
if val_inputs:
self.engine._validate_inputs(inputs)
# run inference, updates past_key_values internally
output = self.engine._eng_net.execute_list_out(
inputs,
past_key_values.internal_past_key_values
)
logits = output[0]
return logits, past_key_values
# empty past kvs (dummy values to be passed around)
def make_empty_past_key_values(self):
past_key_values = {}
for idx, name in enumerate(self.onnx_inputs):
if name.startswith(PAST_KEY_VALUES_NAME):
past_key_values[name] = np.zeros(
self.engine.input_shapes[idx],
dtype=self.past_key_value_dtype
)
return past_key_values
class DeepSparseDecoderModel:
def __init__(
self,
onnx_file_path: str,
sequence_length: int = 1024,
multitoken_length: int = 16,
engine_context: Optional[Context] = None,
):
self.sequence_length = sequence_length
self.multitoken_length = multitoken_length
# compile decode engine
self.singletoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=1,
)
# compile prefill engine
self.multitoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=self.multitoken_length,
)
assert "input_ids" in self.singletoken_engine.onnx_inputs
assert "attention_mask" in self.singletoken_engine.onnx_inputs
assert "causal_mask" in self.singletoken_engine.onnx_inputs
assert "positions" in self.singletoken_engine.onnx_inputs
def engine_inputs_for_prefill(
self,
input_ids: np.ndarray,
):
# split batch into N token_batches
num_batches = input_ids.shape[1] // self.multitoken_length
token_batches = [
input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
for i in range(0, num_batches)
]
# format inputs for each of the N token_batches
for idx, token_batch in enumerate(token_batches):
num_processed_tokens = self.multitoken_length * idx
engine_inputs = {}
engine_inputs["input_ids"] = token_batch
# make attention mask from the right
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
# make positions (building from the right)
# TODO: handle case when multitoken engine is 1
assert self.multitoken_length > 1
engine_inputs["positions"] = np.arange(
num_processed_tokens, num_processed_tokens + self.multitoken_length
).reshape(1, -1).astype(np.int64)
# make causal mask (building from the right)
engine_inputs["causal_mask"] = create_causal_mask(
input_ids=engine_inputs["input_ids"],
attention_mask=engine_inputs["attention_mask"]
)
yield engine_inputs
def engine_inputs_for_decode(
self,
input_ids: np.ndarray,
):
engine_inputs = {}
engine_inputs["input_ids"] = input_ids[:,-1:]
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
engine_inputs["causal_mask"] = create_causal_mask(
engine_inputs["input_ids"],
engine_inputs["attention_mask"]
)
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
return engine_inputs
def decode(
self,
input_ids: np.ndarray,
past_key_values: DeepSparsePastKeyValues
) -> (np.ndarray, DeepSparsePastKeyValues):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
engine_inputs = self.engine_inputs_for_decode(input_ids)
logits, past_key_values = self.singletoken_engine(
engine_inputs,
past_key_values
)
return logits, past_key_values
def prefill(
self,
input_ids: np.ndarray,
) -> (np.ndarray, DeepSparsePastKeyValues):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
tokens_processed = 0
# setup empty past key values
past_key_values = DeepSparsePastKeyValues()
# loop through chunks, run inference w/ multitoken engine
for engine_inputs in self.engine_inputs_for_prefill(input_ids):
logits, past_key_values = self.multitoken_engine(
engine_inputs,
past_key_values
)
tokens_processed += self.multitoken_length
# if anything left over, run inference w/ singletoken engine
while tokens_processed < input_ids.shape[1]:
logits, past_key_values = self.decode(
input_ids=input_ids[:,:tokens_processed+1],
past_key_values=past_key_values
)
tokens_processed += 1
# print(logits[:,-1:,:])
return logits, past_key_values
def forward(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
if past_key_values is None:
return self.prefill(input_ids)
else:
return self.decode(input_ids, past_key_values)
def __call__(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
return self.forward(input_ids, past_key_values)

View File

@ -0,0 +1,76 @@
from typing import Dict, List, Tuple
from server.deepsparse.service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
from server.deepsparse.utils import Generation, CachedBatch, Batch
class BatchCache:
def __init__(self):
self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
def pop(self, batch_id: int) -> DeepSparseCausalLMBatch:
batch = self.cache.pop(batch_id, None)
assert batch is not None, "Batch ID {batch_id} not found in cache."
return batch
def set(self, entry: DeepSparseCausalLMBatch):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
del batch
def clear(self):
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())
class DeepSparseService:
def __init__(
self,
model: DeepSparseCausalLM
):
self.model = model
self.cache = BatchCache()
def ClearCache(self):
self.cache.clear()
def FilterBatch(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
ds_batch = self.cache.pop(batch_id)
filtered_ds_batch = ds_batch.filter(request_ids)
self.cache.set(filtered_ds_batch)
return filtered_ds_batch.to_cached_batch()
def Prefill(self, batch: Batch) -> Tuple[Generation, CachedBatch]:
ds_batch = DeepSparseCausalLMBatch.from_batch(
batch=batch,
tokenizer=self.model.tokenizer
)
generations, next_ds_batch = self.model.generate_token(ds_batch)
assert len(generations) == 1
self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_cached_batch()
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
assert len(batches) != 0, "Must provide at least one batch"
ds_batches = []
for cached_batch in batches:
ds_batches.append(self.cache.pop(cached_batch.batch_id))
if len(ds_batches) > 1:
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
else:
ds_batch = ds_batches[0]
generations, next_ds_batch = self.model.generate_token(ds_batch)
self.cache.set(next_ds_batch)
return generations, (next_ds_batch.to_cached_batch() if next_ds_batch else None)

View File

@ -0,0 +1,35 @@
from dataclasses import dataclass
from queue import Queue
from typing import List, Optional
@dataclass
class Request:
id: int
prompt: str
max_generated_tokens: int
@dataclass
class Batch:
id: int
requests: List[Request]
@dataclass
class CachedBatch:
batch_id: int
request_ids: List[int]
def __len__(self):
return len(self.request_ids)
@dataclass
class Generation:
request_id: int
token: Optional[str]
token_id: Optional[str]
stopped: bool
@dataclass
class GenerateRequest:
prompt: str
max_generated_tokens: int
response_stream: Queue[Generation]