mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Fix the generate_stream crash in concurrent query (#105)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
3d81a80577
commit
a4d5c3f40f
@ -380,8 +380,15 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
parameters = [r.data.parameters for r in flat_requests]
|
||||||
|
if len(flat_requests) < new_bs:
|
||||||
|
for i in range(new_bs-len(flat_requests)) :
|
||||||
|
#append the dummy parameters for dummy request
|
||||||
|
parameters.append(parameters[0])
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
[r.data.parameters for r in flat_requests],
|
parameters,
|
||||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||||
batches[dst_batch_idx].next_token_chooser.device
|
batches[dst_batch_idx].next_token_chooser.device
|
||||||
)
|
)
|
||||||
@ -426,13 +433,19 @@ class CausalLMBatch(Batch):
|
|||||||
# TODO: Add support for sparse batches
|
# TODO: Add support for sparse batches
|
||||||
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device)
|
|
||||||
|
|
||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
# this means that we cannot shift inputs to the left after a long input sequence
|
# this means that we cannot shift inputs to the left after a long input sequence
|
||||||
# was filtered out
|
# was filtered out
|
||||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||||
dummy_inputs = ["?"] * (new_bs - len(requests))
|
dummy_inputs = ["?"] * (new_bs - len(requests))
|
||||||
|
parameters = [r.parameters for r in pb.requests]
|
||||||
|
if len(pb.requests) < new_bs:
|
||||||
|
for i in range(new_bs-len(pb.requests)) :
|
||||||
|
#append the dummy parameters for dummy request
|
||||||
|
parameters.append(parameters[0])
|
||||||
|
|
||||||
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device)
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
[r.data.inputs for r in requests] + dummy_inputs,
|
[r.data.inputs for r in requests] + dummy_inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
Loading…
Reference in New Issue
Block a user