mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: improve image processing
This commit is contained in:
parent
dbe1666bc7
commit
c9573ddf28
@ -1269,11 +1269,12 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
@ -127,11 +127,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
n_cols = image_input["cols"][0][image_id]
|
||||
|
||||
# TODO: avoid using hardcoded values
|
||||
image_seq_len = 169 # default value
|
||||
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||||
|
||||
image_seq_len = int(
|
||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||
/ (config.scale_factor**2)
|
||||
)
|
||||
image_str = get_image_prompt_string(
|
||||
n_rows,
|
||||
n_cols,
|
||||
@ -301,10 +300,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
replacement_text = image_text_replacement(
|
||||
full_text += image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
full_text += replacement_text
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
@ -379,10 +377,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# **processor_kwargs,
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
# import ipdb; ipdb.set_trace()
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
|
Loading…
Reference in New Issue
Block a user