mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add other models
This commit is contained in:
parent
3fc87f93bd
commit
c509e4e79d
@ -1,9 +1,9 @@
|
|||||||
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7
|
flash_att_commit := c5b2a9b7baba2d3059888dbeb03a3cea7aba6e1d
|
||||||
|
|
||||||
flash-attention:
|
flash-attention:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install packaging
|
pip install packaging
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git
|
git clone https://github.com/OlivierDehaene/flash-attention.git
|
||||||
|
|
||||||
build-flash-attention: flash-attention
|
build-flash-attention: flash-attention
|
||||||
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
||||||
|
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda_modif
|
import flash_attn_cuda
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
@ -128,34 +128,42 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
query, kv = qkv.split([1, 2], dim=1)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv[:, 1:]
|
layer_past[past_present_indices] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
query,
|
||||||
qkv[:, 1],
|
torch.select(kv, dim=1, index=0),
|
||||||
qkv[:, 2],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -168,20 +176,21 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
layer_past[past_present_indices] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
layer_past[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
layer_past[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -258,11 +267,14 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -271,11 +283,14 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
normed_hidden_states,
|
normed_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -322,35 +337,36 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_present_indices,
|
||||||
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -360,23 +376,19 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
torch.select(past_key_values, dim=1, index=i),
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -399,9 +411,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -409,9 +424,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
hidden_states, present = self.model(
|
hidden_states, present = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda_modif
|
import flash_attn_cuda
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -113,34 +113,42 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
query, kv = qkv.split([1, 2], dim=1)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv[:, 1:]
|
layer_past[past_present_indices] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
query,
|
||||||
qkv[:, 1],
|
torch.select(kv, dim=1, index=0),
|
||||||
qkv[:, 2],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -153,20 +161,21 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
layer_past[past_present_indices] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
layer_past[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
layer_past[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -240,11 +249,14 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
@ -253,11 +265,14 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
ln1_hidden_states,
|
ln1_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
@ -275,11 +290,14 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -328,9 +346,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -338,25 +359,23 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -366,23 +385,19 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
torch.select(past_key_values, dim=1, index=i),
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
@ -403,9 +418,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -413,9 +431,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
hidden_states, present = self.gpt_neox(
|
hidden_states, present = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda_modif
|
import flash_attn_cuda
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -42,25 +42,25 @@ class RWConfig(PretrainedConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_type="RefinedWeb",
|
model_type="RefinedWeb",
|
||||||
vocab_size=250880,
|
vocab_size=250880,
|
||||||
hidden_size=64,
|
hidden_size=64,
|
||||||
n_layer=2,
|
n_layer=2,
|
||||||
n_head=8,
|
n_head=8,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
hidden_dropout=0.0,
|
hidden_dropout=0.0,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
n_head_kv=None,
|
n_head_kv=None,
|
||||||
multi_query=False,
|
multi_query=False,
|
||||||
alibi=False,
|
alibi=False,
|
||||||
bias=False,
|
bias=False,
|
||||||
parallel_attn=False,
|
parallel_attn=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if alibi:
|
if alibi:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -126,18 +126,18 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -194,7 +194,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -264,15 +264,18 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
max_s,
|
end_seq,
|
||||||
layer_past,
|
start_seq_q,
|
||||||
past_present_indices,
|
end_seq_q,
|
||||||
cu_seqlens_q,
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
past_present_indices,
|
||||||
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
@ -287,12 +290,12 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = (
|
kv = (
|
||||||
kv.unsqueeze(2)
|
kv.unsqueeze(2)
|
||||||
@ -303,13 +306,15 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, :, 0],
|
torch.select(kv, dim=2, index=0),
|
||||||
kv[:, :, 1],
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -334,13 +339,15 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, :, 0],
|
torch.select(kv, dim=2, index=0),
|
||||||
kv[:, :, 1],
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -419,19 +426,19 @@ class FlashRWLayer(nn.Module):
|
|||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -509,16 +516,19 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
max_s,
|
end_seq,
|
||||||
layer_past,
|
start_seq_q,
|
||||||
past_present_indices,
|
end_seq_q,
|
||||||
cu_seqlens_q,
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
past_present_indices,
|
||||||
|
prefill,
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
@ -528,11 +538,14 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
ln_attn,
|
ln_attn,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
@ -570,7 +583,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
self.h[0].self_attention.head_size,
|
self.h[0].self_attention.head_size,
|
||||||
)
|
)
|
||||||
elif config.model_type == "RefinedWeb":
|
elif config.model_type == "RefinedWeb":
|
||||||
raise NotImplementedError
|
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLargeLayer(layer_id, config, weights)
|
FlashRWLargeLayer(layer_id, config, weights)
|
||||||
@ -596,17 +608,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
self.head_size = self.h[0].self_attention.head_size
|
self.head_size = self.h[0].self_attention.head_size
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
@ -617,7 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
prefill = True
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_zeros(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
len(self.h),
|
len(self.h),
|
||||||
@ -646,7 +658,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[:, i],
|
torch.select(past_key_values, dim=1, index=i),
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
@ -667,18 +679,18 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda_modif
|
import flash_attn_cuda
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -179,7 +179,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
@ -208,7 +208,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
@ -373,13 +373,7 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_zeros(
|
past_key_values = hidden_states.new_zeros(
|
||||||
(
|
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.h),
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
self.head_size
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -415,17 +409,17 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -82,11 +82,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -184,7 +184,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64)
|
request_past_present_indices = torch.arange(
|
||||||
|
cumulative_max_length,
|
||||||
|
cumulative_max_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
past_present_indices.append(request_past_present_indices)
|
past_present_indices.append(request_past_present_indices)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
@ -217,8 +221,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
||||||
|
|
||||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
start_seq_prefill = torch.tensor(
|
||||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
start_seq_prefill, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
end_seq_prefill = torch.tensor(
|
||||||
|
end_seq_prefill, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
@ -230,7 +238,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
|
past_present_indices = torch.tensor(
|
||||||
|
past_present_indices, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
@ -294,7 +304,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# past indices to keep
|
# past indices to keep
|
||||||
past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device)
|
past_indices = torch.zeros(
|
||||||
|
self.past_key_values.shape[0], dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||||
@ -332,14 +344,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
remaining_tokens = (
|
||||||
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
start_seq[i] = cumulative_max_length
|
start_seq[i] = cumulative_max_length
|
||||||
end_seq[i] = cumulative_max_length + request_input_length
|
end_seq[i] = cumulative_max_length + request_input_length
|
||||||
|
|
||||||
# Set slice
|
# Set slice
|
||||||
past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True
|
past_indices[
|
||||||
|
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
|
||||||
|
] = True
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||||
|
|
||||||
@ -480,7 +496,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
] = batch.all_input_ids_tensor[:, :max_length]
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
|
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
@ -523,12 +539,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_cls: Type[PreTrainedModel],
|
model_cls: Type[PreTrainedModel],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -569,18 +585,18 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
start_seq: torch.Tensor,
|
start_seq: torch.Tensor,
|
||||||
end_seq: torch.Tensor,
|
end_seq: torch.Tensor,
|
||||||
start_seq_q: Optional[torch.Tensor],
|
start_seq_q: Optional[torch.Tensor],
|
||||||
end_seq_q: Optional[torch.Tensor],
|
end_seq_q: Optional[torch.Tensor],
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_present_indices: torch.Tensor,
|
past_present_indices: torch.Tensor,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@ -599,7 +615,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
@ -647,7 +663,9 @@ class FlashCausalLM(Model):
|
|||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
# Create batch.start_seq_q and batch.end_seq_q for decode
|
# Create batch.start_seq_q and batch.end_seq_q for decode
|
||||||
batch.start_seq_q = torch.arange(0, len(batch), device=self.device, dtype=torch.int32)
|
batch.start_seq_q = torch.arange(
|
||||||
|
0, len(batch), device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
batch.end_seq_q = batch.start_seq_q + 1
|
batch.end_seq_q = batch.start_seq_q + 1
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
# We do not need start_seq_prefill and end_seq_prefill anymore
|
# We do not need start_seq_prefill and end_seq_prefill anymore
|
||||||
@ -783,7 +801,7 @@ class FlashCausalLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens:]
|
all_input_ids[-stopping_criteria.current_tokens :]
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
|
Loading…
Reference in New Issue
Block a user