diff --git a/server/deepsparse/router.py b/server/deepsparse/router.py new file mode 100644 index 00000000..950bea68 --- /dev/null +++ b/server/deepsparse/router.py @@ -0,0 +1,161 @@ +from queue import Queue +from typing import List, Dict, Optional, Tuple +from server.deepsparse.service.service import DeepSparseService +from server.deepsparse.utils import CachedBatch, Batch, Generation, GenerateRequest, Request + +# TODO: implement logic for maximum size of the queue based on memory usage +class DeepSparseQueue: + def __init__(self): + self.next_request_id: int = 0 + self.next_batch_id: int = 0 + self.queue: Queue[GenerateRequest] = Queue() + + def append(self, generate_request: GenerateRequest): + self.queue.put(generate_request) + + # TODO: enable multiple prefill requests in a batch + def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]: + + # if not blocking, return none if empty + if not block and self.queue.empty(): + return None + + # if block = True, this blocks until something ready + # if block = False, the queue has data (if not an exception is raised) + # while queue.empty() == False does not guarentee data + # the queue is only subscribed to by one thread (this one) + # since batching_task is the only function that calls next_batch + generate_request = self.queue.get(block=block) + generate_requests = {self.next_request_id: generate_request} + + # format into request + request = Request( + id=self.next_request_id, + prompt=generate_request.prompt, + max_generated_tokens=generate_request.max_generated_tokens + ) + self.next_request_id += 1 + + # format into batch + batch = Batch( + id = self.next_batch_id, + requests=[request] + ) + self.next_batch_id += 1 + + # return batch, generate_requests + return (batch, generate_requests) + +class DeepSparseRouter: + def __init__(self, service: DeepSparseService): + self.service: DeepSparseService = service + self.queue: DeepSparseQueue = DeepSparseQueue() + + def generate(self, generate_request: GenerateRequest): + self.queue.append(generate_request) + + def prefill( + self, + batch: Batch, + generate_requests: Dict[int, GenerateRequest] + ) -> Optional[CachedBatch]: + + generation, next_batch = self.service.Prefill(batch=batch) + active_generate_request_ids = self.filter_send_generations([generation], generate_requests) + return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids) + + def decode( + self, + batches: List[CachedBatch], + generate_requests: Dict[int,GenerateRequest] + ) -> Optional[CachedBatch]: + + generations, next_batch = self.service.Decode(batches=batches) + active_generate_request_ids = self.filter_send_generations(generations, generate_requests) + return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids) + + def filter_send_generations( + self, + generations: List[Generation], + generate_requests: Dict[int, GenerateRequest] + ) -> List[int]: + + active_request_ids = [] + for generation in generations: + # send generation to the response stream + generate_requests[generation.request_id].response_stream.put(generation) + + # remove request from active requests if stopped + if generation.stopped: + generate_requests.pop(generation.request_id) + else: + active_request_ids.append(generation.request_id) + + return active_request_ids + + def filter_batch( + self, + batch: Optional[CachedBatch], + active_generate_request_ids: List[int] + ) -> Optional[CachedBatch]: + + # if batch done OR nothing to filter + if batch is None or len(batch) == len(active_generate_request_ids): + return batch + + # active request_ids + batch.request_ids = active_generate_request_ids + + # if all requests complete, clear cache + if len(batch) == 0: + self.service.ClearCache() + return None + + return self.service.FilterBatch(batch_id=batch.batch_id, request_ids=batch.request_ids) + + +# TODO: update to do more sophisticated logic as to when to do a prefill +def batching_task(router: DeepSparseRouter): + while True: + # loop until no requests to process (note: this blocks if queue is empty) + next_batch = router.queue.next_batch(block=True) + while next_batch is not None: + batch, generate_requests = next_batch + + # HACK for development --- breaks out of the cycle + if batch.requests[0].prompt == "stop": + return + + # run prefill + cached_batch = router.prefill( + batch=batch, + generate_requests=generate_requests + ) + + # loop until we do not reiceve any cached batch from the service + # == until all active requests have met their stopping criteria + while cached_batch is not None: + batches = [cached_batch] + + # try to get a new batch and run prefill on this batch + next_batch = router.queue.next_batch(block=False) + + if next_batch is not None: + new_batch, new_generate_requests = next_batch + new_cached_batch = router.prefill( + batch=new_batch, + generate_requests=new_generate_requests + ) + + if new_cached_batch is not None: + batches.append(new_cached_batch) + assert len(generate_requests.keys() & new_generate_requests.keys()) == 0 + generate_requests.update(new_generate_requests) + + # run decode + cached_batch = router.decode( + batches=batches, + generate_requests=generate_requests + ) + + next_batch = router.queue.next_batch(block=False) \ No newline at end of file diff --git a/server/deepsparse/server.py b/server/deepsparse/server.py new file mode 100644 index 00000000..e69de29b diff --git a/server/deepsparse/service/causal_lm.py b/server/deepsparse/service/causal_lm.py new file mode 100644 index 00000000..4e5cde2f --- /dev/null +++ b/server/deepsparse/service/causal_lm.py @@ -0,0 +1,239 @@ +from dataclasses import dataclass +from typing import List, Dict, Optional +from transformers import AutoTokenizer, PreTrainedTokenizerBase +import numpy as np + +from server.deepsparse.service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel +from server.deepsparse.utils import Request, Batch, CachedBatch, Generation + +DEEPSPARSE_SEQUENCE_LENGTH = 128 +DEEPSPARSE_MULTITOKEN_LENGTH = 4 + +@dataclass +class DeepSparseCausalLMBatch: + batch_id: int + requests: List[Request] + requests_idx_mapping: Dict[int,int] + input_ids_list: List[np.ndarray] + past_key_values_list: List[Optional[DeepSparsePastKeyValues]] + + @classmethod + def from_batch( + cls, + batch: Batch, + tokenizer: PreTrainedTokenizerBase, + ) -> "DeepSparseCausalLMBatch": + + # parse batch + requests_idx_mapping = {} + input_ids_list = [] + + # setup tokenizer for deepsparse left padding + tokenizer.padding_side = "left" + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + padding, truncation = "longest", False + + # loop through items in the batch + for idx, r in enumerate(batch.requests): + requests_idx_mapping[r.id] = idx + + # setup inputs_ids, past_key_values + tokenized_inputs = tokenizer( + r.prompt, + return_tensors="np", + padding=padding, + truncation=truncation, + return_token_type_ids=False, + max_length=DEEPSPARSE_SEQUENCE_LENGTH + ) + input_ids_list.append(tokenized_inputs["input_ids"]) + + return cls( + batch_id=batch.id, + requests=batch.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids_list=input_ids_list, + past_key_values_list=[None] * len(batch.requests), + ) + + def to_cached_batch(self) -> CachedBatch: + return CachedBatch( + batch_id = self.batch_id, + request_ids=[r.id for r in self.requests], + ) + + # length of the batch + def __len__(self): + return len(self.requests) + + # pass list of request ids, returns batch with only those request ids + def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: + assert(len(request_ids) > 0) + + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + + # loop through requests, keep ones that should remain + for new_idx, request_id in enumerate(request_ids): + assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch" + + requests_idx_mapping[request_id] = new_idx + + old_idx = self.requests_idx_mapping[request_id] + requests.append(self.requests[old_idx]) + input_ids_list.append(self.input_ids_list[old_idx]) + past_key_values_list.append(self.past_key_values_list[old_idx]) + + # update batch state + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids_list = input_ids_list + self.past_key_values_list = past_key_values_list + + return self + + # combine two batches into one + @classmethod + def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": + assert len(batches) > 1, "must have more than 1 batch to concatenate" + + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + + start_index = 0 + for i, batch in enumerate(batches): + assert batch.past_key_values_list is not None, "only concatenate prefilled batches" + + # concatenate request, input_ids, and past_key_values lists + requests.extend(batch.requests) + input_ids_list.extend(batch.input_ids_list) + past_key_values_list.extend(batch.past_key_values_list) + + # merge the request_id to index mapping + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + start_index += len(batch) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids_list=input_ids_list, + past_key_values_list=past_key_values_list + ) + +class DeepSparseCausalLM: + def __init__( + self, + model_path: str, + tokenizer_path: str, + ): + # setup tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.tokenizer.padding_side = "left" + if not self.tokenizer.pad_token: + assert self.tokenizer.eos_token + self.tokenizer.pad_token = self.tokenizer.eos_token + + # setup model + self.model = DeepSparseDecoderModel( + onnx_file_path = model_path, + sequence_length = DEEPSPARSE_SEQUENCE_LENGTH, + multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, + ) + + # TODO: switch to NextTokenChooser + def sample_token( + self, + logits: np.ndarray + ): + # assert b=1 for now + assert(logits.shape[0] == 1) + + # grab logits for the last item in the sequence + # shape == (batch, seq, vocabulary_size) + return np.argmax(logits[0,-1,:]) + + # TODO: switch to StoppingCriteria + def should_stop( + self, + num_tokens_processed: int, + generated_token_id: int + ): + if num_tokens_processed >= self.model.sequence_length: + return True + if generated_token_id == self.tokenizer.eos_token_id: + return True + return False + + def generate_token( + self, + batch: DeepSparseCausalLMBatch, + ) -> (List[Generation], Optional[DeepSparseCausalLMBatch]): + + generations: List[Generation] = [] + all_stopped = True + + # for each member of the batch: + # a) run inference + # b) sample and check stopping criteria + # c) create generation + # d) update batch + for i, (request, input_ids, past_key_values,) in enumerate( + zip( + batch.requests, + batch.input_ids_list, + batch.past_key_values_list + ) + ): + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == 1 + + # a) run inference + logits, past_key_values = self.model(input_ids, past_key_values) + + # b) sample token and check stopping criteria + # TODO: should use NextTokenChooser/StoppingCriteria (simple for now) + generated_token_id = self.sample_token(logits) + generated_token = self.tokenizer.decode(generated_token_id) + stop = self.should_stop( + num_tokens_processed=input_ids.shape[1] + 1, + generated_token_id = generated_token_id + ) + if not stop: + all_stopped = False + + # c) make generation + generations.append(Generation( + request_id=request.id, + token=generated_token, + token_id=generated_token_id, + stopped=stop + )) + + # d) update batch + # TODO: this does not occur in place) + assert len(batch.input_ids_list[i].shape) == 2 + assert batch.input_ids_list[i].shape[0] == 1 + batch.input_ids_list[i] = np.append( + batch.input_ids_list[i], + np.array([[generated_token_id]]), + axis=1 + ) + batch.past_key_values_list[i] = past_key_values + + # if all elements of the batch are done, return null for batch + if all_stopped: + return generations, None + + # return generation + updated batch + return generations, batch \ No newline at end of file diff --git a/server/deepsparse/service/model.py b/server/deepsparse/service/model.py new file mode 100644 index 00000000..9b0082bc --- /dev/null +++ b/server/deepsparse/service/model.py @@ -0,0 +1,241 @@ +import os +os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + +import numpy as np +from typing import Optional, List, Dict + +from deepsparse import Context +from deepsparse.engine import LIB +from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine +from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask + +PAST_KEY_VALUES_NAME = "past_key_values" + +class DeepSparsePastKeyValues: + def __init__(self): + prev_num_tokens = 0 + num_frozen_tokens = 1 + self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens) + +class DeepSparseDecoderEngine: + def __init__ ( + self, + onnx_file_path: str, + sequence_length: int = 1024, + input_ids_length: int = 1, + engine_context: Optional[Context] = None, + ): + + # setup ONNX graph + onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs( + onnx_file_path=onnx_file_path, + batch_size=1, + sequence_length=sequence_length, + input_ids_length=input_ids_length, + ) + + # compile engine + self.engine = create_engine( + onnx_file_path=onnx_file_path, + engine_type=DEEPSPARSE_ENGINE, + engine_args={"cached_outputs": cached_outputs}, + context=engine_context, + ) + print(self.engine) + + # save utilties + self.past_key_value_dtype = data_type + self.onnx_inputs = self.engine.input_names + self.empty_past_key_values = self.make_empty_past_key_values() + + # forward function + def __call__( + self, + engine_inputs: Dict[str, np.ndarray], + past_key_values: DeepSparsePastKeyValues, + val_inputs: bool = True + ): + # format input into lists (we pass empty past key values) + inputs = [ + self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) + else engine_inputs[name] for name in self.engine.input_names + ] + + # validate inputs formatted correctly + if val_inputs: + self.engine._validate_inputs(inputs) + + # run inference, updates past_key_values internally + output = self.engine._eng_net.execute_list_out( + inputs, + past_key_values.internal_past_key_values + ) + logits = output[0] + return logits, past_key_values + + # empty past kvs (dummy values to be passed around) + def make_empty_past_key_values(self): + past_key_values = {} + for idx, name in enumerate(self.onnx_inputs): + if name.startswith(PAST_KEY_VALUES_NAME): + past_key_values[name] = np.zeros( + self.engine.input_shapes[idx], + dtype=self.past_key_value_dtype + ) + + return past_key_values + +class DeepSparseDecoderModel: + def __init__( + self, + onnx_file_path: str, + sequence_length: int = 1024, + multitoken_length: int = 16, + engine_context: Optional[Context] = None, + ): + self.sequence_length = sequence_length + self.multitoken_length = multitoken_length + + # compile decode engine + self.singletoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=1, + ) + + # compile prefill engine + self.multitoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=self.multitoken_length, + ) + + assert "input_ids" in self.singletoken_engine.onnx_inputs + assert "attention_mask" in self.singletoken_engine.onnx_inputs + assert "causal_mask" in self.singletoken_engine.onnx_inputs + assert "positions" in self.singletoken_engine.onnx_inputs + + def engine_inputs_for_prefill( + self, + input_ids: np.ndarray, + ): + # split batch into N token_batches + num_batches = input_ids.shape[1] // self.multitoken_length + token_batches = [ + input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length] + for i in range(0, num_batches) + ] + + # format inputs for each of the N token_batches + for idx, token_batch in enumerate(token_batches): + num_processed_tokens = self.multitoken_length * idx + + engine_inputs = {} + engine_inputs["input_ids"] = token_batch + + # make attention mask from the right + engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) + engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1 + + # make positions (building from the right) + # TODO: handle case when multitoken engine is 1 + assert self.multitoken_length > 1 + engine_inputs["positions"] = np.arange( + num_processed_tokens, num_processed_tokens + self.multitoken_length + ).reshape(1, -1).astype(np.int64) + + # make causal mask (building from the right) + engine_inputs["causal_mask"] = create_causal_mask( + input_ids=engine_inputs["input_ids"], + attention_mask=engine_inputs["attention_mask"] + ) + yield engine_inputs + + def engine_inputs_for_decode( + self, + input_ids: np.ndarray, + ): + engine_inputs = {} + engine_inputs["input_ids"] = input_ids[:,-1:] + engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) + engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1 + + engine_inputs["causal_mask"] = create_causal_mask( + engine_inputs["input_ids"], + engine_inputs["attention_mask"] + ) + engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64) + + return engine_inputs + + def decode( + self, + input_ids: np.ndarray, + past_key_values: DeepSparsePastKeyValues + ) -> (np.ndarray, DeepSparsePastKeyValues): + + # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == 1 + assert input_ids.shape[1] < self.sequence_length + + engine_inputs = self.engine_inputs_for_decode(input_ids) + logits, past_key_values = self.singletoken_engine( + engine_inputs, + past_key_values + ) + + return logits, past_key_values + + def prefill( + self, + input_ids: np.ndarray, + ) -> (np.ndarray, DeepSparsePastKeyValues): + + # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == 1 + assert input_ids.shape[1] < self.sequence_length + + tokens_processed = 0 + + # setup empty past key values + past_key_values = DeepSparsePastKeyValues() + + # loop through chunks, run inference w/ multitoken engine + for engine_inputs in self.engine_inputs_for_prefill(input_ids): + logits, past_key_values = self.multitoken_engine( + engine_inputs, + past_key_values + ) + tokens_processed += self.multitoken_length + + # if anything left over, run inference w/ singletoken engine + while tokens_processed < input_ids.shape[1]: + logits, past_key_values = self.decode( + input_ids=input_ids[:,:tokens_processed+1], + past_key_values=past_key_values + ) + tokens_processed += 1 + # print(logits[:,-1:,:]) + + return logits, past_key_values + + def forward( + self, + input_ids: np.ndarray, + past_key_values: Optional[DeepSparsePastKeyValues] = None, + ): + if past_key_values is None: + return self.prefill(input_ids) + else: + return self.decode(input_ids, past_key_values) + + def __call__( + self, + input_ids: np.ndarray, + past_key_values: Optional[DeepSparsePastKeyValues] = None, + ): + return self.forward(input_ids, past_key_values) \ No newline at end of file diff --git a/server/deepsparse/service/service.py b/server/deepsparse/service/service.py new file mode 100644 index 00000000..5c23bd64 --- /dev/null +++ b/server/deepsparse/service/service.py @@ -0,0 +1,76 @@ +from typing import Dict, List, Tuple +from server.deepsparse.service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch +from server.deepsparse.utils import Generation, CachedBatch, Batch + +class BatchCache: + def __init__(self): + self.cache: Dict[int, DeepSparseCausalLMBatch] = {} + + def pop(self, batch_id: int) -> DeepSparseCausalLMBatch: + batch = self.cache.pop(batch_id, None) + assert batch is not None, "Batch ID {batch_id} not found in cache." + return batch + + def set(self, entry: DeepSparseCausalLMBatch): + if entry is not None: + self.cache[entry.batch_id] = entry + + def delete(self, batch_id: int): + batch = self.pop(batch_id) + if batch is not None: + del batch + + def clear(self): + keys = list(self.cache.keys()) + for k in keys: + self.delete(k) + + def __len__(self): + return len(self.cache.keys()) + +class DeepSparseService: + def __init__( + self, + model: DeepSparseCausalLM + ): + self.model = model + self.cache = BatchCache() + + def ClearCache(self): + self.cache.clear() + + def FilterBatch(self, batch_id: int, request_ids: List[int]) -> CachedBatch: + ds_batch = self.cache.pop(batch_id) + filtered_ds_batch = ds_batch.filter(request_ids) + self.cache.set(filtered_ds_batch) + + return filtered_ds_batch.to_cached_batch() + + def Prefill(self, batch: Batch) -> Tuple[Generation, CachedBatch]: + ds_batch = DeepSparseCausalLMBatch.from_batch( + batch=batch, + tokenizer=self.model.tokenizer + ) + + generations, next_ds_batch = self.model.generate_token(ds_batch) + assert len(generations) == 1 + self.cache.set(next_ds_batch) + + return generations[0], next_ds_batch.to_cached_batch() + + def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: + assert len(batches) != 0, "Must provide at least one batch" + + ds_batches = [] + for cached_batch in batches: + ds_batches.append(self.cache.pop(cached_batch.batch_id)) + + if len(ds_batches) > 1: + ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches) + else: + ds_batch = ds_batches[0] + + generations, next_ds_batch = self.model.generate_token(ds_batch) + self.cache.set(next_ds_batch) + + return generations, (next_ds_batch.to_cached_batch() if next_ds_batch else None) \ No newline at end of file diff --git a/server/deepsparse/utils.py b/server/deepsparse/utils.py new file mode 100644 index 00000000..32c898b3 --- /dev/null +++ b/server/deepsparse/utils.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from queue import Queue +from typing import List, Optional + +@dataclass +class Request: + id: int + prompt: str + max_generated_tokens: int + +@dataclass +class Batch: + id: int + requests: List[Request] + +@dataclass +class CachedBatch: + batch_id: int + request_ids: List[int] + + def __len__(self): + return len(self.request_ids) + +@dataclass +class Generation: + request_id: int + token: Optional[str] + token_id: Optional[str] + stopped: bool + +@dataclass +class GenerateRequest: + prompt: str + max_generated_tokens: int + response_stream: Queue[Generation] \ No newline at end of file