mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: improve image processing
This commit is contained in:
parent
dbe1666bc7
commit
c9573ddf28
@ -1269,11 +1269,12 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
# XXX: Extremely important to cap resolution in order to limit
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
# VRAM usage.
|
# VRAM usage.
|
||||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
@ -127,11 +127,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
# TODO: implement this in a more general way
|
# TODO: implement this in a more general way
|
||||||
n_rows = image_input["rows"][0][image_id]
|
n_rows = image_input["rows"][0][image_id]
|
||||||
n_cols = image_input["cols"][0][image_id]
|
n_cols = image_input["cols"][0][image_id]
|
||||||
|
image_seq_len = int(
|
||||||
# TODO: avoid using hardcoded values
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||||
image_seq_len = 169 # default value
|
/ (config.scale_factor**2)
|
||||||
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
)
|
||||||
|
|
||||||
image_str = get_image_prompt_string(
|
image_str = get_image_prompt_string(
|
||||||
n_rows,
|
n_rows,
|
||||||
n_cols,
|
n_cols,
|
||||||
@ -301,10 +300,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
full_text += chunk.text
|
full_text += chunk.text
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
replacement_text = image_text_replacement(
|
full_text += image_text_replacement(
|
||||||
processor, image_inputs, config, image_id
|
processor, image_inputs, config, image_id
|
||||||
)
|
)
|
||||||
full_text += replacement_text
|
|
||||||
image_id += 1
|
image_id += 1
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
@ -379,10 +377,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
# **processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
self.batch_class = batch_class
|
self.batch_class = batch_class
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
Loading…
Reference in New Issue
Block a user