mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Using markdown to send image.
This commit is contained in:
parent
5ba141e5d9
commit
24ea07cd6d
@ -28,6 +28,8 @@ from transformers.image_utils import (
|
|||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
)
|
)
|
||||||
|
from io import BytesIO
|
||||||
|
import requests
|
||||||
from transformers import TensorType, is_torch_available
|
from transformers import TensorType, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@ -162,5 +164,27 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
|||||||
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
|
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
|
||||||
|
"""
|
||||||
|
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
|
||||||
|
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"User-Agent": (
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
|
||||||
|
" Safari/537.36"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if isinstance(image_url_or_urls, list):
|
||||||
|
return [self.fetch_images(x) for x in image_url_or_urls]
|
||||||
|
elif isinstance(image_url_or_urls, str):
|
||||||
|
response = requests.get(image_url_or_urls, stream=True, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return Image.open(BytesIO(response.content))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
||||||
|
@ -4,7 +4,7 @@ import re
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import json
|
import re
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -21,6 +21,26 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)')
|
||||||
|
|
||||||
|
def split(string):
|
||||||
|
parts = []
|
||||||
|
cursor = 0
|
||||||
|
for pattern in IMAGES.finditer(string):
|
||||||
|
start = pattern.start()
|
||||||
|
if start != cursor:
|
||||||
|
parts.append(string[cursor:start])
|
||||||
|
|
||||||
|
parts.append(pattern.group(1))
|
||||||
|
cursor = pattern.end()
|
||||||
|
|
||||||
|
if cursor != len(string):
|
||||||
|
parts.append(string[cursor:])
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -106,16 +126,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
prompts = []
|
prompts = []
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
if isinstance(inp, str):
|
prompts.append(split(inp))
|
||||||
prompts.append([inp])
|
|
||||||
elif isinstance(inp, list):
|
|
||||||
if not all(isinstance(item, str) for item in inp):
|
|
||||||
raise ValueError("All elements in the list must be strings (text string or image URL)")
|
|
||||||
prompts.append(
|
|
||||||
json.load(inp)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported type of input")
|
|
||||||
|
|
||||||
# The processor replaces the call to tokenizer, and
|
# The processor replaces the call to tokenizer, and
|
||||||
# a/ takes care of fetching images from the URL
|
# a/ takes care of fetching images from the URL
|
||||||
|
@ -16,8 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
|||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
|
||||||
PROFILE = False
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
@ -28,14 +26,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
# Force inference mode for the lifetime of TextGenerationService
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
if PROFILE:
|
|
||||||
self.prof = torch.profiler.profile(
|
|
||||||
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/idefics'),
|
|
||||||
record_shapes=True,
|
|
||||||
with_stack=True
|
|
||||||
)
|
|
||||||
self.prof.start()
|
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
@ -90,8 +80,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
if PROFILE:
|
|
||||||
self.prof.step()
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
@ -119,12 +107,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
if PROFILE:
|
|
||||||
self.prof.step()
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
if next_batch is None:
|
|
||||||
if PROFILE:
|
|
||||||
self.prof.stop()
|
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
|
Loading…
Reference in New Issue
Block a user