added all the files

This commit is contained in:
rsnm2 2023-08-22 18:12:24 +00:00
parent e7f3eac8d5
commit e7ec2ff282
2 changed files with 156 additions and 0 deletions

View File

@ -0,0 +1,56 @@
from typing import Deque, Optional, Tuple, Dict
from collections import deque
from threading import Condition
from server.deepsparse.deepsparse_requests import Batch, Request
class GenerateRequest:
def __init__(
self,
prompt: str,
max_generated_tokens: int
):
self.prompt = prompt
self.generation = prompt
self.max_generated_tokens = max_generated_tokens
self.cv = Condition()
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Deque[GenerateRequest] = deque()
def append(self, generate_request: GenerateRequest):
self.queue.append(generate_request)
def is_empty(self):
return len(self.queue) == 0
# (todo): enable multiple prefill requests in a batch
def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
if self.is_empty():
return None
# pop first generate_request in the queue
generate_request = self.queue.popleft()
generate_requests = {
self.next_request_id: generate_request
}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)

View File

@ -0,0 +1,100 @@
from threading import Condition
from typing import List, Dict, Optional
from server.deepsparse.deepsparse_service import DeepSparseService
from server.deepsparse.deepsparse_requests import (
CachedBatch, Batch, Generation,
PrefillRequest, DecodeRequest, FilterBatchRequest,
)
from server.deepsparse.deepsparse_queue import (
DeepSparseQueue, GenerateRequest
)
class DeepSparseRouter:
def __init__(self, service: DeepSparseService):
self.service: DeepSparseService = service
self.queue: DeepSparseQueue = DeepSparseQueue()
self.cv: Condition = Condition()
def generate(self):
pass
def prefill(
self,
batch: Batch,
generation_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]:
generation, next_batch = self.service.Prefill(
PrefillRequest(batch=batch)
)
self.filter_notify_update([generation], generation_requests)
return self.filter_batch(
batch=next_batch,
generation_requests=generation_requests
)
def decode(self):
pass
def filter_notify_update(
self,
generations: List[Generation],
generation_requests: Dict[int, GenerateRequest]
):
for generation in generations:
request_id = generation.request_id
# if we hit a stopping criteria
if generation.generated_text is None:
# remove from active requests and notify
stopped_generation_request = generation_requests.pop()
stopped_generation_request[request_id].cv.notify()
# otherwise, update generation
else:
generation_requests[request_id].generation += generation.generated_text
def filter_batch(
self,
batch: CachedBatch,
generation_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]:
# no need to filter
if len(batch) == len(generation_requests):
return batch
# retain only requests that are still in active generation requests
batch.request_ids = [id for id in batch.request_ids if id in generation_requests]
# if all requests complete, clear cache and return None
if len(batch) == 0:
self.service.ClearCache()
return None
# otherwise call the filter batch service
return self.service.FilterBatch(
FilterBatchRequest(
batch_id=batch.batch_id,
request_ids=batch.request_ids,
)
)
def batching_task(self):
while True:
with self.cv:
while self.queue.is_empty():
self.cv.wait()
# loop until the queue is empty
next_batch = self.queue.next_batch()
while next_batch is not None:
cached_batch = self.prefill(*next_batch)
next_batch = self.queue.next_batch()