mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
cleaning and fixes
This commit is contained in:
parent
8cda9ca2f7
commit
5ba141e5d9
@ -1013,7 +1013,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
no_images = False
|
no_images = False
|
||||||
|
|
||||||
if image_hidden_states is None:
|
if image_hidden_states is None:
|
||||||
if pixel_values is None and image_embeddings is None:
|
if pixel_values is None and image_embeddings is None:
|
||||||
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
||||||
@ -1040,7 +1040,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
||||||
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
||||||
else:
|
else:
|
||||||
no_images = True
|
no_images = False
|
||||||
num_images = pixel_values.shape[1]
|
num_images = pixel_values.shape[1]
|
||||||
image_seq_len = image_hidden_states.shape[1] // num_images
|
image_seq_len = image_hidden_states.shape[1] // num_images
|
||||||
|
|
||||||
|
@ -24,35 +24,6 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
# UTILS
|
|
||||||
def base64_to_pil(encoded_image):
|
|
||||||
decoded_image = base64.b64decode(encoded_image)
|
|
||||||
pil_image = Image.open(BytesIO(decoded_image))
|
|
||||||
return pil_image
|
|
||||||
|
|
||||||
def im_markdown_to_pil(im_markdown_str):
|
|
||||||
pattern = r'<img src="data:image/png;base64,([^"]+)" />'
|
|
||||||
match = re.search(pattern, im_markdown_str)
|
|
||||||
img_b64_str = match.group(1)
|
|
||||||
return base64_to_pil(img_b64_str)
|
|
||||||
|
|
||||||
def split_str_on_im_markdown(string_with_potential_im_markdown):
|
|
||||||
"""
|
|
||||||
Extract from a string (typically the user prompt string) the potentional images saved as a base64 representation
|
|
||||||
inside a markdown.
|
|
||||||
"""
|
|
||||||
pattern = r'<img src="data:image/png;base64,([^"]+)" />'
|
|
||||||
parts = re.split(pattern, string_with_potential_im_markdown)
|
|
||||||
result = []
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
if i % 2 == 0:
|
|
||||||
result.append(part)
|
|
||||||
else:
|
|
||||||
img_tag = f'<img src="data:image/png;base64,{part.strip()}" />'
|
|
||||||
result.append(img_tag)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IdeficsCausalLMBatch(Batch):
|
class IdeficsCausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
@ -126,8 +97,8 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_truncation = max(max_truncation, r.truncate) #TODO: understand that
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens # TODO: I think it is just the maximum of tokens to generate in the WHOLE batch
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
@ -145,16 +116,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported type of input")
|
raise ValueError("Unsupported type of input")
|
||||||
# I initially wanted to send the images in string base64 but they are too big to send in a consistent way...
|
|
||||||
# So resorting to uploading the image to a server and pulling them back
|
|
||||||
# splitted_inp = split_str_on_im_markdown(inp)
|
|
||||||
# prompts.append(
|
|
||||||
# [
|
|
||||||
# im_markdown_to_pil(s) if s.startswith('<img src="data:image/png;base64,') else s
|
|
||||||
# for s in splitted_inp
|
|
||||||
# if s != ""
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# The processor replaces the call to tokenizer, and
|
# The processor replaces the call to tokenizer, and
|
||||||
# a/ takes care of fetching images from the URL
|
# a/ takes care of fetching images from the URL
|
||||||
@ -288,6 +249,10 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
+ new_padding_right_offset,
|
+ new_padding_right_offset,
|
||||||
:
|
:
|
||||||
]
|
]
|
||||||
|
if self.image_hidden_states is None:
|
||||||
|
image_hidden_states = None
|
||||||
|
else:
|
||||||
|
image_hidden_states = self.image_hidden_states[keep_indices]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
if type(self.past_key_values[0]) == tuple:
|
if type(self.past_key_values[0]) == tuple:
|
||||||
@ -315,6 +280,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
self.requests_idx_mapping = requests_idx_mapping
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.pixel_values = pixel_values
|
self.pixel_values = pixel_values
|
||||||
|
self.image_hidden_states = image_hidden_states
|
||||||
self.position_ids = position_ids
|
self.position_ids = position_ids
|
||||||
self.all_input_ids = all_input_ids
|
self.all_input_ids = all_input_ids
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
@ -675,7 +641,7 @@ class IdeficsCausalLM(Model):
|
|||||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||||
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user