mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Preprocessing.
This commit is contained in:
parent
907906466a
commit
39d2073e93
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Mllama model."""
|
"""PyTorch Mllama model."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -23,11 +23,6 @@ import math
|
|||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
|
||||||
load_text_model,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
||||||
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -40,7 +35,6 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
||||||
|
@ -120,6 +120,8 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO Check impact on idefics
|
# TODO Check impact on idefics
|
||||||
|
|
||||||
|
if config.model_type == "idefics":
|
||||||
prompts = []
|
prompts = []
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
@ -147,6 +149,38 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
# TODO Check impact on idefics
|
# TODO Check impact on idefics
|
||||||
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||||
).to(device)
|
).to(device)
|
||||||
|
else:
|
||||||
|
images = []
|
||||||
|
texts = []
|
||||||
|
for inp in inputs:
|
||||||
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
|
curr_images = []
|
||||||
|
curr_text = ""
|
||||||
|
for chunk in inp:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
curr_text += chunk.text
|
||||||
|
elif chunk_type == "image":
|
||||||
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
|
curr_images.append(image)
|
||||||
|
# TODO unsure about BOS
|
||||||
|
curr_text += "<|image|><|begin_of_text|>"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
images.append(curr_images)
|
||||||
|
texts.append(curr_text)
|
||||||
|
|
||||||
|
# The processor replaces the call to tokenizer, and
|
||||||
|
# a/ takes care of fetching images from the URL
|
||||||
|
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
|
||||||
|
tokenized_inputs = processor(
|
||||||
|
images=images,
|
||||||
|
text=texts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
|
).to(device)
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
prefix_offsets.append(
|
prefix_offsets.append(
|
||||||
|
Loading…
Reference in New Issue
Block a user