This commit is contained in:
Mohit Sharma 2024-04-23 15:05:39 +00:00
parent fbc5a6a120
commit e72897004a
2 changed files with 10 additions and 3 deletions

View File

@ -834,7 +834,9 @@ class FlashCausalLM(Model):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
# TODO: is this correct? # TODO: is this correct?
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s input_lengths = (
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
)
block_tables = ( block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device) torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs) .repeat(bs)

View File

@ -85,6 +85,7 @@ except ImportError as e:
logger.warning(f"Unable to use Flash Attention V2: {e}") logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" """
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@ -93,7 +94,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
total_tokens, num_key_value_heads, head_dim = hidden_states.shape total_tokens, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1: if n_rep == 1:
return hidden_states return hidden_states
hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim) hidden_states = hidden_states[:, :, None, :].expand(
total_tokens, num_key_value_heads, n_rep, head_dim
)
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim) return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
@ -229,4 +232,6 @@ def attention(
) )
return output return output
else: else:
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})") raise NotImplementedError(
f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})"
)