mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: load weights within layer and refactor lora pass
This commit is contained in:
parent
db3d8e6518
commit
0a6ea7fb57
@ -126,27 +126,51 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
self.index = index
|
self.index = index
|
||||||
self.adapter_weights = {}
|
self.adapter_weights = {}
|
||||||
for adapter_id, adapter_weights in all_adapter_weights.items():
|
adapter_names = list(all_adapter_weights.keys())
|
||||||
filtered_keys = list(
|
|
||||||
filter(
|
self.lora_a_matrix = torch.empty(
|
||||||
lambda x: x.startswith(
|
(len(adapter_names), 2, 4096, 8),
|
||||||
f"base_model.model.model.layers.{index}.self_attn"
|
|
||||||
),
|
|
||||||
adapter_weights.keys(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.adapter_weights[adapter_id] = {
|
|
||||||
key: torch.tensor(
|
|
||||||
adapter_weights[key],
|
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
).T
|
)
|
||||||
for key in filtered_keys
|
self.lora_b_matrix = torch.empty(
|
||||||
}
|
(len(adapter_names), 2, 8, 4096),
|
||||||
|
device=weights.device,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
self.index_to_key = {
|
self.pre_multiplied_lora_matrix = torch.empty(
|
||||||
i: key for i, key in enumerate(self.adapter_weights.keys())
|
(len(adapter_names), 2, 4096, 4096),
|
||||||
}
|
device=weights.device,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.key_to_index = {}
|
||||||
|
self.index_to_key = {}
|
||||||
|
|
||||||
|
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
|
||||||
|
for adapter_index, adapter_name in enumerate(adapter_names):
|
||||||
|
self.lora_alpha = 16.0
|
||||||
|
self.lora_r = 8.0
|
||||||
|
self.lora_scale = self.lora_alpha / self.lora_r
|
||||||
|
self.key_to_index[adapter_name] = adapter_index
|
||||||
|
self.index_to_key[adapter_index] = adapter_name
|
||||||
|
adapter_weights = all_adapter_weights[adapter_name]
|
||||||
|
for target_index, target in enumerate(["q", "v"]):
|
||||||
|
adapter_weight_a = adapter_weights.get_tensor(
|
||||||
|
f"{lora_prefix}.{target}_proj.lora_A.weight"
|
||||||
|
)
|
||||||
|
adapter_weight_b = adapter_weights.get_tensor(
|
||||||
|
f"{lora_prefix}.{target}_proj.lora_B.weight"
|
||||||
|
)
|
||||||
|
pre_multiplied_lora_matrix = torch.matmul(
|
||||||
|
adapter_weight_a.T * self.lora_scale,
|
||||||
|
adapter_weight_b.T,
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = (
|
||||||
|
pre_multiplied_lora_matrix
|
||||||
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -159,23 +183,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
def get_adapter_weights(self, lora_index):
|
|
||||||
adapter_id = self.index_to_key[lora_index]
|
|
||||||
q_proj_lora_a = self.adapter_weights[adapter_id][
|
|
||||||
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
|
|
||||||
]
|
|
||||||
q_proj_lora_b = self.adapter_weights[adapter_id][
|
|
||||||
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
|
|
||||||
]
|
|
||||||
|
|
||||||
v_proj_lora_a = self.adapter_weights[adapter_id][
|
|
||||||
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
|
|
||||||
]
|
|
||||||
v_proj_lora_b = self.adapter_weights[adapter_id][
|
|
||||||
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
|
|
||||||
]
|
|
||||||
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -201,39 +208,42 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
|
|
||||||
self.get_adapter_weights(
|
|
||||||
# TODO: dont just assume the first adapter
|
|
||||||
lora_indices[0].item()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
query_adapted = torch.matmul(
|
|
||||||
hidden_states,
|
|
||||||
torch.matmul(
|
|
||||||
q_proj_lora_a,
|
|
||||||
q_proj_lora_b,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
value_adapted = torch.matmul(
|
|
||||||
hidden_states,
|
|
||||||
torch.matmul(
|
|
||||||
v_proj_lora_a,
|
|
||||||
v_proj_lora_b,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = query.size(0)
|
batch_size = query.size(0)
|
||||||
|
if not torch.all(lora_indices, -1):
|
||||||
|
lora_mask = lora_indices[lora_indices != -1]
|
||||||
|
|
||||||
# TODO: improve this to avoid unnecessary work
|
q_pre_multiplied_batch = torch.ones(
|
||||||
# mask across batch and within lora adapters
|
(batch_size, 4096, 4096),
|
||||||
query[batch_lora_adapter_mask] += query_adapted.view(
|
device=hidden_states.device,
|
||||||
batch_size, self.num_heads, self.head_size
|
dtype=hidden_states.dtype,
|
||||||
)[batch_lora_adapter_mask]
|
)
|
||||||
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
|
|
||||||
batch_size, self.num_key_value_heads, self.head_size
|
q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
|
||||||
)[batch_lora_adapter_mask]
|
lora_mask, 0
|
||||||
|
]
|
||||||
|
|
||||||
|
v_pre_multiplied_batch = torch.ones(
|
||||||
|
(batch_size, 4096, 4096),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
|
||||||
|
lora_mask, 1
|
||||||
|
]
|
||||||
|
|
||||||
|
query_adapted = (
|
||||||
|
torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch)
|
||||||
|
.squeeze(1)
|
||||||
|
.view(batch_size, self.num_heads, self.head_size)
|
||||||
|
)
|
||||||
|
value_adapted = (
|
||||||
|
torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch)
|
||||||
|
.squeeze(1)
|
||||||
|
.view(batch_size, self.num_key_value_heads, self.head_size)
|
||||||
|
)
|
||||||
|
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
||||||
|
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
@ -503,6 +513,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_lora_index(self, adapter_id):
|
||||||
|
return self.model.layers[0].self_attn.key_to_index[adapter_id]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
@ -1064,11 +1064,11 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||||
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
for i, r in enumerate(batch.requests):
|
for i, r in enumerate(batch.requests):
|
||||||
if r.adapter_id:
|
if r.adapter_id:
|
||||||
lora_index = int(r.adapter_id)
|
lora_index = self.model.get_lora_index(r.adapter_id)
|
||||||
lora_indices[i] = lora_index
|
lora_indices[i] = lora_index
|
||||||
batch_lora_adapter_mask[i] = True
|
batch_lora_adapter_mask[i] = True
|
||||||
|
|
||||||
|
@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|||||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||||
|
|
||||||
|
|
||||||
|
def _cached_adapter_weight_files(
|
||||||
|
adapter_id: str, revision: Optional[str], extension: str
|
||||||
|
) -> List[str]:
|
||||||
|
"""Guess weight files from the cached revision snapshot directory"""
|
||||||
|
d = _get_cached_revision_directory(adapter_id, revision)
|
||||||
|
if not d:
|
||||||
|
return []
|
||||||
|
filenames = _adapter_weight_files_from_dir(d, extension)
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _cached_weight_files(
|
def _cached_weight_files(
|
||||||
model_id: str, revision: Optional[str], extension: str
|
model_id: str, revision: Optional[str], extension: str
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
@ -60,6 +71,21 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||||
|
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||||
|
# see _weight_files_from_dir, that's also what is done there
|
||||||
|
root, _, files = next(os.walk(str(d)))
|
||||||
|
filenames = [
|
||||||
|
os.path.join(root, f)
|
||||||
|
for f in files
|
||||||
|
if f.endswith(extension)
|
||||||
|
and "arguments" not in f
|
||||||
|
and "args" not in f
|
||||||
|
and "training" not in f
|
||||||
|
]
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_revision_directory(
|
def _get_cached_revision_directory(
|
||||||
model_id: str, revision: Optional[str]
|
model_id: str, revision: Optional[str]
|
||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
|
Loading…
Reference in New Issue
Block a user