Merge pull request #3 from rsnm2/dev-ux

Dev ux
This commit is contained in:
Robert Shaw 2023-08-25 07:41:11 -06:00 committed by GitHub
commit 02952c511f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 954 additions and 406 deletions

View File

@ -61,4 +61,11 @@ python3 server/text_generation_server/cli.py serve bigscience/bloom-560m
Launch Router Launch Router
```shell ```shell
make router-dev make router-dev
```
Install FastAPI/Uvicorn
```shell
pip install fastapi
pip install "uvicorn[standard]"
``` ```

68
deepsparse/main.py Normal file
View File

@ -0,0 +1,68 @@
import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
from utils import GenerateRequest
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
def serve(
model_path=MODEL_PATH,
tokenizer_path=TOKENIZER_PATH,
host="0.0.0.0",
port=5543
):
router = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
router.submit_request(
GenerateRequest(
prompt=prompt,
max_generated_tokens=100,
response_stream=response_stream
)
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
return response_string
uvicorn.run(
app,
host=host,
port=port,
workers=1
)
if __name__ == "__main__":
serve()

191
deepsparse/router.py Normal file
View File

@ -0,0 +1,191 @@
from queue import Queue
from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
class DeepSparseRouter:
def __init__(
self,
service: Optional[DeepSparseService] = None,
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None
):
assert (
service is not None or
(model_path is not None and tokenizer_path is not None)
)
if service is not None:
self.service = service
else:
self.service = DeepSparseService(
model = DeepSparseCausalLM(
model_path=model_path,
tokenizer_path=tokenizer_path
)
)
self.queue: DeepSparseQueue = DeepSparseQueue()
self.batching_task_should_stop:bool = False
def submit_request(self, generate_request: GenerateRequest):
self.queue.append(generate_request)
def stop_batching_task(self):
# tell batching task to stop
self.batching_task_should_stop = True
# unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest(
prompt="dummy",
max_generated_tokens=1,
response_stream=Queue()
))
def prefill(
self,
batch: Batch,
generate_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]:
generation, next_batch = self.service.Prefill(batch=batch)
active_generate_request_ids = self.filter_send_generations([generation], generate_requests)
return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids)
def decode(
self,
batches: List[CachedBatch],
generate_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]:
generations, next_batch = self.service.Decode(batches=batches)
active_generate_request_ids = self.filter_send_generations(generations, generate_requests)
return self.filter_batch(batch=next_batch, active_generate_request_ids=active_generate_request_ids)
def filter_send_generations(
self,
generations: List[Generation],
generate_requests: Dict[int, GenerateRequest]
) -> List[int]:
active_request_ids = []
for generation in generations:
# send generation to the response stream
generate_requests[generation.request_id].response_stream.put(generation)
# remove request from active requests if stopped
if generation.stopped:
generate_requests.pop(generation.request_id)
else:
active_request_ids.append(generation.request_id)
return active_request_ids
def filter_batch(
self,
batch: Optional[CachedBatch],
active_generate_request_ids: List[int]
) -> Optional[CachedBatch]:
# if batch done OR nothing to filter
if batch is None or len(batch) == len(active_generate_request_ids):
return batch
# active request_ids
batch.request_ids = active_generate_request_ids
# if all requests complete, clear cache
if len(batch) == 0:
self.service.ClearCache()
return None
return self.service.FilterBatch(batch_id=batch.batch_id, request_ids=batch.request_ids)
# TODO: update to do more sophisticated logic as to when to do a prefill
def batching_task(router: DeepSparseRouter):
# while not signaled to stop
while not router.batching_task_should_stop:
# loop until no requests to process (note: this blocks if queue is empty)
next_batch = router.queue.next_batch(block=True)
while next_batch is not None:
batch, generate_requests = next_batch
# run prefill
cached_batch = router.prefill(
batch=batch,
generate_requests=generate_requests
)
# loop until we do not reiceve any cached batch from the service
# == until all active requests have met their stopping criteria
while cached_batch is not None:
batches = [cached_batch]
# try to get a new batch and run prefill on this batch
next_batch = router.queue.next_batch(block=False)
if next_batch is not None:
new_batch, new_generate_requests = next_batch
new_cached_batch = router.prefill(
batch=new_batch,
generate_requests=new_generate_requests
)
if new_cached_batch is not None:
batches.append(new_cached_batch)
assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
generate_requests.update(new_generate_requests)
# run decode
cached_batch = router.decode(
batches=batches,
generate_requests=generate_requests
)
next_batch = router.queue.next_batch(block=False)
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Queue[GenerateRequest] = Queue()
def append(self, generate_request: GenerateRequest):
self.queue.put(generate_request)
# TODO: enable multiple prefill requests in a batch
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
# if not blocking, return none if empty
if not block and self.queue.empty():
return None
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)

View File

@ -1,15 +1,10 @@
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict, Optional from typing import List, Dict, Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np
from server.deepsparse.deepsparse_model import ( from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
DeepSparsePastKeyValues, DeepSparseDecoderModel from utils import Request, Batch, CachedBatch, Generation
)
from server.deepsparse.deepsparse_requests import (
Request, Batch, CachedBatch, Generation
)
DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4 DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -62,7 +57,7 @@ class DeepSparseCausalLMBatch:
past_key_values_list=[None] * len(batch.requests), past_key_values_list=[None] * len(batch.requests),
) )
def to_batch(self) -> CachedBatch: def to_cached_batch(self) -> CachedBatch:
return CachedBatch( return CachedBatch(
batch_id = self.batch_id, batch_id = self.batch_id,
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
@ -156,15 +151,19 @@ class DeepSparseCausalLM:
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
) )
# TODO (@rsnm2): switch to NextTokenChooser # TODO: switch to NextTokenChooser
def sample_token( def sample_token(
self, self,
logits: np.ndarray logits: np.ndarray
): ):
assert(logits.shape[0] == 1) # assert b=1 for now # assert b=1 for now
return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence assert(logits.shape[0] == 1)
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
# TODO (@rsnm2): switch to StoppingCriteria # TODO: switch to StoppingCriteria
def should_stop( def should_stop(
self, self,
num_tokens_processed: int, num_tokens_processed: int,
@ -184,56 +183,45 @@ class DeepSparseCausalLM:
generations: List[Generation] = [] generations: List[Generation] = []
all_stopped = True all_stopped = True
# if we supported continuous batching, we would do batched inference here
# logits, past_key_values = self.model(batch)
# for each member of the batch: # for each member of the batch:
# a) run inference # a) run inference
# b) sample and check stopping criteria # b) sample and check stopping criteria
# c) create generation + update batch # c) create generation
for i, ( # d) update batch
request, for i, (request, input_ids, past_key_values,) in enumerate(
input_ids, zip(
past_key_values, batch.requests,
) in enumerate(zip( batch.input_ids_list,
batch.requests, batch.past_key_values_list
batch.input_ids_list, )
batch.past_key_values_list ):
)):
# run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# sample token
# todo: simple for now --- should use NextTokenChooser
generated_token_id = self.sample_token(logits)
# check stopping criteria
# todo: simple for now --- should use StoppingCriteria
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1 assert input_ids.shape[0] == 1
# a) run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits)
generated_token = self.tokenizer.decode(generated_token_id)
stop = self.should_stop( stop = self.should_stop(
num_tokens_processed=input_ids.shape[1] + 1, num_tokens_processed=input_ids.shape[1] + 1,
generated_token_id = generated_token_id generated_token_id = generated_token_id
) )
# if not stopped, convert token id to text
generated_text = None
if not stop: if not stop:
all_stopped = False all_stopped = False
generated_text = self.tokenizer.decode(
generated_token_id, # c) make generation
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
generations.append(Generation( generations.append(Generation(
request_id=request.id, request_id=request.id,
generated_text=generated_text token=generated_token,
token_id=generated_token_id,
stopped=stop
)) ))
# update values in the batch # d) update batch
# bad --- this does not occur in place # TODO: this does not occur in place)
assert len(batch.input_ids_list[i].shape) == 2 assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1 assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append( batch.input_ids_list[i] = np.append(
@ -243,7 +231,7 @@ class DeepSparseCausalLM:
) )
batch.past_key_values_list[i] = past_key_values batch.past_key_values_list[i] = past_key_values
# if all elements of the batch are done, return generation + null for batch # if all elements of the batch are done, return null for batch
if all_stopped: if all_stopped:
return generations, None return generations, None

View File

@ -0,0 +1,76 @@
from typing import Dict, List, Tuple
from service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
from utils import Generation, CachedBatch, Batch
class BatchCache:
def __init__(self):
self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
def pop(self, batch_id: int) -> DeepSparseCausalLMBatch:
batch = self.cache.pop(batch_id, None)
assert batch is not None, "Batch ID {batch_id} not found in cache."
return batch
def set(self, entry: DeepSparseCausalLMBatch):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
del batch
def clear(self):
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())
class DeepSparseService:
def __init__(
self,
model: DeepSparseCausalLM
):
self.model = model
self.cache = BatchCache()
def ClearCache(self):
self.cache.clear()
def FilterBatch(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
ds_batch = self.cache.pop(batch_id)
filtered_ds_batch = ds_batch.filter(request_ids)
self.cache.set(filtered_ds_batch)
return filtered_ds_batch.to_cached_batch()
def Prefill(self, batch: Batch) -> Tuple[Generation, CachedBatch]:
ds_batch = DeepSparseCausalLMBatch.from_batch(
batch=batch,
tokenizer=self.model.tokenizer
)
generations, next_ds_batch = self.model.generate_token(ds_batch)
assert len(generations) == 1
self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_cached_batch()
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
assert len(batches) != 0, "Must provide at least one batch"
ds_batches = []
for cached_batch in batches:
ds_batches.append(self.cache.pop(cached_batch.batch_id))
if len(ds_batches) > 1:
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
else:
ds_batch = ds_batches[0]
generations, next_ds_batch = self.model.generate_token(ds_batch)
self.cache.set(next_ds_batch)
return generations, (next_ds_batch.to_cached_batch() if next_ds_batch else None)

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from queue import Queue
from typing import List, Optional from typing import List, Optional
@dataclass @dataclass
@ -23,17 +24,12 @@ class CachedBatch:
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
generated_text: Optional[str] token: Optional[str]
token_id: Optional[str]
@dataclass stopped: bool
class PrefillRequest:
batch: Batch
@dataclass @dataclass
class DecodeRequest: class GenerateRequest:
batches: List[CachedBatch] prompt: str
max_generated_tokens: int
@dataclass response_stream: Queue[Generation]
class FilterBatchRequest:
batch_id: int
request_ids: List[int]

View File

@ -11,6 +11,199 @@
"%autoreload 2" "%autoreload 2"
] ]
}, },
{
"cell_type": "code",
"execution_count": 10,
"id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\""
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6ea583b1-e2d3-4f35-b87f-630d097a2628",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'\\n'\n",
"b' '\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 0'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 0'\n",
"b'\\n'\n",
"b' '\n",
"b'el'\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 1'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 1'\n",
"b'\\n'\n",
"b' '\n",
"b'else'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'1'\n",
"b')'\n",
"b' +'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'2'\n",
"b')'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' Driver'\n",
"b' function'\n",
"b' to'\n",
"b' test'\n",
"b' above'\n",
"b' function'\n",
"b'\\n'\n",
"b'n'\n",
"b' ='\n",
"b' int'\n",
"b'('\n",
"b'input'\n",
"b'(\"'\n",
"b'Enter'\n",
"b' the'\n",
"b' number'\n",
"b':'\n",
"b' \"'\n",
"b'))'\n",
"b'\\n'\n",
"b'print'\n",
"b'('\n",
"b'f'\n",
"b'ib'\n",
"b'('\n",
"b'n'\n",
"b'))'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' This'\n",
"b' code'\n",
"b' is'\n",
"b' contributed'\n",
"b' by'\n",
"b' Nik'\n",
"b'h'\n",
"b'il'\n",
"b' Kumar'\n",
"b' Singh'\n",
"b'('\n",
"b'nick'\n",
"b'z'\n",
"b'uck'\n",
"b'_'\n",
"b'007'\n",
"b')'\n",
"b'\\n'\n"
]
}
],
"source": [
"import requests\n",
"\n",
"url = \"http://127.0.0.1:5543/generate_stream\"\n",
"obj = {\n",
" \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
" \"max_generated_tokens\":100\n",
"}\n",
"\n",
"with requests.post(url, json=obj, stream=True) as r:\n",
" for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n",
" print(chunk)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dfb0e1de-17e2-4ff2-aecb-22b6f607ef9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n-1) + fib(n-2)\n",
"\n",
"# Driver function to test above function\n",
"n = int(input(\"Enter the number: \"))\n",
"print(fib(n))\n",
"\n",
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n"
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate_stream \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6948b65b-8851-40f3-bf67-d7db2cf1955e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\""
]
}
],
"source": [
"!curl http://127.0.0.1:5543/generate/def%20fib%28n%29%3A"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "a19786b8-e72c-43c1-964f-45d92fd171e9", "id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
@ -21,7 +214,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d", "id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -29,19 +222,19 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n" "2023-08-24 17:46:45 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
] ]
} }
], ],
"source": [ "source": [
"from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n", "from server.deepsparse.router import DeepSparseRouter, batching_task\n",
"from server.deepsparse.deepsparse_service import DeepSparseService\n", "from server.deepsparse.service.service import DeepSparseService\n",
"from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM" "from server.deepsparse.service.causal_lm import DeepSparseCausalLM"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "78acf813-3688-483d-9148-5c0df5d6b8e3", "id": "78acf813-3688-483d-9148-5c0df5d6b8e3",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -50,7 +243,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Using pad_token, but it is not set yet.\n", "Using pad_token, but it is not set yet.\n",
"2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", "2023-08-24 17:47:02 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
"DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n" "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
] ]
}, },
@ -73,7 +266,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" "2023-08-24 17:47:27 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
] ]
}, },
{ {
@ -99,12 +292,376 @@
"model = DeepSparseCausalLM(\n", "model = DeepSparseCausalLM(\n",
" tokenizer_path=tokenizer_path,\n", " tokenizer_path=tokenizer_path,\n",
" model_path=onnx_path\n", " model_path=onnx_path\n",
")\n", ")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c011f167-f150-4a46-b24d-3074e8564151",
"metadata": {},
"outputs": [],
"source": [
"from server.deepsparse.utils import GenerateRequest\n",
"from queue import Queue \n",
"from threading import Thread\n",
"import time\n",
"\n", "\n",
"service = DeepSparseService(model=model)\n", "service = DeepSparseService(model=model)\n",
"router = DeepSparseRouter(service=service)" "router = DeepSparseRouter(service=service)\n",
"batching_thread = Thread(target=batching_task, args=[router])\n",
"batching_thread.start()"
] ]
}, },
{
"cell_type": "code",
"execution_count": 8,
"id": "538ba71e-9562-4325-899b-c67a4cf74075",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"starting request\n",
"starting request\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n-1) + fib(n-2)\n",
"\n",
"# Driver function to test above function\n",
"n = int(input(\"Enter the number: \"))\n",
"print(fib(n))\n",
"\n",
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
"\n",
"Write a function for filtering a list of integers to include only positive numbers:\n",
"\n",
"def filter(lst):\n",
" return [x for x in lst if x > 0]\n",
"\n",
"# Test\n",
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
"print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
"print(filter([1,\n"
]
}
],
"source": [
"prompts = [\n",
" \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
" \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n",
" \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n",
" \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n",
" \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n",
"]\n",
"\n",
"def generate_request(prompt):\n",
" print(\"starting request\")\n",
" response_stream = Queue()\n",
" \n",
" g_request = GenerateRequest(\n",
" prompt=prompt,\n",
" max_generated_tokens=100,\n",
" response_stream=response_stream\n",
" )\n",
"\n",
" router.generate(g_request)\n",
"\n",
" str = prompt\n",
" generation = response_stream.get()\n",
" while not generation.stopped:\n",
" str += generation.token\n",
" generation = response_stream.get()\n",
"\n",
" print(str)\n",
"\n",
"generate_threads = [\n",
" Thread(target=generate_request, args=[prompt]) for prompt in prompts[:2]\n",
"]\n",
"\n",
"# print(len(generate_threads))\n",
"\n",
"for gt in generate_threads:\n",
" gt.start()\n",
" time.sleep(1)\n",
"\n",
"for gt in generate_threads:\n",
" gt.join()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d9a1c1cc-e1d1-4915-8c3a-c1c9f2aa5147",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m service \u001b[38;5;241m=\u001b[39m DeepSparseService(model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2\u001b[0m router \u001b[38;5;241m=\u001b[39m DeepSparseRouter(service\u001b[38;5;241m=\u001b[39mservice)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mbatching_thread\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/threading.py:1060\u001b[0m, in \u001b[0;36mThread.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot join current thread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1059\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1060\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1062\u001b[0m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[1;32m 1063\u001b[0m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_tstate_lock(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[38;5;241m0\u001b[39m))\n",
"File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/threading.py:1080\u001b[0m, in \u001b[0;36mThread._wait_for_tstate_lock\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 1077\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1079\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1080\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 1081\u001b[0m lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 1082\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stop()\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"service = DeepSparseService(model=model)\n",
"router = DeepSparseRouter(service=service)\n",
"batching_thread.join()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a762fbbb-9711-4bf4-aeb9-3083a1f5b1f8",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9a556e2-d626-4733-84eb-e7e40f244981",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "aff1a7dd-de34-443a-bf37-6ef039bd25c8",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 54,
"id": "2c50cb8f-bc1b-487d-94e2-90cee1cabdf4",
"metadata": {},
"outputs": [],
"source": [
"batch, generate_requests = router.queue.next_batch(block=True)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "8c048933-7e50-4f6a-b360-41a3ae7ae8c8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CachedBatch(batch_id=0, request_ids=[0])\n"
]
}
],
"source": [
"cached_batch = router.prefill(batch, generate_requests)\n",
"print(cached_batch)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "43363249-5256-4259-b465-b99fc4d7f42c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"None\n"
]
}
],
"source": [
"next_batch = router.queue.next_batch(block=False)\n",
"print(next_batch)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "1ebb411e-c9ff-471a-9cd0-77fcdb2299df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n",
"CachedBatch(batch_id=0, request_ids=[0])\n"
]
}
],
"source": [
"for i in range(20):\n",
" cached_batch = router.decode(batches=[cached_batch], generate_requests=generate_requests)\n",
" print(cached_batch)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "1fb10db6-54bb-4fcb-8662-e12a65fc0fa8",
"metadata": {},
"outputs": [],
"source": [
"router.generate(g_request1)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "7bf59798-3325-4171-9a4b-fab283c8194b",
"metadata": {},
"outputs": [],
"source": [
"next_batch = router.queue.next_batch(block=False)\n",
"if next_batch is not None:\n",
" new_batch, new_generate_requests = next_batch\n",
"\n",
"new_cached_batch = router.prefill(\n",
" batch=new_batch,\n",
" generate_requests=new_generate_requests\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "9e9e5654-beb0-4fba-8bf6-49824d335925",
"metadata": {},
"outputs": [],
"source": [
"batches = [cached_batch]\n",
"batches.append(new_cached_batch)\n",
"generate_requests.update(new_generate_requests)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "764448bb-7841-4910-ab6b-56c3c7a2cde8",
"metadata": {},
"outputs": [],
"source": [
"for i in range(20):\n",
" cached_batch = router.decode(batches=batches, generate_requests=generate_requests)\n",
" batches = [cached_batch]"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "1536911c-8504-41b3-aa96-cc33024446f6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{0: GenerateRequest(prompt='Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):', max_generated_tokens=100, response_stream=<queue.Queue object at 0x7f552c1b3910>),\n",
" 1: GenerateRequest(prompt='Write a function for reversing a string:\\n\\ndef reverse_string(s):', max_generated_tokens=100, response_stream=<queue.Queue object at 0x7f5557c30ac0>)}"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_requests"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "d1dfb9ed-8bf0-4a38-a7f6-9cd22212e663",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n",
"Generation(request_id=1, token=' ', token_id=50284, stopped=False)\n",
"Generation(request_id=1, token='return', token_id=7783, stopped=False)\n",
"Generation(request_id=1, token=' s', token_id=264, stopped=False)\n",
"Generation(request_id=1, token='[', token_id=58, stopped=False)\n",
"Generation(request_id=1, token='::', token_id=3712, stopped=False)\n",
"Generation(request_id=1, token='-', token_id=12, stopped=False)\n",
"Generation(request_id=1, token='1', token_id=16, stopped=False)\n",
"Generation(request_id=1, token=']', token_id=60, stopped=False)\n",
"Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n",
"Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n",
"Generation(request_id=1, token='#', token_id=2, stopped=False)\n",
"Generation(request_id=1, token=' Test', token_id=6208, stopped=False)\n",
"Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n",
"Generation(request_id=1, token='print', token_id=4798, stopped=False)\n",
"Generation(request_id=1, token='(', token_id=7, stopped=False)\n",
"Generation(request_id=1, token='reverse', token_id=50188, stopped=False)\n",
"Generation(request_id=1, token='_', token_id=62, stopped=False)\n",
"Generation(request_id=1, token='string', token_id=8841, stopped=False)\n",
"Generation(request_id=1, token='(\"', token_id=7203, stopped=False)\n"
]
}
],
"source": [
"for i in range(20):\n",
" generation = response_stream1.get(block=False)\n",
" print(generation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41ead58c-e59b-4a9c-8b35-a1a658f22a7f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "b57985b2-75f8-4460-9b77-e531a12bd12c",
"metadata": {},
"outputs": [],
"source": []
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
@ -239,7 +796,7 @@
" gt.join()\n", " gt.join()\n",
"\n", "\n",
"\n", "\n",
"generate_task(\"stop\")\n", "generate(\"stop\")\n",
"batching_thread.join()" "batching_thread.join()"
] ]
}, },

View File

@ -1,58 +0,0 @@
from typing import Deque, Optional, Tuple, Dict
from collections import deque
from threading import Condition
from server.deepsparse.deepsparse_requests import Batch, Request
class GenerateRequest:
def __init__(
self,
prompt: str,
max_generated_tokens: int
):
self.prompt = prompt
self.generation = prompt
self.max_generated_tokens = max_generated_tokens
self.cv = Condition()
self.is_stopped = False
# todo: implement logic for maximum memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Deque[GenerateRequest] = deque()
def append(self, generate_request: GenerateRequest):
self.queue.append(generate_request)
def is_empty(self):
return len(self.queue) == 0
# (todo): enable multiple prefill requests in a batch
def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
if self.is_empty():
return None
# pop first generate_request in the queue
generate_request = self.queue.popleft()
generate_requests = {
self.next_request_id: generate_request
}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)

View File

@ -1,184 +0,0 @@
from threading import Condition
from typing import List, Dict, Optional
from server.deepsparse.deepsparse_service import DeepSparseService
from server.deepsparse.deepsparse_requests import (
CachedBatch, Batch, Generation,
PrefillRequest, DecodeRequest, FilterBatchRequest,
)
from server.deepsparse.deepsparse_queue import (
DeepSparseQueue, GenerateRequest
)
class DeepSparseRouter:
def __init__(self, service: DeepSparseService):
self.service: DeepSparseService = service
self.queue: DeepSparseQueue = DeepSparseQueue()
self.cv: Condition = Condition()
def generate(self, prompt:str) -> str:
generate_request = GenerateRequest(
prompt=prompt,
max_generated_tokens=100
)
with self.cv:
# print("router: acquired cv")
self.queue.append(generate_request)
self.cv.notify()
if prompt == "stop":
return "stop"
with generate_request.cv:
# print("generate_request: acquired cv")
if not generate_request.is_stopped:
# print("generate_request: waiting")
generate_request.cv.wait()
# print("generate_request: done waiting")
return generate_request.generation
def prefill(
self,
batch: Batch,
generate_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]:
# print("prefill")
generation, next_batch = self.service.Prefill(
PrefillRequest(batch=batch)
)
self.filter_notify_update([generation], generate_requests)
return self.filter_batch(
batch=next_batch,
generate_requests=generate_requests
)
def decode(
self,
batches: List[CachedBatch],
generate_requests: Dict[int,GenerateRequest]
) -> Optional[CachedBatch]:
# print("decode")
generations, next_batch = self.service.Decode(
DecodeRequest(batches=batches)
)
self.filter_notify_update(generations, generate_requests)
return self.filter_batch(
batch=next_batch,
generate_requests=generate_requests
)
def filter_notify_update(
self,
generations: List[Generation],
generate_requests: Dict[int, GenerateRequest]
):
# print("filter_notify_update")
for generation in generations:
request_id = generation.request_id
# if we hit a stopping criteria
if generation.generated_text is None:
# remove from active requests and notify
stopped_generate_request = generate_requests.pop(request_id)
with stopped_generate_request.cv:
stopped_generate_request.is_stopped = True
stopped_generate_request.cv.notify()
# otherwise, update generation
else:
generate_requests[request_id].generation += generation.generated_text
def filter_batch(
self,
batch: Optional[CachedBatch],
generate_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]:
# print("filter_batch")
# batch is already done
if batch is None:
return batch
# no need to filter
if len(batch) == len(generate_requests):
return batch
# retain only requests that are still in active generation requests
batch.request_ids = [id for id in batch.request_ids if id in generate_requests]
# if all requests complete, clear cache and return None
if len(batch) == 0:
self.service.ClearCache()
return None
# otherwise call the filter batch service
return self.service.FilterBatch(
FilterBatchRequest(
batch_id=batch.batch_id,
request_ids=batch.request_ids,
)
)
def batching_task(
router: DeepSparseRouter
) -> bool:
# infinite_loop
while True:
# block while the queue is empty
# print("batching_task: about to acquire cv")
with router.cv:
while router.queue.is_empty():
# print(f"batching_task cv: waiting")
router.cv.wait()
# print(f"batching_task: done waiting")
# loop until all batches in the queue are processed
next_batch = router.queue.next_batch()
while next_batch is not None:
batch, generate_requests = next_batch
# hack to break out of the cycle
if batch.requests[0].prompt == "stop":
assert router.queue.is_empty()
assert len(router.service.cache) == 0
return True
cached_batch = router.prefill(
batch=batch,
generate_requests=generate_requests
)
# loop until we do not reiceve any cached batch from the service (== until
# all requests have met their stopping criteria
while cached_batch is not None:
# print(f"batch_size = {len(cached_batch)}")
batches = [cached_batch]
# try to get a new batch and run prefill on this batch
next_batch = router.queue.next_batch()
if next_batch is not None:
new_batch, new_generate_requests = next_batch
new_cached_batch = router.prefill(
batch=new_batch,
generate_requests=new_generate_requests
)
if new_cached_batch is not None:
batches.append(new_cached_batch)
assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
generate_requests.update(new_generate_requests)
# run decode
cached_batch = router.decode(
batches=batches,
generate_requests=generate_requests
)
next_batch = router.queue.next_batch()

View File

@ -1,93 +0,0 @@
from typing import Optional, Dict, List
from server.deepsparse.deepsparse_causal_lm import (
DeepSparseCausalLM, DeepSparseCausalLMBatch
)
from server.deepsparse.deepsparse_requests import (
PrefillRequest, DecodeRequest, FilterBatchRequest,
Generation, CachedBatch
)
class Cache:
def __init__(self):
self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
def pop(self, batch_id: int) -> Optional[DeepSparseCausalLMBatch]:
return self.cache.pop(batch_id, None)
def set(self, entry: DeepSparseCausalLMBatch):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
del batch
def clear(self):
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())
class DeepSparseService:
def __init__(
self,
model: DeepSparseCausalLM
):
self.model = model
self.cache = Cache()
def ClearCache(self):
self.cache.clear()
def FilterBatch(
self,
request: FilterBatchRequest
) -> CachedBatch:
ds_batch = self.cache.pop(request.batch_id)
assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache."
filtered_batch = ds_batch.filter(request.request_ids)
self.cache.set(filtered_batch)
return filtered_batch.to_batch()
def Prefill(
self,
request: PrefillRequest
) -> [Generation, CachedBatch]:
ds_batch = DeepSparseCausalLMBatch.from_batch(
batch=request.batch,
tokenizer=self.model.tokenizer
)
generations, next_ds_batch = self.model.generate_token(ds_batch)
assert len(generations) == 1
self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_batch()
def Decode(
self,
request: DecodeRequest
) -> [List[Generation], CachedBatch]:
assert len(request.batches) != 0, "Must provide at least one batch"
ds_batches = []
for batch in request.batches:
ds_batch = self.cache.pop(batch.batch_id)
assert batch is not None, "Batch ID {batch.id} not found in cache."
ds_batches.append(ds_batch)
if len(ds_batches) > 1:
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
else:
ds_batch = ds_batches[0]
generations, next_ds_batch = self.model.generate_token(ds_batch)
self.cache.set(next_ds_batch)
return generations, (next_ds_batch.to_batch() if next_ds_batch else None)