mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
black
This commit is contained in:
parent
fbc5a6a120
commit
e72897004a
@ -834,7 +834,9 @@ class FlashCausalLM(Model):
|
|||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
# TODO: is this correct?
|
# TODO: is this correct?
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = (
|
||||||
|
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
)
|
||||||
block_tables = (
|
block_tables = (
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||||
.repeat(bs)
|
.repeat(bs)
|
||||||
|
@ -85,6 +85,7 @@ except ImportError as e:
|
|||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
@ -93,7 +94,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
|
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim)
|
hidden_states = hidden_states[:, :, None, :].expand(
|
||||||
|
total_tokens, num_key_value_heads, n_rep, head_dim
|
||||||
|
)
|
||||||
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
|
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
|
||||||
|
|
||||||
|
|
||||||
@ -229,4 +232,6 @@ def attention(
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
raise NotImplementedError(
|
||||||
|
f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})"
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user