mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Much better code with correct RMS + rotary:
Next: - precompute cos/sin - Fix KV layout - Refactor everything for flash (long).
This commit is contained in:
parent
21c15d576d
commit
fc02d99e57
@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass
|
||||
from transformers.modeling_utils import PretrainedConfig
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
@ -47,6 +47,15 @@ from text_generation_server.utils.layers import (
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
import dropout_layer_norm
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
@dataclass
|
||||
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# logger = logging.get_logger(__name__)
|
||||
@ -279,10 +288,30 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
|
||||
# this was adapted from LlamaRMSNorm
|
||||
# class IdeficsRMSNorm(nn.Module):
|
||||
# def __init__(self, prefix, weights, eps=1e-6):
|
||||
# """
|
||||
# IdeficsRMSNorm is equivalent to T5LayerNorm
|
||||
# """
|
||||
# super().__init__()
|
||||
#
|
||||
# weight = weights.get_tensor(f"{prefix}.weight")
|
||||
# self.weight = nn.Parameter(weight)
|
||||
# self.variance_epsilon = eps
|
||||
#
|
||||
# def forward(self, hidden_states):
|
||||
# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
#
|
||||
# # convert into half-precision if necessary
|
||||
# if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
# hidden_states = hidden_states.to(self.weight.dtype)
|
||||
#
|
||||
# return self.weight * hidden_states
|
||||
class IdeficsRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
IdeficsRMSNorm is equivalent to T5LayerNorm
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -290,15 +319,56 @@ class IdeficsRMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
return self.weight * hidden_states
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
else:
|
||||
# faster post attention rms norm
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
unwrap = True
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
if unwrap:
|
||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||
|
||||
|
||||
return normed_hidden_states
|
||||
|
||||
|
||||
# this was adapted from LlamaRotaryEmbedding
|
||||
@ -341,14 +411,14 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
# def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
# q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
# k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
# return q_embed, k_embed
|
||||
|
||||
|
||||
# this was adapted from LlamaMLP
|
||||
@ -438,11 +508,9 @@ class IdeficsAttention(nn.Module):
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||
)
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_dim, base=10000.0, device=weights.device
|
||||
)
|
||||
self.qk_layer_norms = qk_layer_norms
|
||||
if self.qk_layer_norms:
|
||||
self.q_layer_norm = IdeficsRMSNorm(
|
||||
@ -470,11 +538,29 @@ class IdeficsAttention(nn.Module):
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||
if not is_cross_attention:
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||
kv_seq_len = q_len
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
max_s = max(kv_seq_len, q_len)
|
||||
cos, sin = self.rotary_emb.get_cos_sin(
|
||||
position_ids.view(-1), max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
shape = query_states.shape
|
||||
query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape)
|
||||
|
||||
shape = key_states.shape
|
||||
key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
||||
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = (
|
||||
@ -484,14 +570,6 @@ class IdeficsAttention(nn.Module):
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if not is_cross_attention:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# cos, sin = self.rotary_emb.get_cos_sin(
|
||||
# position_ids=torch.arange(),
|
||||
# max_s=max(kv_seq_len, q_len),
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -951,13 +1029,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||
image_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> Union[Tuple, BaseModelOutputWithPastImage]:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -999,30 +1078,37 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
no_images = False
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
||||
|
||||
elif pixel_values is not None and image_embeddings is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
||||
if image_hidden_states is None:
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
||||
|
||||
elif pixel_values is not None:
|
||||
no_images = len(torch.nonzero(pixel_values)) == 0
|
||||
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
|
||||
batch_size, num_images = pixel_values.shape[:2]
|
||||
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
elif pixel_values is not None and image_embeddings is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
||||
elif pixel_values is not None:
|
||||
no_images = len(torch.nonzero(pixel_values)) == 0
|
||||
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
|
||||
batch_size, num_images = pixel_values.shape[:2]
|
||||
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
elif image_embeddings is not None:
|
||||
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
||||
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
||||
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
||||
|
||||
elif image_embeddings is not None:
|
||||
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
||||
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
||||
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
||||
|
||||
if self.config.use_resampler:
|
||||
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
||||
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
||||
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
||||
else:
|
||||
no_images = True
|
||||
num_images = pixel_values.shape[1]
|
||||
image_seq_len = image_hidden_states.shape[1] // num_images
|
||||
|
||||
if self.config.use_resampler:
|
||||
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
||||
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
||||
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
||||
# # Hack to use the model in full language modeling mode
|
||||
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
|
||||
# Make image_attention_mask compatible with hidden states
|
||||
@ -1030,15 +1116,19 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
||||
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
||||
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
||||
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
||||
image_hidden_shape = (image_batch_size, image_sequence_length)
|
||||
if image_attention_mask is None:
|
||||
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
||||
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
||||
|
||||
if image_hidden_states is not None:
|
||||
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
||||
image_hidden_shape = (image_batch_size, image_sequence_length)
|
||||
if image_attention_mask is None:
|
||||
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
||||
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
||||
else:
|
||||
image_attention_mask = None
|
||||
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
|
||||
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
|
||||
|
||||
|
||||
# if image_hidden_states is not None:
|
||||
# else:
|
||||
# image_attention_mask = None
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
@ -1170,11 +1260,12 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
return BaseModelOutputWithPastImage(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
image_hidden_states=image_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@ -1204,13 +1295,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
image_attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> Union[Tuple, CausalLMOutputWithPastImage]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1244,11 +1336,6 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {input_ids.size()=}")
|
||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {attention_mask.size()=}")
|
||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {position_ids.size()=}")
|
||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {pixel_values.size()=} {pixel_values.sum()=}")
|
||||
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {image_attention_mask.size()=} {image_attention_mask.sum()=}")
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1257,6 +1344,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_embeddings=image_embeddings,
|
||||
image_hidden_states=image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -1268,29 +1356,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
# if labels is not None:
|
||||
# # Shift so that tokens < n predict n
|
||||
# if attention_mask is not None:
|
||||
# shift_attention_mask = attention_mask[..., 1:]
|
||||
# shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
||||
# shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
||||
# else:
|
||||
# shift_logits = logits[..., :-1, :].contiguous()
|
||||
# shift_labels = labels[..., 1:].contiguous()
|
||||
# # Flatten the tokens
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
# if not return_dict:
|
||||
# output = (logits,) + outputs[1:]
|
||||
# return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
return CausalLMOutputWithPastImage(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
|
@ -82,31 +82,3 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pixel_values: Optional = None,
|
||||
image_attention_mask: Optional = None,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return (
|
||||
outputs.logits,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
@ -64,6 +64,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
attention_mask: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
pixel_values: Optional[torch.Tensor]
|
||||
image_hidden_states: Optional[torch.Tensor]
|
||||
image_attention_mask: Optional[torch.Tensor]
|
||||
past_key_values: Optional[List[Tuple]]
|
||||
|
||||
@ -118,7 +119,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
@ -135,7 +136,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
|
||||
if isinstance(inp, str):
|
||||
prompts.append([inp])
|
||||
elif isinstance(inp, list):
|
||||
@ -168,7 +169,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
max_length=max_truncation,
|
||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
|
||||
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
|
||||
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
|
||||
for _ in pb.requests:
|
||||
@ -181,6 +182,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
pixel_values = tokenized_inputs["pixel_values"]
|
||||
image_hidden_states = None
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_input_length + padding_right_offset)
|
||||
@ -192,7 +194,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
||||
)
|
||||
# image_attention_mask = tokenized_inputs["image_attention_mask"]
|
||||
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
@ -209,6 +211,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hidden_states=image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=None,
|
||||
all_input_ids=list(all_input_ids),
|
||||
@ -224,7 +227,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
||||
from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
|
||||
# from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
|
||||
# It deletes requests from the batch. For instance when client lost connection
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
@ -336,7 +339,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
||||
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
|
||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
|
||||
# It adds new requests to the batch
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
@ -433,8 +436,8 @@ class IdeficsCausalLMBatch(Batch):
|
||||
:,
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
|
||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
|
||||
image_attention_mask[
|
||||
start_index:end_index,
|
||||
left_offset:-padding_right_offset,
|
||||
@ -648,6 +651,7 @@ class IdeficsCausalLM(Model):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pixel_values,
|
||||
image_hidden_states,
|
||||
image_attention_mask,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
@ -656,6 +660,7 @@ class IdeficsCausalLM(Model):
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"image_hidden_states": image_hidden_states,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
@ -665,13 +670,13 @@ class IdeficsCausalLM(Model):
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: IdeficsCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
||||
from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
|
||||
# from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
@ -683,19 +688,21 @@ class IdeficsCausalLM(Model):
|
||||
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
|
||||
# image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset]
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
|
||||
|
||||
logits, past = self.forward(
|
||||
logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=batch.position_ids,
|
||||
pixel_values=batch.pixel_values,
|
||||
image_hidden_states=batch.image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=batch.past_key_values,
|
||||
)
|
||||
@ -806,7 +813,7 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
@ -819,7 +826,7 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
@ -832,6 +839,7 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
batch.image_hidden_states = image_hidden_states
|
||||
|
||||
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
|
||||
return generations, batch
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
PROFILE = False
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
@ -27,6 +28,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
if PROFILE:
|
||||
self.prof = torch.profiler.profile(
|
||||
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
|
||||
record_shapes=True,
|
||||
with_stack=True
|
||||
)
|
||||
self.prof.start()
|
||||
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
||||
@ -80,6 +90,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
if PROFILE:
|
||||
self.prof.step()
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
@ -107,7 +119,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = batches[0]
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
if PROFILE:
|
||||
self.prof.step()
|
||||
self.cache.set(next_batch)
|
||||
if next_batch is None:
|
||||
if PROFILE:
|
||||
self.prof.stop()
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
|
Loading…
Reference in New Issue
Block a user