fix: adjust batch for bgmv

This commit is contained in:
drbh 2024-06-06 17:45:08 +00:00
parent 8984ce6c69
commit ad088d51fa
2 changed files with 32 additions and 22 deletions

View File

@ -1205,8 +1205,13 @@ class FlashCausalLM(Model):
if prefill_logprobs
else speculative_logits
)
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
else:
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate()
(
@ -1228,6 +1233,14 @@ class FlashCausalLM(Model):
)
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
@ -1297,6 +1310,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs:
# Get prefill logprobs

View File

@ -239,28 +239,24 @@ def serve(
max_input_tokens,
)
# TODO: avoid hacky hardcoded adapter id
adapter_parameters = AdapterParameters(
adapter_ids=lora_adapter_ids,
weights=[
# TODO: fill with actual weights
torch.tensor([1.0], dtype=torch.float32)
],
merge_strategy=0,
density=0.0,
majority_sign_method=0,
)
adapter_source = None
adapter_index = 0
api_token = None
model.load_adapter(
adapter_parameters,
adapter_source,
adapter_index,
api_token,
False,
)
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: avoid hacky hardcoded adapter id
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=[],
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index
model.load_adapter(
adapter_parameters,
None, # adapter_source
adapter_index,
None, # api_token
False, # dynamic
)
except Exception:
logger.exception("Error when initializing model")