mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
updated ipynb for simulating interacting with the server
This commit is contained in:
parent
a972a7870f
commit
010739bba1
@ -90,7 +90,7 @@ class DeepSparseCausalLMBatch:
|
||||
old_idx = self.requests_idx_mapping[request_id]
|
||||
requests.append(self.requests[old_idx])
|
||||
input_ids_list.append(self.input_ids_list[old_idx])
|
||||
past_key_values_list.append(self.past_key_values[old_idx])
|
||||
past_key_values_list.append(self.past_key_values_list[old_idx])
|
||||
|
||||
# update batch state
|
||||
self.requests = requests
|
||||
@ -112,7 +112,7 @@ class DeepSparseCausalLMBatch:
|
||||
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
|
||||
assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
|
||||
|
||||
# concatenate request, input_ids, and past_key_values lists
|
||||
requests.extend(batch.requests)
|
||||
@ -129,7 +129,7 @@ class DeepSparseCausalLMBatch:
|
||||
start_index += len(batch)
|
||||
|
||||
return cls(
|
||||
batch_id= batches[0].id,
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids_list=input_ids_list,
|
||||
@ -210,8 +210,11 @@ class DeepSparseCausalLM:
|
||||
|
||||
# check stopping criteria
|
||||
# simple for now --- should use StoppingCriteria
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
|
||||
stop = self.should_stop(
|
||||
num_tokens_processed=len(input_ids) + 1,
|
||||
num_tokens_processed=input_ids.shape[1] + 1,
|
||||
generated_token_id = generated_token_id
|
||||
)
|
||||
|
||||
|
@ -32,3 +32,4 @@ class DecodeRequest:
|
||||
@dataclass
|
||||
class FilterBatchRequest:
|
||||
batch_id: int
|
||||
request_ids: List[int]
|
@ -56,16 +56,17 @@ class DeepSparseService:
|
||||
def Prefill(
|
||||
self,
|
||||
request: PrefillRequest
|
||||
) -> [List[Generation], CachedBatch]:
|
||||
) -> [Generation, CachedBatch]:
|
||||
ds_batch = DeepSparseCausalLMBatch.from_batch(
|
||||
batch=request.batch,
|
||||
tokenizer=self.model.tokenizer
|
||||
)
|
||||
|
||||
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||
assert len(generations) == 1
|
||||
self.cache.set(next_ds_batch)
|
||||
|
||||
return generations, next_ds_batch.to_batch()
|
||||
return generations[0], next_ds_batch.to_batch()
|
||||
|
||||
def Decode(
|
||||
self,
|
||||
@ -75,16 +76,16 @@ class DeepSparseService:
|
||||
|
||||
ds_batches = []
|
||||
for batch in request.batches:
|
||||
ds_batch = self.cache.pop(batch.id)
|
||||
ds_batch = self.cache.pop(batch.batch_id)
|
||||
assert batch is not None, "Batch ID {batch.id} not found in cache."
|
||||
ds_batches.append(ds_batch)
|
||||
|
||||
if len(ds_batches) > 1:
|
||||
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
|
||||
else:
|
||||
batch = ds_batches[0]
|
||||
ds_batch = ds_batches[0]
|
||||
|
||||
generations, next_ds_batch = self.model.generate_token(ds_batches)
|
||||
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||
self.cache.set(next_ds_batch)
|
||||
|
||||
return generations, next_ds_batch.to_batch()
|
||||
return generations, next_ds_batch.to_batch() if next_ds_batch else None
|
Loading…
Reference in New Issue
Block a user