mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 04:22:08 +00:00
black
This commit is contained in:
parent
fbc5a6a120
commit
e72897004a
@ -834,7 +834,9 @@ class FlashCausalLM(Model):
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# TODO: is this correct?
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
||||
input_lengths = (
|
||||
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
|
@ -85,6 +85,7 @@ except ImportError as e:
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
@ -93,7 +94,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim)
|
||||
hidden_states = hidden_states[:, :, None, :].expand(
|
||||
total_tokens, num_key_value_heads, n_rep, head_dim
|
||||
)
|
||||
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
@ -229,4 +232,6 @@ def attention(
|
||||
)
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
||||
raise NotImplementedError(
|
||||
f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user