mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Fix the generate_stream crash in concurrent query (#105)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
3d81a80577
commit
a4d5c3f40f
@ -380,8 +380,15 @@ class CausalLMBatch(Batch):
|
||||
|
||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
|
||||
parameters = [r.data.parameters for r in flat_requests]
|
||||
if len(flat_requests) < new_bs:
|
||||
for i in range(new_bs-len(flat_requests)) :
|
||||
#append the dummy parameters for dummy request
|
||||
parameters.append(parameters[0])
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
[r.data.parameters for r in flat_requests],
|
||||
parameters,
|
||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||
batches[dst_batch_idx].next_token_chooser.device
|
||||
)
|
||||
@ -426,13 +433,19 @@ class CausalLMBatch(Batch):
|
||||
# TODO: Add support for sparse batches
|
||||
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device)
|
||||
|
||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||
# this means that we cannot shift inputs to the left after a long input sequence
|
||||
# was filtered out
|
||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||
dummy_inputs = ["?"] * (new_bs - len(requests))
|
||||
parameters = [r.parameters for r in pb.requests]
|
||||
if len(pb.requests) < new_bs:
|
||||
for i in range(new_bs-len(pb.requests)) :
|
||||
#append the dummy parameters for dummy request
|
||||
parameters.append(parameters[0])
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device)
|
||||
tokenized_inputs = tokenizer(
|
||||
[r.data.inputs for r in requests] + dummy_inputs,
|
||||
return_tensors="pt",
|
||||
|
Loading…
Reference in New Issue
Block a user