mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
commit
83f6461bb9
397
server-dev.ipynb
397
server-dev.ipynb
@ -13,30 +13,36 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d43c041-2c79-4276-9104-2f224b2f8af6",
|
||||
"id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example Interacting With The Service"
|
||||
"## Example Interacting With The Router"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"id": "631e94eb-cca0-438e-8936-6e8a87166d63",
|
||||
"execution_count": 2,
|
||||
"id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLMBatch, DeepSparseCausalLM\n",
|
||||
"from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n",
|
||||
"from server.deepsparse.deepsparse_service import DeepSparseService\n",
|
||||
"from server.deepsparse.deepsparse_requests import (\n",
|
||||
" PrefillRequest, DecodeRequest, FilterBatchRequest, Request\n",
|
||||
")"
|
||||
"from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c9c39557-2898-443f-aae8-443ef1171123",
|
||||
"id": "78acf813-3688-483d-9148-5c0df5d6b8e3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -44,7 +50,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using pad_token, but it is not set yet.\n",
|
||||
"2023-08-22 03:09:19 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
|
||||
"2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
|
||||
"DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
|
||||
]
|
||||
},
|
||||
@ -67,7 +73,241 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-08-22 03:09:45 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
|
||||
"2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"deepsparse.engine.Engine:\n",
|
||||
"\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
|
||||
"\tbatch_size: 1\n",
|
||||
"\tnum_cores: 8\n",
|
||||
"\tnum_streams: 1\n",
|
||||
"\tscheduler: Scheduler.default\n",
|
||||
"\tfraction_of_supported_ops: 1.0\n",
|
||||
"\tcpu_avx_type: avx2\n",
|
||||
"\tcpu_vnni: False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
|
||||
"onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
|
||||
"\n",
|
||||
"model = DeepSparseCausalLM(\n",
|
||||
" tokenizer_path=tokenizer_path,\n",
|
||||
" model_path=onnx_path\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"service = DeepSparseService(model=model)\n",
|
||||
"router = DeepSparseRouter(service=service)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "e93bac63-8924-4cf4-8683-81ce9333a2f1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Finish the following function for computing a fibonacci sequence: \n",
|
||||
"\n",
|
||||
"def fib(n):\n",
|
||||
" if n == 0:\n",
|
||||
" return 0\n",
|
||||
" elif n == 1:\n",
|
||||
" return 1\n",
|
||||
" else:\n",
|
||||
" return fib(n-1) + fib(n-2)\n",
|
||||
"\n",
|
||||
"# Driver function to test above function\n",
|
||||
"n = int(input(\"Enter the number: \"))\n",
|
||||
"print(fib(n))\n",
|
||||
"\n",
|
||||
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Write a function for filtering a list of integers to include only positive numbers:\n",
|
||||
"\n",
|
||||
"def filter(lst):\n",
|
||||
" return [x for x in lst if x > 0]\n",
|
||||
"\n",
|
||||
"# Test\n",
|
||||
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
|
||||
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
|
||||
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
|
||||
"print(filter([1,\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Write a function for checking if a word if a palindrome:\n",
|
||||
"\n",
|
||||
"def is_palindrome(word):\n",
|
||||
" return word == word[::-1]\n",
|
||||
"\n",
|
||||
"# Test\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Write a function for reversing a string:\n",
|
||||
"\n",
|
||||
"def reverse_string(s):\n",
|
||||
" return s[::-1]\n",
|
||||
"\n",
|
||||
"# Test\n",
|
||||
"print(reverse_string(\"hello\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"a\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Write a function for sorting an array of integers:\n",
|
||||
"\n",
|
||||
"def merge_sort(arr):\n",
|
||||
" if len(arr) <= 1:\n",
|
||||
" return arr\n",
|
||||
" mid = len(arr) // 2\n",
|
||||
" left = arr[:mid]\n",
|
||||
" right = arr[mid:]\n",
|
||||
" left = merge_sort(left)\n",
|
||||
" right = merge_sort(right)\n",
|
||||
" return merge(left, right)\n",
|
||||
"\n",
|
||||
"def merge(left, right):\n",
|
||||
" result = []\n",
|
||||
" while len(left) > 0 and len(right) > 0:\n",
|
||||
" if left[0]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"stop\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from threading import Thread\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"batching_thread = Thread(target=batching_task, args=[router])\n",
|
||||
"batching_thread.start()\n",
|
||||
"\n",
|
||||
"prompts = [\n",
|
||||
" \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
|
||||
" \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n",
|
||||
" \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n",
|
||||
" \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n",
|
||||
" \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"def generate_task(prompt):\n",
|
||||
" result = router.generate(prompt=prompt)\n",
|
||||
" print(result)\n",
|
||||
" print(\"\\n\")\n",
|
||||
"\n",
|
||||
"generate_threads = [\n",
|
||||
" Thread(target=generate_task, args=[prompt]) for prompt in prompts\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# print(len(generate_threads))\n",
|
||||
"\n",
|
||||
"for gt in generate_threads:\n",
|
||||
" gt.start()\n",
|
||||
" time.sleep(0.5)\n",
|
||||
"\n",
|
||||
"for gt in generate_threads:\n",
|
||||
" gt.join()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"generate_task(\"stop\")\n",
|
||||
"batching_thread.join()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d43c041-2c79-4276-9104-2f224b2f8af6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example Interacting With The Service"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "631e94eb-cca0-438e-8936-6e8a87166d63",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-08-22 14:26:39 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLMBatch, DeepSparseCausalLM\n",
|
||||
"from server.deepsparse.deepsparse_service import DeepSparseService\n",
|
||||
"from server.deepsparse.deepsparse_requests import (\n",
|
||||
" PrefillRequest, DecodeRequest, FilterBatchRequest, Request\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c9c39557-2898-443f-aae8-443ef1171123",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using pad_token, but it is not set yet.\n",
|
||||
"2023-08-22 14:26:56 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
|
||||
"DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"deepsparse.engine.Engine:\n",
|
||||
"\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
|
||||
"\tbatch_size: 1\n",
|
||||
"\tnum_cores: 8\n",
|
||||
"\tnum_streams: 1\n",
|
||||
"\tscheduler: Scheduler.default\n",
|
||||
"\tfraction_of_supported_ops: 1.0\n",
|
||||
"\tcpu_avx_type: avx2\n",
|
||||
"\tcpu_vnni: False\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-08-22 14:27:21 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -100,7 +340,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 4,
|
||||
"id": "85ce9aab-1a56-4b6f-a82b-4e91d52290b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -136,10 +376,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"execution_count": 5,
|
||||
"id": "d2441753-fe2a-45c0-ad80-135b6207947d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'Batch' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m service\u001b[38;5;241m.\u001b[39mClearCache()\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# prefill queue\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m prefill_queue \u001b[38;5;241m=\u001b[39m \u001b[43mPrefillQueue\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# cached batches\u001b[39;00m\n\u001b[1;32m 7\u001b[0m cached_batches \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"Cell \u001b[0;32mIn[4], line 17\u001b[0m, in \u001b[0;36mPrefillQueue.__init__\u001b[0;34m(self, prompts)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, prompts):\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqueue \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 18\u001b[0m idx: PrefillRequest(batch\u001b[38;5;241m=\u001b[39mmake_batch(\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39midx, prompt\u001b[38;5;241m=\u001b[39mprompt))\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, prompt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(prompts)\n\u001b[1;32m 20\u001b[0m }\n",
|
||||
"Cell \u001b[0;32mIn[4], line 18\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, prompts):\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqueue \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m---> 18\u001b[0m idx: PrefillRequest(batch\u001b[38;5;241m=\u001b[39m\u001b[43mmake_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, prompt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(prompts)\n\u001b[1;32m 20\u001b[0m }\n",
|
||||
"Cell \u001b[0;32mIn[4], line 10\u001b[0m, in \u001b[0;36mmake_batch\u001b[0;34m(id, prompt)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmake_batch\u001b[39m(\u001b[38;5;28mid\u001b[39m, prompt):\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mBatch\u001b[49m(\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mid\u001b[39m,\n\u001b[1;32m 12\u001b[0m requests\u001b[38;5;241m=\u001b[39m[Request(\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mid\u001b[39m, prompt\u001b[38;5;241m=\u001b[39mprompt)]\n\u001b[1;32m 13\u001b[0m )\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'Batch' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"service.ClearCache()\n",
|
||||
"\n",
|
||||
@ -213,118 +468,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": null,
|
||||
"id": "dd6bcc43-63ef-4f92-a960-74e33b86dc97",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Request 0 is done!\n",
|
||||
"Request 1 is done!\n",
|
||||
"Request 3 is done!\n",
|
||||
"Request 2 is done!\n",
|
||||
"All Requests Done!\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"INDEX = 0:\n",
|
||||
"Finish the following function for computing a fibonacci sequence: \n",
|
||||
"\n",
|
||||
" fib(n):\n",
|
||||
"\n",
|
||||
" if n == 0:\n",
|
||||
" return 0\n",
|
||||
" elif n == 1:\n",
|
||||
" return 1\n",
|
||||
" else:\n",
|
||||
" return fib(n-1) + fib(n-2)\n",
|
||||
"\n",
|
||||
"# Call the function.\n",
|
||||
"print(fib(5))\n",
|
||||
"\n",
|
||||
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"INDEX = 1:\n",
|
||||
"Write a function for filtering a list of integers to include only positive numbers:\n",
|
||||
"\n",
|
||||
"filter(lst):\n",
|
||||
"\n",
|
||||
"lst = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n",
|
||||
"\n",
|
||||
"def filter_positive(lst):\n",
|
||||
" return [num for num in lst if num > 0]\n",
|
||||
"\n",
|
||||
"print(filter_positive(lst))\n",
|
||||
"\n",
|
||||
"# filter_positive([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n",
|
||||
"\n",
|
||||
"# filter_positive([1, 2, 3, 4, 5\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"INDEX = 2:\n",
|
||||
"Write a function for reversing a string:\n",
|
||||
"\n",
|
||||
"def reverse_string(s):\n",
|
||||
" return s[::-1]\n",
|
||||
"\n",
|
||||
"# Test\n",
|
||||
"print(reverse_string(\"hello\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"a\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\"))\n",
|
||||
"print(reverse_string(\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"INDEX = 3:\n",
|
||||
"Write a function for checking if a word if a palindrome:\n",
|
||||
"\n",
|
||||
"def is_palindrome(word):\n",
|
||||
" return word == word[::-1]\n",
|
||||
"\n",
|
||||
"# Test\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(is_palindrome(\"racecar\"))\n",
|
||||
"print(\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"INDEX = 4:\n",
|
||||
"Write a function for sorting an array of integers:\n",
|
||||
"\n",
|
||||
"def merge_sort(arr):\n",
|
||||
" if len(arr) <= 1:\n",
|
||||
" return arr\n",
|
||||
" mid = len(arr) // 2\n",
|
||||
" left = arr[:mid]\n",
|
||||
" right = arr[mid:]\n",
|
||||
" left = merge_sort(left)\n",
|
||||
" right = merge_sort(right)\n",
|
||||
" return merge(left, right)\n",
|
||||
"\n",
|
||||
"def merge(left, right):\n",
|
||||
" result = []\n",
|
||||
" while len(left) > 0 and len(right) > 0:\n",
|
||||
" if left[0]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"[CachedBatch(batch_id=0, request_ids=[4])]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# run a few decodes\n",
|
||||
"for _ in range(100):\n",
|
||||
|
@ -205,11 +205,11 @@ class DeepSparseCausalLM:
|
||||
logits, past_key_values = self.model(input_ids, past_key_values)
|
||||
|
||||
# sample token
|
||||
# simple for now --- should use NextTokenChooser
|
||||
# todo: simple for now --- should use NextTokenChooser
|
||||
generated_token_id = self.sample_token(logits)
|
||||
|
||||
# check stopping criteria
|
||||
# simple for now --- should use StoppingCriteria
|
||||
# todo: simple for now --- should use StoppingCriteria
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
|
||||
|
58
server/deepsparse/deepsparse_queue.py
Normal file
58
server/deepsparse/deepsparse_queue.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Deque, Optional, Tuple, Dict
|
||||
from collections import deque
|
||||
from threading import Condition
|
||||
from server.deepsparse.deepsparse_requests import Batch, Request
|
||||
|
||||
class GenerateRequest:
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str,
|
||||
max_generated_tokens: int
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.generation = prompt
|
||||
self.max_generated_tokens = max_generated_tokens
|
||||
self.cv = Condition()
|
||||
self.is_stopped = False
|
||||
|
||||
# todo: implement logic for maximum memory usage
|
||||
class DeepSparseQueue:
|
||||
def __init__(self):
|
||||
self.next_request_id: int = 0
|
||||
self.next_batch_id: int = 0
|
||||
self.queue: Deque[GenerateRequest] = deque()
|
||||
|
||||
def append(self, generate_request: GenerateRequest):
|
||||
self.queue.append(generate_request)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.queue) == 0
|
||||
|
||||
# (todo): enable multiple prefill requests in a batch
|
||||
def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
|
||||
if self.is_empty():
|
||||
return None
|
||||
|
||||
# pop first generate_request in the queue
|
||||
generate_request = self.queue.popleft()
|
||||
generate_requests = {
|
||||
self.next_request_id: generate_request
|
||||
}
|
||||
|
||||
# format into request
|
||||
request = Request(
|
||||
id=self.next_request_id,
|
||||
prompt=generate_request.prompt,
|
||||
max_generated_tokens=generate_request.max_generated_tokens
|
||||
)
|
||||
self.next_request_id += 1
|
||||
|
||||
# format into batch
|
||||
batch = Batch(
|
||||
id = self.next_batch_id,
|
||||
requests=[request]
|
||||
)
|
||||
self.next_batch_id += 1
|
||||
|
||||
# return batch, generate_requests
|
||||
return (batch, generate_requests)
|
@ -5,6 +5,7 @@ from typing import List, Optional
|
||||
class Request:
|
||||
id: int
|
||||
prompt: str
|
||||
max_generated_tokens: int
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
@ -16,6 +17,9 @@ class CachedBatch:
|
||||
batch_id: int
|
||||
request_ids: List[int]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.request_ids)
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
request_id: int
|
||||
|
184
server/deepsparse/deepsparse_router.py
Normal file
184
server/deepsparse/deepsparse_router.py
Normal file
@ -0,0 +1,184 @@
|
||||
from threading import Condition
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from server.deepsparse.deepsparse_service import DeepSparseService
|
||||
from server.deepsparse.deepsparse_requests import (
|
||||
CachedBatch, Batch, Generation,
|
||||
PrefillRequest, DecodeRequest, FilterBatchRequest,
|
||||
)
|
||||
from server.deepsparse.deepsparse_queue import (
|
||||
DeepSparseQueue, GenerateRequest
|
||||
)
|
||||
|
||||
class DeepSparseRouter:
|
||||
def __init__(self, service: DeepSparseService):
|
||||
self.service: DeepSparseService = service
|
||||
self.queue: DeepSparseQueue = DeepSparseQueue()
|
||||
self.cv: Condition = Condition()
|
||||
|
||||
def generate(self, prompt:str) -> str:
|
||||
generate_request = GenerateRequest(
|
||||
prompt=prompt,
|
||||
max_generated_tokens=100
|
||||
)
|
||||
|
||||
with self.cv:
|
||||
# print("router: acquired cv")
|
||||
self.queue.append(generate_request)
|
||||
self.cv.notify()
|
||||
|
||||
if prompt == "stop":
|
||||
return "stop"
|
||||
|
||||
with generate_request.cv:
|
||||
# print("generate_request: acquired cv")
|
||||
if not generate_request.is_stopped:
|
||||
# print("generate_request: waiting")
|
||||
generate_request.cv.wait()
|
||||
|
||||
# print("generate_request: done waiting")
|
||||
|
||||
return generate_request.generation
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
batch: Batch,
|
||||
generate_requests: Dict[int,GenerateRequest]
|
||||
) -> Optional[CachedBatch]:
|
||||
# print("prefill")
|
||||
generation, next_batch = self.service.Prefill(
|
||||
PrefillRequest(batch=batch)
|
||||
)
|
||||
|
||||
self.filter_notify_update([generation], generate_requests)
|
||||
|
||||
return self.filter_batch(
|
||||
batch=next_batch,
|
||||
generate_requests=generate_requests
|
||||
)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
batches: List[CachedBatch],
|
||||
generate_requests: Dict[int,GenerateRequest]
|
||||
) -> Optional[CachedBatch]:
|
||||
# print("decode")
|
||||
generations, next_batch = self.service.Decode(
|
||||
DecodeRequest(batches=batches)
|
||||
)
|
||||
|
||||
self.filter_notify_update(generations, generate_requests)
|
||||
|
||||
return self.filter_batch(
|
||||
batch=next_batch,
|
||||
generate_requests=generate_requests
|
||||
)
|
||||
|
||||
def filter_notify_update(
|
||||
self,
|
||||
generations: List[Generation],
|
||||
generate_requests: Dict[int, GenerateRequest]
|
||||
):
|
||||
# print("filter_notify_update")
|
||||
for generation in generations:
|
||||
request_id = generation.request_id
|
||||
|
||||
# if we hit a stopping criteria
|
||||
if generation.generated_text is None:
|
||||
# remove from active requests and notify
|
||||
stopped_generate_request = generate_requests.pop(request_id)
|
||||
with stopped_generate_request.cv:
|
||||
stopped_generate_request.is_stopped = True
|
||||
stopped_generate_request.cv.notify()
|
||||
|
||||
# otherwise, update generation
|
||||
else:
|
||||
generate_requests[request_id].generation += generation.generated_text
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
batch: Optional[CachedBatch],
|
||||
generate_requests: Dict[int, GenerateRequest]
|
||||
) -> Optional[CachedBatch]:
|
||||
# print("filter_batch")
|
||||
|
||||
# batch is already done
|
||||
if batch is None:
|
||||
return batch
|
||||
|
||||
# no need to filter
|
||||
if len(batch) == len(generate_requests):
|
||||
return batch
|
||||
|
||||
# retain only requests that are still in active generation requests
|
||||
batch.request_ids = [id for id in batch.request_ids if id in generate_requests]
|
||||
|
||||
# if all requests complete, clear cache and return None
|
||||
if len(batch) == 0:
|
||||
self.service.ClearCache()
|
||||
return None
|
||||
|
||||
# otherwise call the filter batch service
|
||||
return self.service.FilterBatch(
|
||||
FilterBatchRequest(
|
||||
batch_id=batch.batch_id,
|
||||
request_ids=batch.request_ids,
|
||||
)
|
||||
)
|
||||
|
||||
def batching_task(
|
||||
router: DeepSparseRouter
|
||||
) -> bool:
|
||||
# infinite_loop
|
||||
while True:
|
||||
# block while the queue is empty
|
||||
# print("batching_task: about to acquire cv")
|
||||
with router.cv:
|
||||
while router.queue.is_empty():
|
||||
# print(f"batching_task cv: waiting")
|
||||
router.cv.wait()
|
||||
# print(f"batching_task: done waiting")
|
||||
|
||||
# loop until all batches in the queue are processed
|
||||
next_batch = router.queue.next_batch()
|
||||
while next_batch is not None:
|
||||
batch, generate_requests = next_batch
|
||||
|
||||
# hack to break out of the cycle
|
||||
if batch.requests[0].prompt == "stop":
|
||||
assert router.queue.is_empty()
|
||||
assert len(router.service.cache) == 0
|
||||
return True
|
||||
|
||||
cached_batch = router.prefill(
|
||||
batch=batch,
|
||||
generate_requests=generate_requests
|
||||
)
|
||||
|
||||
# loop until we do not reiceve any cached batch from the service (== until
|
||||
# all requests have met their stopping criteria
|
||||
while cached_batch is not None:
|
||||
# print(f"batch_size = {len(cached_batch)}")
|
||||
batches = [cached_batch]
|
||||
|
||||
# try to get a new batch and run prefill on this batch
|
||||
next_batch = router.queue.next_batch()
|
||||
if next_batch is not None:
|
||||
new_batch, new_generate_requests = next_batch
|
||||
new_cached_batch = router.prefill(
|
||||
batch=new_batch,
|
||||
generate_requests=new_generate_requests
|
||||
)
|
||||
|
||||
if new_cached_batch is not None:
|
||||
batches.append(new_cached_batch)
|
||||
assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
|
||||
generate_requests.update(new_generate_requests)
|
||||
|
||||
# run decode
|
||||
cached_batch = router.decode(
|
||||
batches=batches,
|
||||
generate_requests=generate_requests
|
||||
)
|
||||
|
||||
next_batch = router.queue.next_batch()
|
@ -7,7 +7,7 @@ from server.deepsparse.deepsparse_requests import (
|
||||
Generation, CachedBatch
|
||||
)
|
||||
|
||||
class BatchCache:
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
|
||||
|
||||
@ -37,7 +37,7 @@ class DeepSparseService:
|
||||
model: DeepSparseCausalLM
|
||||
):
|
||||
self.model = model
|
||||
self.cache = BatchCache()
|
||||
self.cache = Cache()
|
||||
|
||||
def ClearCache(self):
|
||||
self.cache.clear()
|
||||
@ -46,6 +46,7 @@ class DeepSparseService:
|
||||
self,
|
||||
request: FilterBatchRequest
|
||||
) -> CachedBatch:
|
||||
|
||||
ds_batch = self.cache.pop(request.batch_id)
|
||||
assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache."
|
||||
filtered_batch = ds_batch.filter(request.request_ids)
|
||||
@ -57,6 +58,7 @@ class DeepSparseService:
|
||||
self,
|
||||
request: PrefillRequest
|
||||
) -> [Generation, CachedBatch]:
|
||||
|
||||
ds_batch = DeepSparseCausalLMBatch.from_batch(
|
||||
batch=request.batch,
|
||||
tokenizer=self.model.tokenizer
|
||||
@ -88,4 +90,4 @@ class DeepSparseService:
|
||||
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||
self.cache.set(next_ds_batch)
|
||||
|
||||
return generations, next_ds_batch.to_batch() if next_ds_batch else None
|
||||
return generations, (next_ds_batch.to_batch() if next_ds_batch else None)
|
Loading…
Reference in New Issue
Block a user