diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py index 727f94c6..8b06ba4c 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -28,6 +28,8 @@ from transformers.image_utils import ( to_numpy_array, valid_images, ) +from io import BytesIO +import requests from transformers import TensorType, is_torch_available @@ -162,5 +164,27 @@ class IdeficsImageProcessor(BaseImageProcessor): images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"] return images + + def fetch_images(self, image_url_or_urls: Union[str, List[str]]): + """ + Convert a single or a list of urls into the corresponding `PIL.Image` objects. + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + headers = { + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0" + " Safari/537.36" + ) + } + if isinstance(image_url_or_urls, list): + return [self.fetch_images(x) for x in image_url_or_urls] + elif isinstance(image_url_or_urls, str): + response = requests.get(image_url_or_urls, stream=True, headers=headers) + response.raise_for_status() + return Image.open(BytesIO(response.content)) + else: + raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}") + import transformers transformers.IdeficsImageProcessor = IdeficsImageProcessor diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 9adc44d1..873a27d4 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -4,7 +4,7 @@ import re from io import BytesIO import base64 from PIL import Image -import json +import re from dataclasses import dataclass from opentelemetry import trace @@ -21,6 +21,26 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +import re + +IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)') + +def split(string): + parts = [] + cursor = 0 + for pattern in IMAGES.finditer(string): + start = pattern.start() + if start != cursor: + parts.append(string[cursor:start]) + + parts.append(pattern.group(1)) + cursor = pattern.end() + + if cursor != len(string): + parts.append(string[cursor:]) + + return parts + tracer = trace.get_tracer(__name__) @@ -106,16 +126,7 @@ class IdeficsCausalLMBatch(Batch): prompts = [] for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL - if isinstance(inp, str): - prompts.append([inp]) - elif isinstance(inp, list): - if not all(isinstance(item, str) for item in inp): - raise ValueError("All elements in the list must be strings (text string or image URL)") - prompts.append( - json.load(inp) - ) - else: - raise ValueError("Unsupported type of input") + prompts.append(split(inp)) # The processor replaces the call to tokenizer, and # a/ takes care of fetching images from the URL diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 9507d331..67137aaa 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,8 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch -PROFILE = False - class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]): self.cache = cache @@ -28,14 +26,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # Force inference mode for the lifetime of TextGenerationService self._inference_mode_raii_guard = torch._C._InferenceMode(True) - if PROFILE: - self.prof = torch.profiler.profile( - schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'), - record_shapes=True, - with_stack=True - ) - self.prof.start() async def Info(self, request, context): return self.model.info @@ -90,8 +80,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) generations, next_batch = self.model.generate_token(batch) - if PROFILE: - self.prof.step() self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -119,12 +107,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = batches[0] generations, next_batch = self.model.generate_token(batch) - if PROFILE: - self.prof.step() self.cache.set(next_batch) - if next_batch is None: - if PROFILE: - self.prof.stop() return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations],