mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 13:32:10 +00:00
fix: adjust batch for bgmv
This commit is contained in:
parent
8984ce6c69
commit
ad088d51fa
@ -1205,8 +1205,13 @@ class FlashCausalLM(Model):
|
|||||||
if prefill_logprobs
|
if prefill_logprobs
|
||||||
else speculative_logits
|
else speculative_logits
|
||||||
)
|
)
|
||||||
|
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
||||||
|
len(batch)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
next_adapter_indices = batch.adapter_meta.adapter_indices
|
||||||
|
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
(
|
(
|
||||||
@ -1228,6 +1233,14 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
|
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||||
|
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||||
|
adapter_segments,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=batch.adapter_meta.adapter_segments.device,
|
||||||
|
)
|
||||||
|
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
@ -1297,6 +1310,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.input_lengths_tensor += accepted_ids
|
batch.input_lengths_tensor += accepted_ids
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
|
@ -239,28 +239,24 @@ def serve(
|
|||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: avoid hacky hardcoded adapter id
|
if len(lora_adapter_ids) > 0:
|
||||||
adapter_parameters = AdapterParameters(
|
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||||
adapter_ids=lora_adapter_ids,
|
# TODO: avoid hacky hardcoded adapter id
|
||||||
weights=[
|
adapter_parameters = AdapterParameters(
|
||||||
# TODO: fill with actual weights
|
adapter_ids=[adapter_id],
|
||||||
torch.tensor([1.0], dtype=torch.float32)
|
weights=[],
|
||||||
],
|
merge_strategy=0,
|
||||||
merge_strategy=0,
|
density=1.0,
|
||||||
density=0.0,
|
majority_sign_method=0,
|
||||||
majority_sign_method=0,
|
)
|
||||||
)
|
adapter_index = index
|
||||||
adapter_source = None
|
model.load_adapter(
|
||||||
adapter_index = 0
|
adapter_parameters,
|
||||||
api_token = None
|
None, # adapter_source
|
||||||
|
adapter_index,
|
||||||
model.load_adapter(
|
None, # api_token
|
||||||
adapter_parameters,
|
False, # dynamic
|
||||||
adapter_source,
|
)
|
||||||
adapter_index,
|
|
||||||
api_token,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
|
Loading…
Reference in New Issue
Block a user