mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
optimizations
This commit is contained in:
parent
2f67c53075
commit
6545cdde0d
@ -190,11 +190,13 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
|||||||
)
|
)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def preprocess_text(config, text: str) -> str:
|
def preprocess_text(config, text: str) -> str:
|
||||||
if config.model_type == "paligemma":
|
if config.model_type == "paligemma":
|
||||||
return "<bos>" + text + "\n"
|
return "<bos>" + text + "\n"
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(config, img):
|
def preprocess_image(config, img):
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
@ -207,6 +209,7 @@ def preprocess_image(config, img):
|
|||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def get_unpadded_features(
|
def get_unpadded_features(
|
||||||
original_height: int,
|
original_height: int,
|
||||||
original_width: int,
|
original_width: int,
|
||||||
@ -379,7 +382,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch_image_inputs: List[Optional[List[dict]]] = []
|
batch_image_inputs: List[Optional[List[dict]]] = []
|
||||||
batch_image_positions: List[Optional[List[ImagePositions]]] = []
|
batch_image_positions: List[Optional[List[ImagePositions]]] = []
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
for r in requests:
|
||||||
text_parts = []
|
text_parts = []
|
||||||
image_inputs = []
|
image_inputs = []
|
||||||
image_texts = []
|
image_texts = []
|
||||||
@ -457,16 +460,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
img_text = img_text.replace("\n\n", "")
|
img_text = img_text.replace("\n\n", "")
|
||||||
|
|
||||||
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
|
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
|
||||||
|
length = len(tokens)
|
||||||
|
|
||||||
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
||||||
index = img_start_token_pos[pos]
|
index = img_start_token_pos[pos]
|
||||||
|
|
||||||
is_embed = torch.tensor(tokens) == config.image_token_index
|
assert (
|
||||||
num_placeholder_tokens = is_embed.sum().item()
|
input_ids[index : index + length] == tokens
|
||||||
|
), "Image tokens not found in input_ids"
|
||||||
|
|
||||||
length = len(tokens)
|
num_placeholder_tokens = tokens.count(config.image_token_index)
|
||||||
if num_placeholder_tokens == length:
|
if num_placeholder_tokens == length:
|
||||||
is_embed = None
|
is_embed = None
|
||||||
|
else:
|
||||||
|
is_embed = torch.as_tensor(tokens) == config.image_token_index
|
||||||
|
|
||||||
pos = ImagePositions(
|
pos = ImagePositions(
|
||||||
offset=index,
|
offset=index,
|
||||||
|
Loading…
Reference in New Issue
Block a user