fix: improve image processing

This commit is contained in:
drbh 2024-12-18 01:36:36 +00:00
parent dbe1666bc7
commit c9573ddf28
2 changed files with 8 additions and 10 deletions

View File

@ -1269,11 +1269,12 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
processor_kwargs={"size": {"longest_edge": 1456}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))

View File

@ -127,11 +127,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
# TODO: avoid using hardcoded values
image_seq_len = 169 # default value
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
)
image_str = get_image_prompt_string(
n_rows,
n_cols,
@ -301,10 +300,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
replacement_text = image_text_replacement(
full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
full_text += replacement_text
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
@ -379,10 +377,9 @@ class VlmCausalLM(FlashCausalLM):
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# **processor_kwargs,
**processor_kwargs,
)
self.batch_class = batch_class
# import ipdb; ipdb.set_trace()
super().__init__(
model_id=model_id,
revision=revision,