mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Much better code with correct RMS + rotary:
Next: - precompute cos/sin - Fix KV layout - Refactor everything for flash (long).
This commit is contained in:
parent
21c15d576d
commit
fc02d99e57
@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
|
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass
|
||||||
from transformers.modeling_utils import PretrainedConfig
|
from transformers.modeling_utils import PretrainedConfig
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@ -47,6 +47,15 @@ from text_generation_server.utils.layers import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||||
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
|
||||||
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
# logger = logging.get_logger(__name__)
|
# logger = logging.get_logger(__name__)
|
||||||
@ -279,10 +288,30 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaRMSNorm
|
# this was adapted from LlamaRMSNorm
|
||||||
|
# class IdeficsRMSNorm(nn.Module):
|
||||||
|
# def __init__(self, prefix, weights, eps=1e-6):
|
||||||
|
# """
|
||||||
|
# IdeficsRMSNorm is equivalent to T5LayerNorm
|
||||||
|
# """
|
||||||
|
# super().__init__()
|
||||||
|
#
|
||||||
|
# weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
# self.weight = nn.Parameter(weight)
|
||||||
|
# self.variance_epsilon = eps
|
||||||
|
#
|
||||||
|
# def forward(self, hidden_states):
|
||||||
|
# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
|
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
#
|
||||||
|
# # convert into half-precision if necessary
|
||||||
|
# if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
# hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
#
|
||||||
|
# return self.weight * hidden_states
|
||||||
class IdeficsRMSNorm(nn.Module):
|
class IdeficsRMSNorm(nn.Module):
|
||||||
def __init__(self, prefix, weights, eps=1e-6):
|
def __init__(self, prefix, weights, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
IdeficsRMSNorm is equivalent to T5LayerNorm
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -290,15 +319,56 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, residual=None):
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
if hidden_states.shape[-1] > 8192:
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states * torch.rsqrt(
|
||||||
|
variance + self.variance_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
return self.weight * hidden_states
|
# convert into half-precision if necessary
|
||||||
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
|
return self.weight * hidden_states
|
||||||
|
else:
|
||||||
|
# faster post attention rms norm
|
||||||
|
unwrap = False
|
||||||
|
if len(hidden_states.shape) > 2:
|
||||||
|
unwrap = True
|
||||||
|
shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||||
|
|
||||||
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.variance_epsilon,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
True, # Activate RMSNorm
|
||||||
|
)
|
||||||
|
if res is None:
|
||||||
|
res = hidden_states
|
||||||
|
|
||||||
|
if unwrap:
|
||||||
|
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
return normed_hidden_states
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaRotaryEmbedding
|
# this was adapted from LlamaRotaryEmbedding
|
||||||
@ -341,14 +411,14 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
# def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
# q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
# k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
# return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaMLP
|
# this was adapted from LlamaMLP
|
||||||
@ -438,11 +508,9 @@ class IdeficsAttention(nn.Module):
|
|||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
# prefix=f"{prefix}.rotary_emb", weights=weights
|
config=config, dim=self.head_dim, base=10000.0, device=weights.device
|
||||||
# )
|
)
|
||||||
self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
|
|
||||||
|
|
||||||
self.qk_layer_norms = qk_layer_norms
|
self.qk_layer_norms = qk_layer_norms
|
||||||
if self.qk_layer_norms:
|
if self.qk_layer_norms:
|
||||||
self.q_layer_norm = IdeficsRMSNorm(
|
self.q_layer_norm = IdeficsRMSNorm(
|
||||||
@ -470,11 +538,29 @@ class IdeficsAttention(nn.Module):
|
|||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||||
if not is_cross_attention:
|
if not is_cross_attention:
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||||
|
kv_seq_len = q_len
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
max_s = max(kv_seq_len, q_len)
|
||||||
|
cos, sin = self.rotary_emb.get_cos_sin(
|
||||||
|
position_ids.view(-1), max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
shape = query_states.shape
|
||||||
|
query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape)
|
||||||
|
|
||||||
|
shape = key_states.shape
|
||||||
|
key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
||||||
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = (
|
value_states = (
|
||||||
@ -484,14 +570,6 @@ class IdeficsAttention(nn.Module):
|
|||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
if not is_cross_attention:
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
# cos, sin = self.rotary_emb.get_cos_sin(
|
|
||||||
# position_ids=torch.arange(),
|
|
||||||
# max_s=max(kv_seq_len, q_len),
|
|
||||||
# dtype=hidden_states.dtype,
|
|
||||||
# )
|
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -951,13 +1029,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||||
image_attention_mask: Optional[torch.Tensor] = None,
|
image_attention_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPastImage]:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
@ -999,30 +1078,37 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
no_images = False
|
no_images = False
|
||||||
if pixel_values is None and image_embeddings is None:
|
|
||||||
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
if image_hidden_states is None:
|
||||||
|
if pixel_values is None and image_embeddings is None:
|
||||||
|
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
||||||
|
|
||||||
elif pixel_values is not None and image_embeddings is not None:
|
elif pixel_values is not None and image_embeddings is not None:
|
||||||
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
||||||
|
|
||||||
elif pixel_values is not None:
|
elif pixel_values is not None:
|
||||||
no_images = len(torch.nonzero(pixel_values)) == 0
|
no_images = len(torch.nonzero(pixel_values)) == 0
|
||||||
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
|
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
|
||||||
batch_size, num_images = pixel_values.shape[:2]
|
batch_size, num_images = pixel_values.shape[:2]
|
||||||
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
||||||
|
|
||||||
elif image_embeddings is not None:
|
elif image_embeddings is not None:
|
||||||
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
||||||
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
||||||
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
||||||
|
|
||||||
|
if self.config.use_resampler:
|
||||||
|
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
||||||
|
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
||||||
|
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
||||||
|
else:
|
||||||
|
no_images = True
|
||||||
|
num_images = pixel_values.shape[1]
|
||||||
|
image_seq_len = image_hidden_states.shape[1] // num_images
|
||||||
|
|
||||||
if self.config.use_resampler:
|
|
||||||
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
|
||||||
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
|
||||||
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
|
||||||
# # Hack to use the model in full language modeling mode
|
# # Hack to use the model in full language modeling mode
|
||||||
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
|
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
|
||||||
# Make image_attention_mask compatible with hidden states
|
# Make image_attention_mask compatible with hidden states
|
||||||
@ -1030,15 +1116,19 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
||||||
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
||||||
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
||||||
|
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
||||||
|
image_hidden_shape = (image_batch_size, image_sequence_length)
|
||||||
|
if image_attention_mask is None:
|
||||||
|
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
||||||
|
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
||||||
|
|
||||||
if image_hidden_states is not None:
|
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
|
||||||
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
|
||||||
image_hidden_shape = (image_batch_size, image_sequence_length)
|
|
||||||
if image_attention_mask is None:
|
|
||||||
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
# if image_hidden_states is not None:
|
||||||
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
# else:
|
||||||
else:
|
# image_attention_mask = None
|
||||||
image_attention_mask = None
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -1170,11 +1260,12 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPastImage(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
image_hidden_states=image_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1204,13 +1295,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||||
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
image_attention_mask: Optional[torch.Tensor] = None,
|
image_attention_mask: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPastImage]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -1244,11 +1336,6 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {input_ids.size()=}")
|
|
||||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {attention_mask.size()=}")
|
|
||||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {position_ids.size()=}")
|
|
||||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {pixel_values.size()=} {pixel_values.sum()=}")
|
|
||||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {image_attention_mask.size()=} {image_attention_mask.sum()=}")
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -1257,6 +1344,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_embeddings=image_embeddings,
|
image_embeddings=image_embeddings,
|
||||||
|
image_hidden_states=image_hidden_states,
|
||||||
image_attention_mask=image_attention_mask,
|
image_attention_mask=image_attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@ -1268,29 +1356,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
# if labels is not None:
|
|
||||||
# # Shift so that tokens < n predict n
|
|
||||||
# if attention_mask is not None:
|
|
||||||
# shift_attention_mask = attention_mask[..., 1:]
|
|
||||||
# shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
|
||||||
# shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
|
||||||
# else:
|
|
||||||
# shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
# shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# # Flatten the tokens
|
|
||||||
# loss_fct = CrossEntropyLoss()
|
|
||||||
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
||||||
|
|
||||||
# if not return_dict:
|
return CausalLMOutputWithPastImage(
|
||||||
# output = (logits,) + outputs[1:]
|
|
||||||
# return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=outputs.image_hidden_states
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
@ -82,31 +82,3 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
pixel_values: Optional = None,
|
|
||||||
image_attention_mask: Optional = None,
|
|
||||||
past_key_values: Optional = None,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.Tensor,
|
|
||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
||||||
]:
|
|
||||||
# Model Forward
|
|
||||||
outputs = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
image_attention_mask=image_attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
outputs.logits,
|
|
||||||
outputs.past_key_values,
|
|
||||||
)
|
|
||||||
|
@ -64,6 +64,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
pixel_values: Optional[torch.Tensor]
|
pixel_values: Optional[torch.Tensor]
|
||||||
|
image_hidden_states: Optional[torch.Tensor]
|
||||||
image_attention_mask: Optional[torch.Tensor]
|
image_attention_mask: Optional[torch.Tensor]
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
|
||||||
@ -118,7 +119,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
|
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
@ -135,7 +136,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
prompts = []
|
prompts = []
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
|
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
|
||||||
if isinstance(inp, str):
|
if isinstance(inp, str):
|
||||||
prompts.append([inp])
|
prompts.append([inp])
|
||||||
elif isinstance(inp, list):
|
elif isinstance(inp, list):
|
||||||
@ -168,7 +169,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||||
).to(device)
|
).to(device)
|
||||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
|
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
|
||||||
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
|
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
|
||||||
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
|
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
@ -181,6 +182,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
pixel_values = tokenized_inputs["pixel_values"]
|
pixel_values = tokenized_inputs["pixel_values"]
|
||||||
|
image_hidden_states = None
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = input_ids.new_zeros(
|
attention_mask = input_ids.new_zeros(
|
||||||
(pb.size, max_input_length + padding_right_offset)
|
(pb.size, max_input_length + padding_right_offset)
|
||||||
@ -192,7 +194,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
||||||
)
|
)
|
||||||
# image_attention_mask = tokenized_inputs["image_attention_mask"]
|
# image_attention_mask = tokenized_inputs["image_attention_mask"]
|
||||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||||
|
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
@ -209,6 +211,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
|
image_hidden_states=image_hidden_states,
|
||||||
image_attention_mask=image_attention_mask,
|
image_attention_mask=image_attention_mask,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
all_input_ids=list(all_input_ids),
|
||||||
@ -224,7 +227,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
||||||
from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
|
# from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
|
||||||
# It deletes requests from the batch. For instance when client lost connection
|
# It deletes requests from the batch. For instance when client lost connection
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
@ -336,7 +339,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
||||||
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
|
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
|
||||||
# It adds new requests to the batch
|
# It adds new requests to the batch
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
@ -433,8 +436,8 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
:,
|
:,
|
||||||
batch_left_offset : -batch.padding_right_offset,
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
]
|
]
|
||||||
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||||
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
|
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
|
||||||
image_attention_mask[
|
image_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
left_offset:-padding_right_offset,
|
left_offset:-padding_right_offset,
|
||||||
@ -648,6 +651,7 @@ class IdeficsCausalLM(Model):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
|
image_hidden_states,
|
||||||
image_attention_mask,
|
image_attention_mask,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
@ -656,6 +660,7 @@ class IdeficsCausalLM(Model):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
|
"image_hidden_states": image_hidden_states,
|
||||||
"image_attention_mask": image_attention_mask,
|
"image_attention_mask": image_attention_mask,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
@ -665,13 +670,13 @@ class IdeficsCausalLM(Model):
|
|||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: IdeficsCausalLMBatch
|
self, batch: IdeficsCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
||||||
from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
|
# from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
if batch.input_ids.size(1) == 1:
|
if batch.input_ids.size(1) == 1:
|
||||||
@ -683,19 +688,21 @@ class IdeficsCausalLM(Model):
|
|||||||
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
||||||
else:
|
else:
|
||||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
|
# image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset]
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
|
||||||
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, past, image_hidden_states = self.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=batch.position_ids,
|
position_ids=batch.position_ids,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
|
image_hidden_states=batch.image_hidden_states,
|
||||||
image_attention_mask=image_attention_mask,
|
image_attention_mask=image_attention_mask,
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
)
|
)
|
||||||
@ -806,7 +813,7 @@ class IdeficsCausalLM(Model):
|
|||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
@ -819,7 +826,7 @@ class IdeficsCausalLM(Model):
|
|||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
@ -832,6 +839,7 @@ class IdeficsCausalLM(Model):
|
|||||||
|
|
||||||
# Update past key values
|
# Update past key values
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
batch.image_hidden_states = image_hidden_states
|
||||||
|
|
||||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
|
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
|||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
|
||||||
|
PROFILE = False
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
@ -27,6 +28,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
# Force inference mode for the lifetime of TextGenerationService
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
|
if PROFILE:
|
||||||
|
self.prof = torch.profiler.profile(
|
||||||
|
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True
|
||||||
|
)
|
||||||
|
self.prof.start()
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
|
||||||
@ -80,6 +90,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
if PROFILE:
|
||||||
|
self.prof.step()
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
@ -107,7 +119,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
if PROFILE:
|
||||||
|
self.prof.step()
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
if next_batch is None:
|
||||||
|
if PROFILE:
|
||||||
|
self.prof.stop()
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
|
Loading…
Reference in New Issue
Block a user