mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-07 10:22:09 +00:00
fix: adjust adapter_segments logic when in batch
This commit is contained in:
parent
ad088d51fa
commit
c927376725
@ -698,6 +698,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
cumulative_adapter_indices_size = adapter_end_index
|
cumulative_adapter_indices_size = adapter_end_index
|
||||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||||
|
adapter_segment_builder.concat(
|
||||||
|
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
||||||
|
)
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
@ -742,7 +745,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
_adapter_segments, _adapter_segment_indices = adapter_segment_builder.build()
|
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
@ -774,6 +777,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
|
adapter_meta=AdapterBatchMetadata(
|
||||||
|
adapter_indices=adapter_indices,
|
||||||
|
adapter_set=adapter_set,
|
||||||
|
adapter_segments=adapter_segments,
|
||||||
|
segment_indices=adapter_segment_indices,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -1233,14 +1242,6 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
|
||||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
|
||||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
|
||||||
adapter_segments,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=batch.adapter_meta.adapter_segments.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
@ -1285,6 +1286,12 @@ class FlashCausalLM(Model):
|
|||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
|
# Initialize adapter indices
|
||||||
|
# In decode, we only have one token per row in the batch, so grab last index
|
||||||
|
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
|
||||||
|
end_index - 1
|
||||||
|
]
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
@ -1312,6 +1319,15 @@ class FlashCausalLM(Model):
|
|||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
|
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||||
|
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||||
|
adapter_segments,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=batch.adapter_meta.adapter_segments.device,
|
||||||
|
)
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||||
|
@ -241,10 +241,11 @@ def serve(
|
|||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
if len(lora_adapter_ids) > 0:
|
||||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||||
# TODO: avoid hacky hardcoded adapter id
|
# TODO: improve non merged adapter loading and long term
|
||||||
|
# improve adapter loading as a whole
|
||||||
adapter_parameters = AdapterParameters(
|
adapter_parameters = AdapterParameters(
|
||||||
adapter_ids=[adapter_id],
|
adapter_ids=[adapter_id],
|
||||||
weights=[],
|
weights=None, # will be set to 1
|
||||||
merge_strategy=0,
|
merge_strategy=0,
|
||||||
density=1.0,
|
density=1.0,
|
||||||
majority_sign_method=0,
|
majority_sign_method=0,
|
||||||
|
@ -162,8 +162,6 @@ def load_module_map(
|
|||||||
api_token: str,
|
api_token: str,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
print("adapter_id", adapter_id)
|
|
||||||
|
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
adapter_config = LoraConfig.load(adapter_id, api_token)
|
||||||
|
Loading…
Reference in New Issue
Block a user