mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
updated ipynb for simulating interacting with the server
This commit is contained in:
parent
a972a7870f
commit
010739bba1
@ -90,7 +90,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
old_idx = self.requests_idx_mapping[request_id]
|
old_idx = self.requests_idx_mapping[request_id]
|
||||||
requests.append(self.requests[old_idx])
|
requests.append(self.requests[old_idx])
|
||||||
input_ids_list.append(self.input_ids_list[old_idx])
|
input_ids_list.append(self.input_ids_list[old_idx])
|
||||||
past_key_values_list.append(self.past_key_values[old_idx])
|
past_key_values_list.append(self.past_key_values_list[old_idx])
|
||||||
|
|
||||||
# update batch state
|
# update batch state
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
@ -112,7 +112,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
|
assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
|
||||||
|
|
||||||
# concatenate request, input_ids, and past_key_values lists
|
# concatenate request, input_ids, and past_key_values lists
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
@ -129,7 +129,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
start_index += len(batch)
|
start_index += len(batch)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id= batches[0].id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids_list=input_ids_list,
|
input_ids_list=input_ids_list,
|
||||||
@ -210,8 +210,11 @@ class DeepSparseCausalLM:
|
|||||||
|
|
||||||
# check stopping criteria
|
# check stopping criteria
|
||||||
# simple for now --- should use StoppingCriteria
|
# simple for now --- should use StoppingCriteria
|
||||||
|
assert len(input_ids.shape) == 2
|
||||||
|
assert input_ids.shape[0] == 1
|
||||||
|
|
||||||
stop = self.should_stop(
|
stop = self.should_stop(
|
||||||
num_tokens_processed=len(input_ids) + 1,
|
num_tokens_processed=input_ids.shape[1] + 1,
|
||||||
generated_token_id = generated_token_id
|
generated_token_id = generated_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,4 +31,5 @@ class DecodeRequest:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilterBatchRequest:
|
class FilterBatchRequest:
|
||||||
batch_id: int
|
batch_id: int
|
||||||
|
request_ids: List[int]
|
@ -56,16 +56,17 @@ class DeepSparseService:
|
|||||||
def Prefill(
|
def Prefill(
|
||||||
self,
|
self,
|
||||||
request: PrefillRequest
|
request: PrefillRequest
|
||||||
) -> [List[Generation], CachedBatch]:
|
) -> [Generation, CachedBatch]:
|
||||||
ds_batch = DeepSparseCausalLMBatch.from_batch(
|
ds_batch = DeepSparseCausalLMBatch.from_batch(
|
||||||
batch=request.batch,
|
batch=request.batch,
|
||||||
tokenizer=self.model.tokenizer
|
tokenizer=self.model.tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||||
|
assert len(generations) == 1
|
||||||
self.cache.set(next_ds_batch)
|
self.cache.set(next_ds_batch)
|
||||||
|
|
||||||
return generations, next_ds_batch.to_batch()
|
return generations[0], next_ds_batch.to_batch()
|
||||||
|
|
||||||
def Decode(
|
def Decode(
|
||||||
self,
|
self,
|
||||||
@ -75,16 +76,16 @@ class DeepSparseService:
|
|||||||
|
|
||||||
ds_batches = []
|
ds_batches = []
|
||||||
for batch in request.batches:
|
for batch in request.batches:
|
||||||
ds_batch = self.cache.pop(batch.id)
|
ds_batch = self.cache.pop(batch.batch_id)
|
||||||
assert batch is not None, "Batch ID {batch.id} not found in cache."
|
assert batch is not None, "Batch ID {batch.id} not found in cache."
|
||||||
ds_batches.append(ds_batch)
|
ds_batches.append(ds_batch)
|
||||||
|
|
||||||
if len(ds_batches) > 1:
|
if len(ds_batches) > 1:
|
||||||
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
|
ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
|
||||||
else:
|
else:
|
||||||
batch = ds_batches[0]
|
ds_batch = ds_batches[0]
|
||||||
|
|
||||||
generations, next_ds_batch = self.model.generate_token(ds_batches)
|
generations, next_ds_batch = self.model.generate_token(ds_batch)
|
||||||
self.cache.set(next_ds_batch)
|
self.cache.set(next_ds_batch)
|
||||||
|
|
||||||
return generations, next_ds_batch.to_batch()
|
return generations, next_ds_batch.to_batch() if next_ds_batch else None
|
Loading…
Reference in New Issue
Block a user