Using markdown to send image.

This commit is contained in:
Nicolas Patry 2023-08-16 14:26:26 +00:00
parent 5ba141e5d9
commit 24ea07cd6d
3 changed files with 46 additions and 28 deletions

View File

@ -28,6 +28,8 @@ from transformers.image_utils import (
to_numpy_array,
valid_images,
)
from io import BytesIO
import requests
from transformers import TensorType, is_torch_available
@ -162,5 +164,27 @@ class IdeficsImageProcessor(BaseImageProcessor):
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
return images
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
"""
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
headers = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
" Safari/537.36"
)
}
if isinstance(image_url_or_urls, list):
return [self.fetch_images(x) for x in image_url_or_urls]
elif isinstance(image_url_or_urls, str):
response = requests.get(image_url_or_urls, stream=True, headers=headers)
response.raise_for_status()
return Image.open(BytesIO(response.content))
else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor

View File

@ -4,7 +4,7 @@ import re
from io import BytesIO
import base64
from PIL import Image
import json
import re
from dataclasses import dataclass
from opentelemetry import trace
@ -21,6 +21,26 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import re
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)')
def split(string):
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append(string[cursor:start])
parts.append(pattern.group(1))
cursor = pattern.end()
if cursor != len(string):
parts.append(string[cursor:])
return parts
tracer = trace.get_tracer(__name__)
@ -106,16 +126,7 @@ class IdeficsCausalLMBatch(Batch):
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
if isinstance(inp, str):
prompts.append([inp])
elif isinstance(inp, list):
if not all(isinstance(item, str) for item in inp):
raise ValueError("All elements in the list must be strings (text string or image URL)")
prompts.append(
json.load(inp)
)
else:
raise ValueError("Unsupported type of input")
prompts.append(split(inp))
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL

View File

@ -16,8 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
PROFILE = False
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
@ -28,14 +26,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
if PROFILE:
self.prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
record_shapes=True,
with_stack=True
)
self.prof.start()
async def Info(self, request, context):
return self.model.info
@ -90,8 +80,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
generations, next_batch = self.model.generate_token(batch)
if PROFILE:
self.prof.step()
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
@ -119,12 +107,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0]
generations, next_batch = self.model.generate_token(batch)
if PROFILE:
self.prof.step()
self.cache.set(next_batch)
if next_batch is None:
if PROFILE:
self.prof.stop()
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],