cleaning and fixes

This commit is contained in:
VictorSanh 2023-08-16 03:43:43 +00:00
parent 8cda9ca2f7
commit 5ba141e5d9
2 changed files with 10 additions and 44 deletions

View File

@ -1040,7 +1040,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
else:
no_images = True
no_images = False
num_images = pixel_values.shape[1]
image_seq_len = image_hidden_states.shape[1] // num_images

View File

@ -24,35 +24,6 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
tracer = trace.get_tracer(__name__)
# UTILS
def base64_to_pil(encoded_image):
decoded_image = base64.b64decode(encoded_image)
pil_image = Image.open(BytesIO(decoded_image))
return pil_image
def im_markdown_to_pil(im_markdown_str):
pattern = r'<img src="data:image/png;base64,([^"]+)" />'
match = re.search(pattern, im_markdown_str)
img_b64_str = match.group(1)
return base64_to_pil(img_b64_str)
def split_str_on_im_markdown(string_with_potential_im_markdown):
"""
Extract from a string (typically the user prompt string) the potentional images saved as a base64 representation
inside a markdown.
"""
pattern = r'<img src="data:image/png;base64,([^"]+)" />'
parts = re.split(pattern, string_with_potential_im_markdown)
result = []
for i, part in enumerate(parts):
if i % 2 == 0:
result.append(part)
else:
img_tag = f'<img src="data:image/png;base64,{part.strip()}" />'
result.append(img_tag)
return result
@dataclass
class IdeficsCausalLMBatch(Batch):
batch_id: int
@ -126,8 +97,8 @@ class IdeficsCausalLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate) #TODO: understand that
max_decode_tokens += stopping_criteria.max_new_tokens # TODO: I think it is just the maximum of tokens to generate in the WHOLE batch
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -145,16 +116,6 @@ class IdeficsCausalLMBatch(Batch):
)
else:
raise ValueError("Unsupported type of input")
# I initially wanted to send the images in string base64 but they are too big to send in a consistent way...
# So resorting to uploading the image to a server and pulling them back
# splitted_inp = split_str_on_im_markdown(inp)
# prompts.append(
# [
# im_markdown_to_pil(s) if s.startswith('<img src="data:image/png;base64,') else s
# for s in splitted_inp
# if s != ""
# ]
# )
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
@ -288,6 +249,10 @@ class IdeficsCausalLMBatch(Batch):
+ new_padding_right_offset,
:
]
if self.image_hidden_states is None:
image_hidden_states = None
else:
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
@ -315,6 +280,7 @@ class IdeficsCausalLMBatch(Batch):
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.pixel_values = pixel_values
self.image_hidden_states = image_hidden_states
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
@ -675,7 +641,7 @@ class IdeficsCausalLM(Model):
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]