Much better code with correct RMS + rotary:

Next:
- precompute cos/sin
- Fix KV layout
- Refactor everything for flash (long).
This commit is contained in:
Nicolas Patry 2023-08-15 11:10:52 +00:00
parent 21c15d576d
commit fc02d99e57
4 changed files with 204 additions and 134 deletions

View File

@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass
from transformers.modeling_utils import PretrainedConfig
from transformers.utils import (
add_start_docstrings,
@ -47,6 +47,15 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
FastLinear,
)
import dropout_layer_norm
@dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None
# logger = logging.get_logger(__name__)
@ -279,10 +288,30 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# this was adapted from LlamaRMSNorm
# class IdeficsRMSNorm(nn.Module):
# def __init__(self, prefix, weights, eps=1e-6):
# """
# IdeficsRMSNorm is equivalent to T5LayerNorm
# """
# super().__init__()
#
# weight = weights.get_tensor(f"{prefix}.weight")
# self.weight = nn.Parameter(weight)
# self.variance_epsilon = eps
#
# def forward(self, hidden_states):
# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
#
# # convert into half-precision if necessary
# if self.weight.dtype in [torch.float16, torch.bfloat16]:
# hidden_states = hidden_states.to(self.weight.dtype)
#
# return self.weight * hidden_states
class IdeficsRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
IdeficsRMSNorm is equivalent to T5LayerNorm
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
@ -290,15 +319,56 @@ class IdeficsRMSNorm(nn.Module):
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
return self.weight * hidden_states
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
else:
# faster post attention rms norm
unwrap = False
if len(hidden_states.shape) > 2:
unwrap = True
shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1])
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
if unwrap:
normed_hidden_states = normed_hidden_states.view(*shape)
return normed_hidden_states
# this was adapted from LlamaRotaryEmbedding
@ -341,14 +411,14 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
# q_embed = (q * cos) + (rotate_half(q) * sin)
# k_embed = (k * cos) + (rotate_half(k) * sin)
# return q_embed, k_embed
# this was adapted from LlamaMLP
@ -438,11 +508,9 @@ class IdeficsAttention(nn.Module):
self.o_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
)
# self.rotary_emb = PositionRotaryEmbedding.load(
# prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_dim, base=10000.0, device=weights.device
)
self.qk_layer_norms = qk_layer_norms
if self.qk_layer_norms:
self.q_layer_norm = IdeficsRMSNorm(
@ -470,11 +538,29 @@ class IdeficsAttention(nn.Module):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
if not is_cross_attention:
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
max_s = max(kv_seq_len, q_len)
cos, sin = self.rotary_emb.get_cos_sin(
position_ids.view(-1), max_s, hidden_states.dtype
)
shape = query_states.shape
query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape)
shape = key_states.shape
key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
else:
query_states = query_states.transpose(1, 2)
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = (
@ -484,14 +570,6 @@ class IdeficsAttention(nn.Module):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if not is_cross_attention:
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# cos, sin = self.rotary_emb.get_cos_sin(
# position_ids=torch.arange(),
# max_s=max(kv_seq_len, q_len),
# dtype=hidden_states.dtype,
# )
# [bsz, nh, t, hd]
if past_key_value is not None:
@ -951,13 +1029,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> Union[Tuple, BaseModelOutputWithPastImage]:
device = input_ids.device if input_ids is not None else inputs_embeds.device
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -999,30 +1078,37 @@ class IdeficsModel(IdeficsPreTrainedModel):
position_ids = position_ids.view(-1, seq_length).long()
no_images = False
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
elif pixel_values is not None and image_embeddings is not None:
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
if image_hidden_states is None:
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
elif pixel_values is not None:
no_images = len(torch.nonzero(pixel_values)) == 0
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
batch_size, num_images = pixel_values.shape[:2]
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
elif pixel_values is not None and image_embeddings is not None:
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
elif pixel_values is not None:
no_images = len(torch.nonzero(pixel_values)) == 0
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
batch_size, num_images = pixel_values.shape[:2]
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
elif image_embeddings is not None:
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
elif image_embeddings is not None:
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
if self.config.use_resampler:
image_hidden_states = self.perceiver_resampler(image_hidden_states)
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
else:
no_images = True
num_images = pixel_values.shape[1]
image_seq_len = image_hidden_states.shape[1] // num_images
if self.config.use_resampler:
image_hidden_states = self.perceiver_resampler(image_hidden_states)
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
# # Hack to use the model in full language modeling mode
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
# Make image_attention_mask compatible with hidden states
@ -1030,15 +1116,19 @@ class IdeficsModel(IdeficsPreTrainedModel):
image_attention_mask = image_attention_mask.unsqueeze(-1)
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
image_hidden_shape = (image_batch_size, image_sequence_length)
if image_attention_mask is None:
image_attention_mask = torch.ones(image_hidden_shape, device=device)
image_attention_mask = self.invert_attention_mask(image_attention_mask)
if image_hidden_states is not None:
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
image_hidden_shape = (image_batch_size, image_sequence_length)
if image_attention_mask is None:
image_attention_mask = torch.ones(image_hidden_shape, device=device)
image_attention_mask = self.invert_attention_mask(image_attention_mask)
else:
image_attention_mask = None
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
# if image_hidden_states is not None:
# else:
# image_attention_mask = None
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -1170,11 +1260,12 @@ class IdeficsModel(IdeficsPreTrainedModel):
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
return BaseModelOutputWithPastImage(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
image_hidden_states=image_hidden_states,
)
@ -1204,13 +1295,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> Union[Tuple, CausalLMOutputWithPastImage]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1244,11 +1336,6 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {input_ids.size()=}")
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {attention_mask.size()=}")
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {position_ids.size()=}")
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {pixel_values.size()=} {pixel_values.sum()=}")
from loguru import logger; logger.info(f"forward in idefics_modeling.py - {image_attention_mask.size()=} {image_attention_mask.sum()=}")
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
@ -1257,6 +1344,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_embeddings=image_embeddings,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
@ -1268,29 +1356,14 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
logits = self.lm_head(hidden_states)
loss = None
# if labels is not None:
# # Shift so that tokens < n predict n
# if attention_mask is not None:
# shift_attention_mask = attention_mask[..., 1:]
# shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
# shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
# else:
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# # Flatten the tokens
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# if not return_dict:
# output = (logits,) + outputs[1:]
# return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
return CausalLMOutputWithPastImage(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states
)
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):

View File

@ -82,31 +82,3 @@ class IDEFICSSharded(IdeficsCausalLM):
rank=rank,
world_size=world_size,
)
def forward(
self,
input_ids,
attention_mask,
position_ids,
pixel_values: Optional = None,
image_attention_mask: Optional = None,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
outputs.past_key_values,
)

View File

@ -64,6 +64,7 @@ class IdeficsCausalLMBatch(Batch):
attention_mask: torch.Tensor
position_ids: torch.Tensor
pixel_values: Optional[torch.Tensor]
image_hidden_states: Optional[torch.Tensor]
image_attention_mask: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
@ -118,7 +119,7 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
@ -135,7 +136,7 @@ class IdeficsCausalLMBatch(Batch):
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
if isinstance(inp, str):
prompts.append([inp])
elif isinstance(inp, list):
@ -168,7 +169,7 @@ class IdeficsCausalLMBatch(Batch):
max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
for _ in pb.requests:
@ -181,6 +182,7 @@ class IdeficsCausalLMBatch(Batch):
input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs["pixel_values"]
image_hidden_states = None
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
@ -192,7 +194,7 @@ class IdeficsCausalLMBatch(Batch):
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
)
# image_attention_mask = tokenized_inputs["image_attention_mask"]
from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
@ -209,6 +211,7 @@ class IdeficsCausalLMBatch(Batch):
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=None,
all_input_ids=list(all_input_ids),
@ -224,7 +227,7 @@ class IdeficsCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
# from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
@ -336,7 +339,7 @@ class IdeficsCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
# It adds new requests to the batch
# Used for padding
total_batch_size = 0
@ -433,8 +436,8 @@ class IdeficsCausalLMBatch(Batch):
:,
batch_left_offset : -batch.padding_right_offset,
]
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
@ -648,6 +651,7 @@ class IdeficsCausalLM(Model):
attention_mask,
position_ids,
pixel_values,
image_hidden_states,
image_attention_mask,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -656,6 +660,7 @@ class IdeficsCausalLM(Model):
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_hidden_states": image_hidden_states,
"image_attention_mask": image_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
@ -665,13 +670,13 @@ class IdeficsCausalLM(Model):
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
# from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1:
@ -683,19 +688,21 @@ class IdeficsCausalLM(Model):
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
else:
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
# image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset]
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
logits, past = self.forward(
logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
attention_mask=attention_mask,
position_ids=batch.position_ids,
pixel_values=batch.pixel_values,
image_hidden_states=batch.image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values,
)
@ -806,7 +813,7 @@ class IdeficsCausalLM(Model):
# Update values
batch.input_ids[i, 0] = next_token_id
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
@ -819,7 +826,7 @@ class IdeficsCausalLM(Model):
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
@ -832,6 +839,7 @@ class IdeficsCausalLM(Model):
# Update past key values
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
return generations, batch

View File

@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
PROFILE = False
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
@ -27,6 +28,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
if PROFILE:
self.prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
record_shapes=True,
with_stack=True
)
self.prof.start()
async def Info(self, request, context):
return self.model.info
@ -80,6 +90,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
generations, next_batch = self.model.generate_token(batch)
if PROFILE:
self.prof.step()
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
@ -107,7 +119,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0]
generations, next_batch = self.model.generate_token(batch)
if PROFILE:
self.prof.step()
self.cache.set(next_batch)
if next_batch is None:
if PROFILE:
self.prof.stop()
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],