mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
fix: adjust batch for bgmv
This commit is contained in:
parent
8984ce6c69
commit
ad088d51fa
@ -1205,8 +1205,13 @@ class FlashCausalLM(Model):
|
||||
if prefill_logprobs
|
||||
else speculative_logits
|
||||
)
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
||||
len(batch)
|
||||
)
|
||||
|
||||
else:
|
||||
next_token_logits = out
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices
|
||||
|
||||
speculate = get_speculate()
|
||||
(
|
||||
@ -1228,6 +1233,14 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
if prefill:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
adapter_segments,
|
||||
dtype=torch.int32,
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
@ -1297,6 +1310,7 @@ class FlashCausalLM(Model):
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.input_lengths_tensor += accepted_ids
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
|
@ -239,28 +239,24 @@ def serve(
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
# TODO: avoid hacky hardcoded adapter id
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=lora_adapter_ids,
|
||||
weights=[
|
||||
# TODO: fill with actual weights
|
||||
torch.tensor([1.0], dtype=torch.float32)
|
||||
],
|
||||
merge_strategy=0,
|
||||
density=0.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_source = None
|
||||
adapter_index = 0
|
||||
api_token = None
|
||||
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
adapter_source,
|
||||
adapter_index,
|
||||
api_token,
|
||||
False,
|
||||
)
|
||||
if len(lora_adapter_ids) > 0:
|
||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||
# TODO: avoid hacky hardcoded adapter id
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=[adapter_id],
|
||||
weights=[],
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_index = index
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
None, # adapter_source
|
||||
adapter_index,
|
||||
None, # api_token
|
||||
False, # dynamic
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
|
Loading…
Reference in New Issue
Block a user