implemented a basic naive router

This commit is contained in:
rsnm2 2023-08-23 19:54:31 +00:00
parent e7ec2ff282
commit 1f87c7762f
5 changed files with 345 additions and 46 deletions

View File

@ -13,31 +13,242 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7d43c041-2c79-4276-9104-2f224b2f8af6", "id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Example Interacting With The Service" "## Example Interacting With The Router"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 2,
"id": "f23bc085-94db-44b6-af42-fc8a05f2cf6a", "id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"ename": "SyntaxError", "name": "stderr",
"evalue": "invalid syntax (260114089.py, line 2)", "output_type": "stream",
"output_type": "error", "text": [
"traceback": [ "2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
"\u001b[0;36m Cell \u001b[0;32mIn[12], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m b = (a = 5) < 5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
] ]
} }
], ],
"source": [ "source": [
"a = None\n", "from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n",
"b = (a = 5) < 5\n", "from server.deepsparse.deepsparse_service import DeepSparseService\n",
"print(b)\n" "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "78acf813-3688-483d-9148-5c0df5d6b8e3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using pad_token, but it is not set yet.\n",
"2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
"DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"deepsparse.engine.Engine:\n",
"\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
"\tbatch_size: 1\n",
"\tnum_cores: 8\n",
"\tnum_streams: 1\n",
"\tscheduler: Scheduler.default\n",
"\tfraction_of_supported_ops: 1.0\n",
"\tcpu_avx_type: avx2\n",
"\tcpu_vnni: False\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"deepsparse.engine.Engine:\n",
"\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
"\tbatch_size: 1\n",
"\tnum_cores: 8\n",
"\tnum_streams: 1\n",
"\tscheduler: Scheduler.default\n",
"\tfraction_of_supported_ops: 1.0\n",
"\tcpu_avx_type: avx2\n",
"\tcpu_vnni: False\n"
]
}
],
"source": [
"tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
"onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
"\n",
"model = DeepSparseCausalLM(\n",
" tokenizer_path=tokenizer_path,\n",
" model_path=onnx_path\n",
")\n",
"\n",
"service = DeepSparseService(model=model)\n",
"router = DeepSparseRouter(service=service)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e93bac63-8924-4cf4-8683-81ce9333a2f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n-1) + fib(n-2)\n",
"\n",
"# Driver function to test above function\n",
"n = int(input(\"Enter the number: \"))\n",
"print(fib(n))\n",
"\n",
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
"\n",
"\n",
"\n",
"Write a function for filtering a list of integers to include only positive numbers:\n",
"\n",
"def filter(lst):\n",
" return [x for x in lst if x > 0]\n",
"\n",
"# Test\n",
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
"print(filter([1,\n",
"\n",
"\n",
"Write a function for checking if a word if a palindrome:\n",
"\n",
"def is_palindrome(word):\n",
" return word == word[::-1]\n",
"\n",
"# Test\n",
"print(is_palindrome(\"racecar\"))\n",
"print(is_palindrome(\"racecar\"))\n",
"print(is_palindrome(\"racecar\"))\n",
"print(is_palindrome(\"racecar\"))\n",
"print(is_palindrome(\"racecar\"))\n",
"print(is_palindrome(\"racecar\"))\n",
"print(is_palindrome(\"racecar\"))\n",
"print(\n",
"\n",
"\n",
"Write a function for reversing a string:\n",
"\n",
"def reverse_string(s):\n",
" return s[::-1]\n",
"\n",
"# Test\n",
"print(reverse_string(\"hello\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"a\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\"))\n",
"print(reverse_string(\"\n",
"\n",
"\n",
"Write a function for sorting an array of integers:\n",
"\n",
"def merge_sort(arr):\n",
" if len(arr) <= 1:\n",
" return arr\n",
" mid = len(arr) // 2\n",
" left = arr[:mid]\n",
" right = arr[mid:]\n",
" left = merge_sort(left)\n",
" right = merge_sort(right)\n",
" return merge(left, right)\n",
"\n",
"def merge(left, right):\n",
" result = []\n",
" while len(left) > 0 and len(right) > 0:\n",
" if left[0]\n",
"\n",
"\n",
"stop\n",
"\n",
"\n"
]
}
],
"source": [
"from threading import Thread\n",
"import time\n",
"\n",
"batching_thread = Thread(target=batching_task, args=[router])\n",
"batching_thread.start()\n",
"\n",
"prompts = [\n",
" \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
" \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n",
" \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n",
" \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n",
" \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n",
"]\n",
"\n",
"def generate_task(prompt):\n",
" result = router.generate(prompt=prompt)\n",
" print(result)\n",
" print(\"\\n\")\n",
"\n",
"generate_threads = [\n",
" Thread(target=generate_task, args=[prompt]) for prompt in prompts\n",
"]\n",
"\n",
"# print(len(generate_threads))\n",
"\n",
"for gt in generate_threads:\n",
" gt.start()\n",
" time.sleep(0.5)\n",
"\n",
"for gt in generate_threads:\n",
" gt.join()\n",
"\n",
"\n",
"generate_task(\"stop\")\n",
"batching_thread.join()"
]
},
{
"cell_type": "markdown",
"id": "7d43c041-2c79-4276-9104-2f224b2f8af6",
"metadata": {},
"source": [
"## Example Interacting With The Service"
] ]
}, },
{ {

View File

@ -205,11 +205,11 @@ class DeepSparseCausalLM:
logits, past_key_values = self.model(input_ids, past_key_values) logits, past_key_values = self.model(input_ids, past_key_values)
# sample token # sample token
# simple for now --- should use NextTokenChooser # todo: simple for now --- should use NextTokenChooser
generated_token_id = self.sample_token(logits) generated_token_id = self.sample_token(logits)
# check stopping criteria # check stopping criteria
# simple for now --- should use StoppingCriteria # todo: simple for now --- should use StoppingCriteria
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1 assert input_ids.shape[0] == 1

View File

@ -13,7 +13,9 @@ class GenerateRequest:
self.generation = prompt self.generation = prompt
self.max_generated_tokens = max_generated_tokens self.max_generated_tokens = max_generated_tokens
self.cv = Condition() self.cv = Condition()
self.is_stopped = False
# todo: implement logic for maximum memory usage
class DeepSparseQueue: class DeepSparseQueue:
def __init__(self): def __init__(self):
self.next_request_id: int = 0 self.next_request_id: int = 0

View File

@ -16,59 +16,102 @@ class DeepSparseRouter:
self.queue: DeepSparseQueue = DeepSparseQueue() self.queue: DeepSparseQueue = DeepSparseQueue()
self.cv: Condition = Condition() self.cv: Condition = Condition()
def generate(self): def generate(self, prompt:str) -> str:
pass generate_request = GenerateRequest(
prompt=prompt,
max_generated_tokens=100
)
with self.cv:
# print("router: acquired cv")
self.queue.append(generate_request)
self.cv.notify()
if prompt == "stop":
return "stop"
with generate_request.cv:
# print("generate_request: acquired cv")
if not generate_request.is_stopped:
# print("generate_request: waiting")
generate_request.cv.wait()
# print("generate_request: done waiting")
return generate_request.generation
def prefill( def prefill(
self, self,
batch: Batch, batch: Batch,
generation_requests: Dict[int,GenerateRequest] generate_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]: ) -> Optional[CachedBatch]:
# print("prefill")
generation, next_batch = self.service.Prefill( generation, next_batch = self.service.Prefill(
PrefillRequest(batch=batch) PrefillRequest(batch=batch)
) )
self.filter_notify_update([generation], generation_requests) self.filter_notify_update([generation], generate_requests)
return self.filter_batch( return self.filter_batch(
batch=next_batch, batch=next_batch,
generation_requests=generation_requests generate_requests=generate_requests
) )
def decode(self): def decode(
pass self,
batches: List[CachedBatch],
generate_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]:
# print("decode")
generations, next_batch = self.service.Decode(
DecodeRequest(batches=batches)
)
self.filter_notify_update(generations, generate_requests)
return self.filter_batch(
batch=next_batch,
generate_requests=generate_requests
)
def filter_notify_update( def filter_notify_update(
self, self,
generations: List[Generation], generations: List[Generation],
generation_requests: Dict[int, GenerateRequest] generate_requests: Dict[int, GenerateRequest]
): ):
# print("filter_notify_update")
for generation in generations: for generation in generations:
request_id = generation.request_id request_id = generation.request_id
# if we hit a stopping criteria # if we hit a stopping criteria
if generation.generated_text is None: if generation.generated_text is None:
# remove from active requests and notify # remove from active requests and notify
stopped_generation_request = generation_requests.pop() stopped_generate_request = generate_requests.pop(request_id)
stopped_generation_request[request_id].cv.notify() with stopped_generate_request.cv:
stopped_generate_request.is_stopped = True
stopped_generate_request.cv.notify()
# otherwise, update generation # otherwise, update generation
else: else:
generation_requests[request_id].generation += generation.generated_text generate_requests[request_id].generation += generation.generated_text
def filter_batch( def filter_batch(
self, self,
batch: CachedBatch, batch: Optional[CachedBatch],
generation_requests: Dict[int, GenerateRequest] generate_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]: ) -> Optional[CachedBatch]:
# print("filter_batch")
# batch is already done
if batch is None:
return batch
# no need to filter # no need to filter
if len(batch) == len(generation_requests): if len(batch) == len(generate_requests):
return batch return batch
# retain only requests that are still in active generation requests # retain only requests that are still in active generation requests
batch.request_ids = [id for id in batch.request_ids if id in generation_requests] batch.request_ids = [id for id in batch.request_ids if id in generate_requests]
# if all requests complete, clear cache and return None # if all requests complete, clear cache and return None
if len(batch) == 0: if len(batch) == 0:
@ -83,18 +126,59 @@ class DeepSparseRouter:
) )
) )
def batching_task(self): def batching_task(
while True: router: DeepSparseRouter
with self.cv: ) -> bool:
while self.queue.is_empty(): # infinite_loop
self.cv.wait() while True:
# block while the queue is empty
# print("batching_task: about to acquire cv")
with router.cv:
while router.queue.is_empty():
# print(f"batching_task cv: waiting")
router.cv.wait()
# print(f"batching_task: done waiting")
# loop until the queue is empty # loop until all batches in the queue are processed
next_batch = self.queue.next_batch() next_batch = router.queue.next_batch()
while next_batch is not None: while next_batch is not None:
cached_batch = self.prefill(*next_batch) batch, generate_requests = next_batch
# hack to break out of the cycle
if batch.requests[0].prompt == "stop":
assert router.queue.is_empty()
assert len(router.service.cache) == 0
return True
cached_batch = router.prefill(
batch=batch,
generate_requests=generate_requests
)
next_batch = self.queue.next_batch() # loop until we do not reiceve any cached batch from the service (== until
# all requests have met their stopping criteria
while cached_batch is not None:
# print(f"batch_size = {len(cached_batch)}")
batches = [cached_batch]
# try to get a new batch and run prefill on this batch
next_batch = router.queue.next_batch()
if next_batch is not None:
new_batch, new_generate_requests = next_batch
new_cached_batch = router.prefill(
batch=new_batch,
generate_requests=new_generate_requests
)
if new_cached_batch is not None:
batches.append(new_cached_batch)
assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
generate_requests.update(new_generate_requests)
# run decode
cached_batch = router.decode(
batches=batches,
generate_requests=generate_requests
)
next_batch = router.queue.next_batch()

View File

@ -7,7 +7,7 @@ from server.deepsparse.deepsparse_requests import (
Generation, CachedBatch Generation, CachedBatch
) )
class BatchCache: class Cache:
def __init__(self): def __init__(self):
self.cache: Dict[int, DeepSparseCausalLMBatch] = {} self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
@ -37,7 +37,7 @@ class DeepSparseService:
model: DeepSparseCausalLM model: DeepSparseCausalLM
): ):
self.model = model self.model = model
self.cache = BatchCache() self.cache = Cache()
def ClearCache(self): def ClearCache(self):
self.cache.clear() self.cache.clear()
@ -46,6 +46,7 @@ class DeepSparseService:
self, self,
request: FilterBatchRequest request: FilterBatchRequest
) -> CachedBatch: ) -> CachedBatch:
ds_batch = self.cache.pop(request.batch_id) ds_batch = self.cache.pop(request.batch_id)
assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache." assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache."
filtered_batch = ds_batch.filter(request.request_ids) filtered_batch = ds_batch.filter(request.request_ids)
@ -57,6 +58,7 @@ class DeepSparseService:
self, self,
request: PrefillRequest request: PrefillRequest
) -> [Generation, CachedBatch]: ) -> [Generation, CachedBatch]:
ds_batch = DeepSparseCausalLMBatch.from_batch( ds_batch = DeepSparseCausalLMBatch.from_batch(
batch=request.batch, batch=request.batch,
tokenizer=self.model.tokenizer tokenizer=self.model.tokenizer
@ -88,4 +90,4 @@ class DeepSparseService:
generations, next_ds_batch = self.model.generate_token(ds_batch) generations, next_ds_batch = self.model.generate_token(ds_batch)
self.cache.set(next_ds_batch) self.cache.set(next_ds_batch)
return generations, next_ds_batch.to_batch() if next_ds_batch else None return generations, (next_ds_batch.to_batch() if next_ds_batch else None)