mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
feat: improve serde add tests and cleanup
This commit is contained in:
parent
c98a6b9948
commit
f8be8d5da7
@ -29,6 +29,7 @@ from text_generation.types import (
|
||||
ChatCompletionComplete,
|
||||
Completion,
|
||||
)
|
||||
from huggingface_hub import AsyncInferenceClient
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||
@ -225,6 +226,7 @@ class GenerousResponseComparator(ResponseComparator):
|
||||
class LauncherHandle:
|
||||
def __init__(self, port: int):
|
||||
self.client = AsyncClient(f"http://localhost:{port}")
|
||||
self.inference_client = AsyncInferenceClient(f"http://localhost:{port}")
|
||||
|
||||
def _inner_health(self):
|
||||
raise NotImplementedError
|
||||
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": " The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1715786475,
|
||||
"id": "",
|
||||
"model": "llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 2943,
|
||||
"total_tokens": 3043
|
||||
}
|
||||
}
|
41
integration-tests/models/test_chat_llava_next.py
Normal file
41
integration-tests/models/test_chat_llava_next.py
Normal file
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llava_next_handle(launcher):
|
||||
with launcher("llava-hf/llava-v1.6-mistral-7b-hf") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llava_chat(flash_llava_next_handle):
|
||||
await flash_llava_next_handle.health(3000)
|
||||
return flash_llava_next_handle.inference_client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_simple(flash_llava_chat, response_snapshot):
|
||||
response = await flash_llava_chat.chat_completion(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
seed=42,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== " The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a"
|
||||
)
|
||||
assert response == response_snapshot
|
@ -375,8 +375,6 @@ impl ChatTemplate {
|
||||
}
|
||||
}
|
||||
|
||||
println!("{:?}", messages);
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
|
@ -8,7 +8,8 @@ mod validation;
|
||||
|
||||
use infer::{Infer, InferError, InferStreamResponse};
|
||||
use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use utoipa::ToSchema;
|
||||
@ -896,16 +897,12 @@ pub(crate) struct ImageUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
#[derive(Clone, Serialize, Debug, PartialEq)]
|
||||
enum ContentChunk {
|
||||
Text(String),
|
||||
ImageUrl(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
struct ContentChunks(Vec<ContentChunk>);
|
||||
|
||||
// Convert in and out of ContentChunk
|
||||
impl From<String> for ContentChunk {
|
||||
fn from(s: String) -> Self {
|
||||
ContentChunk::Text(s)
|
||||
@ -918,7 +915,29 @@ impl From<&str> for ContentChunk {
|
||||
}
|
||||
}
|
||||
|
||||
// Convert in and out of ContentChunks
|
||||
impl<'de> Deserialize<'de> for ContentChunk {
|
||||
fn deserialize<D>(deserializer: D) -> Result<ContentChunk, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum ContentType {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
match ContentType::deserialize(deserializer)? {
|
||||
ContentType::Text { text } => Ok(ContentChunk::Text(text)),
|
||||
ContentType::ImageUrl { image_url } => Ok(ContentChunk::ImageUrl(image_url.url)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug, PartialEq)]
|
||||
struct ContentChunks(Vec<ContentChunk>);
|
||||
|
||||
impl From<Vec<ContentChunk>> for ContentChunks {
|
||||
fn from(chunks: Vec<ContentChunk>) -> Self {
|
||||
Self(chunks)
|
||||
@ -947,48 +966,63 @@ impl Serialize for ContentChunks {
|
||||
ContentChunk::ImageUrl(s) => format!("", s),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
.join("");
|
||||
serializer.serialize_str(&formatted)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_markdown_to_chunks(s: &str) -> Result<Vec<ContentChunk>, serde_json::Error> {
|
||||
let mut chunks = Vec::new();
|
||||
let re = Regex::new(r"!\[([^\]]*)\]\(([^)]+)\)").unwrap();
|
||||
|
||||
let mut last_index = 0;
|
||||
for cap in re.captures_iter(s) {
|
||||
if let Some(m) = cap.get(0) {
|
||||
if m.start() > last_index {
|
||||
chunks.push(ContentChunk::Text(s[last_index..m.start()].to_string()));
|
||||
}
|
||||
let _alt_text = cap.get(1).map(|m| m.as_str().to_string());
|
||||
let url = cap.get(2).map(|m| m.as_str().to_string()).unwrap();
|
||||
chunks.push(ContentChunk::ImageUrl(url));
|
||||
last_index = m.end();
|
||||
}
|
||||
}
|
||||
|
||||
if last_index < s.len() {
|
||||
chunks.push(ContentChunk::Text(s[last_index..].to_string()));
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
mod message_content_serde {
|
||||
use super::*;
|
||||
use serde::de::{Deserialize, Deserializer, Error};
|
||||
use serde::{de::Error, Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ContentChunks>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(vec![s.into()].into())),
|
||||
match Value::deserialize(deserializer)? {
|
||||
Value::String(s) => {
|
||||
let chunks = parse_markdown_to_chunks(&s).map_err(Error::custom)?;
|
||||
Ok(Some(chunks.into()))
|
||||
}
|
||||
Value::Array(arr) => arr
|
||||
.into_iter()
|
||||
.map(|v| match v {
|
||||
Value::String(s) => Ok(ContentChunk::Text(s)),
|
||||
Value::Object(map) => match map
|
||||
.get("image_url")
|
||||
.and_then(|x| x.get("url").and_then(|u| u.as_str()))
|
||||
{
|
||||
Some(url) => Ok(ContentChunk::ImageUrl(url.to_string())),
|
||||
None => map
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|text| Ok(ContentChunk::Text(text.to_string())))
|
||||
.map_or_else(
|
||||
|| Err(Error::custom("Expected a string or an object")),
|
||||
|x| x,
|
||||
),
|
||||
},
|
||||
_ => Err(Error::custom("Expected a string or an object")),
|
||||
Value::Object(map) => serde_json::from_value(Value::Object(map))
|
||||
.map_err(|e| Error::custom(e.to_string())),
|
||||
_ => Err(Error::custom("Expected string or object")),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|chunks| Some(chunks.into())),
|
||||
Value::Null => Ok(None),
|
||||
_ => Err(Error::custom("Invalid content format")),
|
||||
_ => Err(Error::custom(
|
||||
"Expected string or array of text/image_url objects",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1169,6 +1203,7 @@ pub(crate) struct ErrorResponse {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use serde_json::json;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||
@ -1236,4 +1271,58 @@ mod tests {
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_content_chunks_serde() {
|
||||
let content = json!("Whats in this image?");
|
||||
let chunks = message_content_serde::deserialize(content)
|
||||
.expect("Failed to deserialize")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
chunks,
|
||||
ContentChunks(vec![
|
||||
ContentChunk::Text("Whats in this image?".to_string()),
|
||||
ContentChunk::ImageUrl("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string())
|
||||
])
|
||||
);
|
||||
|
||||
let flattened = serde_json::to_string(&chunks).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
flattened,
|
||||
r#""Whats in this image?""#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_typed_message_content_chunks_serde() {
|
||||
let content = json!([
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
},
|
||||
},
|
||||
]);
|
||||
let chunks: ContentChunks = message_content_serde::deserialize(content)
|
||||
.expect("Failed to deserialize")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
chunks,
|
||||
ContentChunks(vec![
|
||||
ContentChunk::Text("Whats in this image?".to_string()),
|
||||
ContentChunk::ImageUrl("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string())
|
||||
])
|
||||
);
|
||||
|
||||
let flattened = serde_json::to_string(&chunks).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
flattened,
|
||||
r#""Whats in this image?""#
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user