diff --git a/server/deepsparse/deepsparse_queue.py b/server/deepsparse/deepsparse_queue.py new file mode 100644 index 00000000..6f0194b6 --- /dev/null +++ b/server/deepsparse/deepsparse_queue.py @@ -0,0 +1,56 @@ +from typing import Deque, Optional, Tuple, Dict +from collections import deque +from threading import Condition +from server.deepsparse.deepsparse_requests import Batch, Request + +class GenerateRequest: + def __init__( + self, + prompt: str, + max_generated_tokens: int + ): + self.prompt = prompt + self.generation = prompt + self.max_generated_tokens = max_generated_tokens + self.cv = Condition() + +class DeepSparseQueue: + def __init__(self): + self.next_request_id: int = 0 + self.next_batch_id: int = 0 + self.queue: Deque[GenerateRequest] = deque() + + def append(self, generate_request: GenerateRequest): + self.queue.append(generate_request) + + def is_empty(self): + return len(self.queue) == 0 + + # (todo): enable multiple prefill requests in a batch + def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]: + if self.is_empty(): + return None + + # pop first generate_request in the queue + generate_request = self.queue.popleft() + generate_requests = { + self.next_request_id: generate_request + } + + # format into request + request = Request( + id=self.next_request_id, + prompt=generate_request.prompt, + max_generated_tokens=generate_request.max_generated_tokens + ) + self.next_request_id += 1 + + # format into batch + batch = Batch( + id = self.next_batch_id, + requests=[request] + ) + self.next_batch_id += 1 + + # return batch, generate_requests + return (batch, generate_requests) \ No newline at end of file diff --git a/server/deepsparse/deepsparse_router.py b/server/deepsparse/deepsparse_router.py new file mode 100644 index 00000000..a09b374e --- /dev/null +++ b/server/deepsparse/deepsparse_router.py @@ -0,0 +1,100 @@ +from threading import Condition +from typing import List, Dict, Optional + +from server.deepsparse.deepsparse_service import DeepSparseService +from server.deepsparse.deepsparse_requests import ( + CachedBatch, Batch, Generation, + PrefillRequest, DecodeRequest, FilterBatchRequest, +) +from server.deepsparse.deepsparse_queue import ( + DeepSparseQueue, GenerateRequest +) + +class DeepSparseRouter: + def __init__(self, service: DeepSparseService): + self.service: DeepSparseService = service + self.queue: DeepSparseQueue = DeepSparseQueue() + self.cv: Condition = Condition() + + def generate(self): + pass + + def prefill( + self, + batch: Batch, + generation_requests: Dict[int,GenerateRequest] + ) -> Optional[CachedBatch]: + + generation, next_batch = self.service.Prefill( + PrefillRequest(batch=batch) + ) + + self.filter_notify_update([generation], generation_requests) + + return self.filter_batch( + batch=next_batch, + generation_requests=generation_requests + ) + + def decode(self): + pass + + def filter_notify_update( + self, + generations: List[Generation], + generation_requests: Dict[int, GenerateRequest] + ): + for generation in generations: + request_id = generation.request_id + + # if we hit a stopping criteria + if generation.generated_text is None: + # remove from active requests and notify + stopped_generation_request = generation_requests.pop() + stopped_generation_request[request_id].cv.notify() + + # otherwise, update generation + else: + generation_requests[request_id].generation += generation.generated_text + + def filter_batch( + self, + batch: CachedBatch, + generation_requests: Dict[int, GenerateRequest] + ) -> Optional[CachedBatch]: + + # no need to filter + if len(batch) == len(generation_requests): + return batch + + # retain only requests that are still in active generation requests + batch.request_ids = [id for id in batch.request_ids if id in generation_requests] + + # if all requests complete, clear cache and return None + if len(batch) == 0: + self.service.ClearCache() + return None + + # otherwise call the filter batch service + return self.service.FilterBatch( + FilterBatchRequest( + batch_id=batch.batch_id, + request_ids=batch.request_ids, + ) + ) + + def batching_task(self): + while True: + with self.cv: + while self.queue.is_empty(): + self.cv.wait() + + # loop until the queue is empty + next_batch = self.queue.next_batch() + while next_batch is not None: + cached_batch = self.prefill(*next_batch) + + + + next_batch = self.queue.next_batch() + \ No newline at end of file