mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-20 06:40:19 +00:00
parent
343437c7b5
commit
db4cb5e4ed
@ -25,6 +25,7 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
len(hidden_states),
|
len(hidden_states)
|
||||||
|
if pre_allocate_past_size is None
|
||||||
|
else pre_allocate_past_size,
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
layer_past_present_indices = None
|
||||||
cu_seqlens_q = None
|
cu_seqlens_q = None
|
||||||
|
slice_past_index = len(hidden_states)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
# Create indices from cumulative sequence lengths
|
||||||
@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = torch.arange(
|
||||||
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
|
slice_past_index = None
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
# We added padding that now need to slice
|
||||||
|
layer_past_key_values = (
|
||||||
|
past_key_values[i]
|
||||||
|
if slice_past_index is None
|
||||||
|
else past_key_values[i, :slice_past_index]
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[i],
|
layer_past_key_values,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.model(
|
hidden_states, present = self.model(
|
||||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
len(hidden_states),
|
len(hidden_states)
|
||||||
|
if pre_allocate_past_size is None
|
||||||
|
else pre_allocate_past_size,
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
layer_past_present_indices = None
|
||||||
cu_seqlens_q = None
|
cu_seqlens_q = None
|
||||||
|
slice_past_index = len(hidden_states)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
# Create indices from cumulative sequence lengths
|
||||||
@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = torch.arange(
|
||||||
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
|
slice_past_index = None
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
# We added padding that now need to slice
|
||||||
|
layer_past_key_values = (
|
||||||
|
past_key_values[i]
|
||||||
|
if slice_past_index is None
|
||||||
|
else past_key_values[i, :slice_past_index]
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[i],
|
layer_past_key_values,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states, present = self.gpt_neox(
|
||||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
logits = self.embed_out(hidden_states)
|
logits = self.embed_out(hidden_states)
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
if self.tp_embeddings:
|
if self.tp_embeddings:
|
||||||
@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
len(self.h),
|
len(self.h),
|
||||||
len(hidden_states),
|
len(hidden_states)
|
||||||
|
if pre_allocate_past_size is None
|
||||||
|
else pre_allocate_past_size,
|
||||||
2,
|
2,
|
||||||
1,
|
1,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
layer_past_present_indices = None
|
||||||
cu_seqlens_q = None
|
cu_seqlens_q = None
|
||||||
|
slice_past_index = len(hidden_states)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
# Create indices from cumulative sequence lengths
|
||||||
@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = torch.arange(
|
||||||
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
|
slice_past_index = None
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
|
# We added padding that now need to slice
|
||||||
|
layer_past_key_values = (
|
||||||
|
past_key_values[i]
|
||||||
|
if slice_past_index is None
|
||||||
|
else past_key_values[i, :slice_past_index]
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[i],
|
layer_past_key_values,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
past_pad=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlens.append(cumulative_length + request_input_length)
|
cu_seqlens.append(cumulative_length + request_input_length)
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
if not single_request:
|
if not single_request:
|
||||||
|
# True index for past
|
||||||
past_key_values.append(self.past_key_values[2 * idx])
|
past_key_values.append(self.past_key_values[2 * idx])
|
||||||
past_key_values.append(self.past_key_values[1])
|
# Add one padding
|
||||||
|
past_key_values.append(self.past_pad)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||||
@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Preallocate tensor for bs = 1 case
|
# Preallocate tensor for bs = 1 case
|
||||||
past_key_values = torch.nn.functional.pad(
|
past_key_values = torch.nn.functional.pad(
|
||||||
self.past_key_values[0],
|
self.past_key_values[0],
|
||||||
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens)
|
(
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
stopping_criterias[0].max_new_tokens
|
||||||
|
- stopping_criterias[0].current_tokens,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
||||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||||
|
|
||||||
if len(batch) != 1:
|
if len(batch) != 1:
|
||||||
past_key_values.extend(batch.past_key_values)
|
past_key_values.extend(batch.past_key_values)
|
||||||
else:
|
else:
|
||||||
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]])
|
# past was pre-allocated for this batch
|
||||||
|
# We need to slice to remove the padding
|
||||||
|
past_key_values.append(
|
||||||
|
batch.past_key_values[:, : batch.input_lengths[0]]
|
||||||
|
)
|
||||||
|
# Add one padding
|
||||||
past_key_values.append(batch.past_pad)
|
past_key_values.append(batch.past_pad)
|
||||||
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
@ -366,6 +385,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@ -374,6 +394,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
pre_allocate_past_size=pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
@ -382,7 +403,9 @@ class FlashCausalLM(Model):
|
|||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
# Shortcut when batch_size == 1
|
# Shortcut when batch_size == 1
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
# No need to slice this down
|
input_ids = batch.input_ids[0].view(-1)
|
||||||
|
# Slice to remove extra padding
|
||||||
|
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
|
||||||
past_key_values = batch.past_key_values
|
past_key_values = batch.past_key_values
|
||||||
else:
|
else:
|
||||||
# Concatenate tensors
|
# Concatenate tensors
|
||||||
@ -393,6 +416,16 @@ class FlashCausalLM(Model):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if prefill and bs == 1
|
||||||
|
if past_key_values is None and len(batch) == 1:
|
||||||
|
# Ask to pre-allocate kv to its max size
|
||||||
|
# == number of tokens + max_new_tokens
|
||||||
|
pre_allocate_past_size = (
|
||||||
|
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pre_allocate_past_size = None
|
||||||
|
|
||||||
# Concatenate when prefill, torch.tensor when decode
|
# Concatenate when prefill, torch.tensor when decode
|
||||||
position_ids = (
|
position_ids = (
|
||||||
torch.tensor(batch.position_ids, device=self.device)
|
torch.tensor(batch.position_ids, device=self.device)
|
||||||
@ -409,21 +442,28 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize past_key_values in prefill
|
# Initialize past_key_values in prefill
|
||||||
if batch.past_key_values is None:
|
if batch.past_key_values is None:
|
||||||
# Initialize past padding tensor
|
# Initialize past padding tensor
|
||||||
if self.past_pad is None:
|
if self.past_pad is None:
|
||||||
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
self.past_pad = present.new_zeros(
|
||||||
|
present.shape[0], 1, *present.shape[2:]
|
||||||
|
)
|
||||||
# Set in batch in case it needs to be used later in concatenate()
|
# Set in batch in case it needs to be used later in concatenate()
|
||||||
batch.past_pad = self.past_pad
|
batch.past_pad = self.past_pad
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
# Preallocate tensor for bs = 1 case
|
# Preallocate tensor for bs = 1 case
|
||||||
batch.past_key_values = torch.nn.functional.pad(
|
batch.past_key_values = torch.nn.functional.pad(
|
||||||
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
|
present,
|
||||||
|
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Add padding after each sequence
|
||||||
|
# This will have the correct shape after the final past_key_values concatenation before the model
|
||||||
|
# forward
|
||||||
batch.past_key_values = [None, self.past_pad] * len(batch)
|
batch.past_key_values = [None, self.past_pad] * len(batch)
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
@ -555,6 +595,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
||||||
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
||||||
if len(batch) != 1:
|
if len(batch) != 1:
|
||||||
|
# Add each sequence before its padding
|
||||||
batch.past_key_values[i * 2] = present[:, start_index:end_index]
|
batch.past_key_values[i * 2] = present[:, start_index:end_index]
|
||||||
# Cumulative sum
|
# Cumulative sum
|
||||||
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
||||||
|
@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__)
|
|||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
|
self.past_pad = None
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__)
|
|||||||
|
|
||||||
class FlashSantacoder(FlashCausalLM):
|
class FlashSantacoder(FlashCausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
|
self.past_pad = None
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
Loading…
Reference in New Issue
Block a user