mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
# Origin: https://github.com/predibase/lorax
|
|
# Path: lorax/server/lorax_server/utils/segments.py
|
|
# License: Apache License Version 2.0, January 2004
|
|
|
|
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
|
|
|
|
def find_segments(
|
|
adapter_indices: Union[torch.Tensor, List[int]]
|
|
) -> Tuple[List[int], List[int]]:
|
|
segments = [0]
|
|
segment_indices = []
|
|
|
|
if isinstance(adapter_indices, torch.Tensor):
|
|
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
|
|
adapter_indices = adapter_indices.cpu().tolist()
|
|
|
|
start_index = 0
|
|
for i in range(1, len(adapter_indices)):
|
|
if adapter_indices[i] != adapter_indices[i - 1]:
|
|
segments.append(i)
|
|
segment_indices.append(adapter_indices[i - 1])
|
|
start_index = i
|
|
|
|
# Handle the last segment
|
|
if start_index < len(adapter_indices):
|
|
segments.append(len(adapter_indices))
|
|
segment_indices.append(adapter_indices[-1])
|
|
|
|
return segments, segment_indices
|
|
|
|
|
|
class SegmentConcatBuilder:
|
|
def __init__(self):
|
|
self.adapter_segment_indices = []
|
|
self.adapter_segment_tensors = []
|
|
|
|
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
|
|
# Update adapter segments
|
|
if self.adapter_segment_tensors:
|
|
# Because we have already processed at least one batch, remove the 0 start index
|
|
# from this batch denoting the beginning of the segment, then offset all segment
|
|
# positions by the value of the last segment in the previous batch to account for
|
|
# the concatenation.
|
|
adapter_segments = (
|
|
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
|
|
)
|
|
|
|
if (
|
|
self.adapter_segment_indices
|
|
and self.adapter_segment_indices[-1] == segment_indices[0]
|
|
):
|
|
# If the last segment in the previous batch is the same as the first segment in this batch,
|
|
# then we merge them together into a single segment. In effect, this means removing it from
|
|
# the segment indices of this batch, and extending the segment span by removing the segment
|
|
# end index from the previous batch.
|
|
segment_indices = segment_indices[1:]
|
|
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
|
|
|
|
self.adapter_segment_indices.extend(segment_indices)
|
|
self.adapter_segment_tensors.append(adapter_segments)
|
|
|
|
def build(self) -> Tuple[torch.Tensor, List[int]]:
|
|
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
|