Using markdown to send image.

This commit is contained in:
Nicolas Patry 2023-08-16 14:26:26 +00:00
parent 5ba141e5d9
commit 24ea07cd6d
3 changed files with 46 additions and 28 deletions

View File

@ -28,6 +28,8 @@ from transformers.image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from io import BytesIO
import requests
from transformers import TensorType, is_torch_available from transformers import TensorType, is_torch_available
@ -162,5 +164,27 @@ class IdeficsImageProcessor(BaseImageProcessor):
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"] images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
return images return images
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
"""
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
headers = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
" Safari/537.36"
)
}
if isinstance(image_url_or_urls, list):
return [self.fetch_images(x) for x in image_url_or_urls]
elif isinstance(image_url_or_urls, str):
response = requests.get(image_url_or_urls, stream=True, headers=headers)
response.raise_for_status()
return Image.open(BytesIO(response.content))
else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
import transformers import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor transformers.IdeficsImageProcessor = IdeficsImageProcessor

View File

@ -4,7 +4,7 @@ import re
from io import BytesIO from io import BytesIO
import base64 import base64
from PIL import Image from PIL import Image
import json import re
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
@ -21,6 +21,26 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import re
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)')
def split(string):
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append(string[cursor:start])
parts.append(pattern.group(1))
cursor = pattern.end()
if cursor != len(string):
parts.append(string[cursor:])
return parts
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -106,16 +126,7 @@ class IdeficsCausalLMBatch(Batch):
prompts = [] prompts = []
for inp in inputs: for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL # Each input is encoded into a list, where each element of this input list is either a string or a URL
if isinstance(inp, str): prompts.append(split(inp))
prompts.append([inp])
elif isinstance(inp, list):
if not all(isinstance(item, str) for item in inp):
raise ValueError("All elements in the list must be strings (text string or image URL)")
prompts.append(
json.load(inp)
)
else:
raise ValueError("Unsupported type of input")
# The processor replaces the call to tokenizer, and # The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL # a/ takes care of fetching images from the URL

View File

@ -16,8 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
PROFILE = False
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]): def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache self.cache = cache
@ -28,14 +26,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService # Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True) self._inference_mode_raii_guard = torch._C._InferenceMode(True)
if PROFILE:
self.prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
record_shapes=True,
with_stack=True
)
self.prof.start()
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info
@ -90,8 +80,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
) )
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
if PROFILE:
self.prof.step()
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
@ -119,12 +107,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0] batch = batches[0]
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
if PROFILE:
self.prof.step()
self.cache.set(next_batch) self.cache.set(next_batch)
if next_batch is None:
if PROFILE:
self.prof.stop()
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],