mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Using markdown to send image.
This commit is contained in:
parent
5ba141e5d9
commit
24ea07cd6d
@ -28,6 +28,8 @@ from transformers.image_utils import (
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from io import BytesIO
|
||||
import requests
|
||||
from transformers import TensorType, is_torch_available
|
||||
|
||||
|
||||
@ -162,5 +164,27 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
|
||||
|
||||
return images
|
||||
|
||||
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
|
||||
"""
|
||||
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
|
||||
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||
returned.
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
|
||||
" Safari/537.36"
|
||||
)
|
||||
}
|
||||
if isinstance(image_url_or_urls, list):
|
||||
return [self.fetch_images(x) for x in image_url_or_urls]
|
||||
elif isinstance(image_url_or_urls, str):
|
||||
response = requests.get(image_url_or_urls, stream=True, headers=headers)
|
||||
response.raise_for_status()
|
||||
return Image.open(BytesIO(response.content))
|
||||
else:
|
||||
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
||||
|
||||
import transformers
|
||||
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
||||
|
@ -4,7 +4,7 @@ import re
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from PIL import Image
|
||||
import json
|
||||
import re
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
@ -21,6 +21,26 @@ from text_generation_server.models.types import (
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
import re
|
||||
|
||||
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)')
|
||||
|
||||
def split(string):
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append(string[cursor:start])
|
||||
|
||||
parts.append(pattern.group(1))
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append(string[cursor:])
|
||||
|
||||
return parts
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@ -106,16 +126,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
if isinstance(inp, str):
|
||||
prompts.append([inp])
|
||||
elif isinstance(inp, list):
|
||||
if not all(isinstance(item, str) for item in inp):
|
||||
raise ValueError("All elements in the list must be strings (text string or image URL)")
|
||||
prompts.append(
|
||||
json.load(inp)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported type of input")
|
||||
prompts.append(split(inp))
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
|
@ -16,8 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
PROFILE = False
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
self.cache = cache
|
||||
@ -28,14 +26,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
if PROFILE:
|
||||
self.prof = torch.profiler.profile(
|
||||
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
|
||||
record_shapes=True,
|
||||
with_stack=True
|
||||
)
|
||||
self.prof.start()
|
||||
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
@ -90,8 +80,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
if PROFILE:
|
||||
self.prof.step()
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
@ -119,12 +107,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = batches[0]
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
if PROFILE:
|
||||
self.prof.step()
|
||||
self.cache.set(next_batch)
|
||||
if next_batch is None:
|
||||
if PROFILE:
|
||||
self.prof.stop()
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
|
Loading…
Reference in New Issue
Block a user