mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
cleaning and fixes
This commit is contained in:
parent
8cda9ca2f7
commit
5ba141e5d9
@ -1013,7 +1013,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
no_images = False
|
||||
|
||||
|
||||
if image_hidden_states is None:
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
||||
@ -1040,7 +1040,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
||||
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
||||
else:
|
||||
no_images = True
|
||||
no_images = False
|
||||
num_images = pixel_values.shape[1]
|
||||
image_seq_len = image_hidden_states.shape[1] // num_images
|
||||
|
||||
|
@ -24,35 +24,6 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
# UTILS
|
||||
def base64_to_pil(encoded_image):
|
||||
decoded_image = base64.b64decode(encoded_image)
|
||||
pil_image = Image.open(BytesIO(decoded_image))
|
||||
return pil_image
|
||||
|
||||
def im_markdown_to_pil(im_markdown_str):
|
||||
pattern = r'<img src="data:image/png;base64,([^"]+)" />'
|
||||
match = re.search(pattern, im_markdown_str)
|
||||
img_b64_str = match.group(1)
|
||||
return base64_to_pil(img_b64_str)
|
||||
|
||||
def split_str_on_im_markdown(string_with_potential_im_markdown):
|
||||
"""
|
||||
Extract from a string (typically the user prompt string) the potentional images saved as a base64 representation
|
||||
inside a markdown.
|
||||
"""
|
||||
pattern = r'<img src="data:image/png;base64,([^"]+)" />'
|
||||
parts = re.split(pattern, string_with_potential_im_markdown)
|
||||
result = []
|
||||
for i, part in enumerate(parts):
|
||||
if i % 2 == 0:
|
||||
result.append(part)
|
||||
else:
|
||||
img_tag = f'<img src="data:image/png;base64,{part.strip()}" />'
|
||||
result.append(img_tag)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdeficsCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
@ -126,8 +97,8 @@ class IdeficsCausalLMBatch(Batch):
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate) #TODO: understand that
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens # TODO: I think it is just the maximum of tokens to generate in the WHOLE batch
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
@ -145,16 +116,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported type of input")
|
||||
# I initially wanted to send the images in string base64 but they are too big to send in a consistent way...
|
||||
# So resorting to uploading the image to a server and pulling them back
|
||||
# splitted_inp = split_str_on_im_markdown(inp)
|
||||
# prompts.append(
|
||||
# [
|
||||
# im_markdown_to_pil(s) if s.startswith('<img src="data:image/png;base64,') else s
|
||||
# for s in splitted_inp
|
||||
# if s != ""
|
||||
# ]
|
||||
# )
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
@ -288,6 +249,10 @@ class IdeficsCausalLMBatch(Batch):
|
||||
+ new_padding_right_offset,
|
||||
:
|
||||
]
|
||||
if self.image_hidden_states is None:
|
||||
image_hidden_states = None
|
||||
else:
|
||||
image_hidden_states = self.image_hidden_states[keep_indices]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
@ -315,6 +280,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = input_ids
|
||||
self.pixel_values = pixel_values
|
||||
self.image_hidden_states = image_hidden_states
|
||||
self.position_ids = position_ids
|
||||
self.all_input_ids = all_input_ids
|
||||
self.input_lengths = input_lengths
|
||||
@ -675,7 +641,7 @@ class IdeficsCausalLM(Model):
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
||||
image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1)
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user