diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b790896d8..b41712f43 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -209,41 +209,20 @@ class FlashLlamaAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) batch_size = query.size(0) - if not torch.all(lora_indices, -1): - lora_mask = lora_indices[lora_indices != -1] + query_adapted = ( + torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0]) + .squeeze(0) + .view(batch_size, self.num_heads, self.head_size) + ) - q_pre_multiplied_batch = torch.ones( - (batch_size, 4096, 4096), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + value_adapted = ( + torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1]) + .squeeze(0) + .view(batch_size, self.num_key_value_heads, self.head_size) + ) - q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[ - lora_mask, 0 - ] - - v_pre_multiplied_batch = torch.ones( - (batch_size, 4096, 4096), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[ - lora_mask, 1 - ] - - query_adapted = ( - torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch) - .squeeze(1) - .view(batch_size, self.num_heads, self.head_size) - ) - value_adapted = ( - torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch) - .squeeze(1) - .view(batch_size, self.num_key_value_heads, self.head_size) - ) - query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask] - kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask] + query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask] + kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask] self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)