Preprocessing.

This commit is contained in:
Nicolas Patry 2024-09-18 17:59:13 +02:00
parent 907906466a
commit 39d2073e93
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 61 additions and 33 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -23,11 +23,6 @@ import math
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -40,7 +35,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
FastLinear, FastLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision

View File

@ -120,33 +120,67 @@ class IdeficsCausalLMBatch(Batch):
) )
# TODO Check impact on idefics # TODO Check impact on idefics
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)
# The processor replaces the call to tokenizer, and if config.model_type == "idefics":
# a/ takes care of fetching images from the URL prompts = []
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model for inp in inputs:
tokenized_inputs = processor( # Each input is encoded into a list, where each element of this input list is either a string or a URL
prompts, prompt = []
return_tensors="pt", for chunk in inp:
padding=True, chunk_type = chunk.WhichOneof("chunk")
truncation=True, if chunk_type == "text":
max_length=max_truncation, prompt.append(chunk.text)
# TODO Check impact on idefics elif chunk_type == "image":
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token image = Image.open(BytesIO(chunk.image.data))
).to(device) prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
# TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
else:
images = []
texts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_images = []
curr_text = ""
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
curr_images.append(image)
# TODO unsure about BOS
curr_text += "<|image|><|begin_of_text|>"
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
images.append(curr_images)
texts.append(curr_text)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
images=images,
text=texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append( prefix_offsets.append(