From f8be8d5da7d929d1ce0d7144b0f475a1e99dfcf3 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 15 May 2024 17:26:50 +0000 Subject: [PATCH] feat: improve serde add tests and cleanup --- integration-tests/conftest.py | 2 + .../test_flash_llava_simple.json | 25 +++ .../models/test_chat_llava_next.py | 41 +++++ router/src/infer.rs | 2 - router/src/lib.rs | 147 ++++++++++++++---- 5 files changed, 186 insertions(+), 31 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_chat_llava_next/test_flash_llava_simple.json create mode 100644 integration-tests/models/test_chat_llava_next.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index ae3f977b..b15c747c 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -29,6 +29,7 @@ from text_generation.types import ( ChatCompletionComplete, Completion, ) +from huggingface_hub import AsyncInferenceClient DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @@ -225,6 +226,7 @@ class GenerousResponseComparator(ResponseComparator): class LauncherHandle: def __init__(self, port: int): self.client = AsyncClient(f"http://localhost:{port}") + self.inference_client = AsyncInferenceClient(f"http://localhost:{port}") def _inner_health(self): raise NotImplementedError diff --git a/integration-tests/models/__snapshots__/test_chat_llava_next/test_flash_llava_simple.json b/integration-tests/models/__snapshots__/test_chat_llava_next/test_flash_llava_simple.json new file mode 100644 index 00000000..fd9f8918 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_chat_llava_next/test_flash_llava_simple.json @@ -0,0 +1,25 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": " The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a", + "name": null, + "role": "assistant", + "tool_calls": null + } + } + ], + "created": 1715786475, + "id": "", + "model": "llava-hf/llava-v1.6-mistral-7b-hf", + "object": "text_completion", + "system_fingerprint": "2.0.2-native", + "usage": { + "completion_tokens": 100, + "prompt_tokens": 2943, + "total_tokens": 3043 + } +} diff --git a/integration-tests/models/test_chat_llava_next.py b/integration-tests/models/test_chat_llava_next.py new file mode 100644 index 00000000..e90b5d20 --- /dev/null +++ b/integration-tests/models/test_chat_llava_next.py @@ -0,0 +1,41 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llava_next_handle(launcher): + with launcher("llava-hf/llava-v1.6-mistral-7b-hf") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llava_chat(flash_llava_next_handle): + await flash_llava_next_handle.health(3000) + return flash_llava_next_handle.inference_client + + +@pytest.mark.private +async def test_flash_llava_simple(flash_llava_chat, response_snapshot): + response = await flash_llava_chat.chat_completion( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + ], + }, + ], + seed=42, + max_tokens=100, + ) + + assert ( + response.choices[0].message.content + == " The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a" + ) + assert response == response_snapshot diff --git a/router/src/infer.rs b/router/src/infer.rs index 58d96b3e..8613d4ef 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -375,8 +375,6 @@ impl ChatTemplate { } } - println!("{:?}", messages); - self.template .render(ChatTemplateInputs { messages, diff --git a/router/src/lib.rs b/router/src/lib.rs index 81d68fd1..22189ed2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -8,7 +8,8 @@ mod validation; use infer::{Infer, InferError, InferStreamResponse}; use queue::{Entry, Queue}; -use serde::{Deserialize, Serialize}; +use regex::Regex; +use serde::{Deserialize, Deserializer, Serialize}; use tokio::sync::OwnedSemaphorePermit; use tokio_stream::wrappers::UnboundedReceiverStream; use utoipa::ToSchema; @@ -896,16 +897,12 @@ pub(crate) struct ImageUrl { pub url: String, } -#[derive(Clone, Deserialize, Serialize, Debug)] +#[derive(Clone, Serialize, Debug, PartialEq)] enum ContentChunk { Text(String), ImageUrl(String), } -#[derive(Clone, Deserialize, Debug)] -struct ContentChunks(Vec); - -// Convert in and out of ContentChunk impl From for ContentChunk { fn from(s: String) -> Self { ContentChunk::Text(s) @@ -918,7 +915,29 @@ impl From<&str> for ContentChunk { } } -// Convert in and out of ContentChunks +impl<'de> Deserialize<'de> for ContentChunk { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(tag = "type")] + enum ContentType { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, + } + match ContentType::deserialize(deserializer)? { + ContentType::Text { text } => Ok(ContentChunk::Text(text)), + ContentType::ImageUrl { image_url } => Ok(ContentChunk::ImageUrl(image_url.url)), + } + } +} + +#[derive(Clone, Deserialize, Debug, PartialEq)] +struct ContentChunks(Vec); + impl From> for ContentChunks { fn from(chunks: Vec) -> Self { Self(chunks) @@ -947,48 +966,63 @@ impl Serialize for ContentChunks { ContentChunk::ImageUrl(s) => format!("![]({})", s), }) .collect::>() - .join(" "); + .join(""); serializer.serialize_str(&formatted) } } +fn parse_markdown_to_chunks(s: &str) -> Result, serde_json::Error> { + let mut chunks = Vec::new(); + let re = Regex::new(r"!\[([^\]]*)\]\(([^)]+)\)").unwrap(); + + let mut last_index = 0; + for cap in re.captures_iter(s) { + if let Some(m) = cap.get(0) { + if m.start() > last_index { + chunks.push(ContentChunk::Text(s[last_index..m.start()].to_string())); + } + let _alt_text = cap.get(1).map(|m| m.as_str().to_string()); + let url = cap.get(2).map(|m| m.as_str().to_string()).unwrap(); + chunks.push(ContentChunk::ImageUrl(url)); + last_index = m.end(); + } + } + + if last_index < s.len() { + chunks.push(ContentChunk::Text(s[last_index..].to_string())); + } + + Ok(chunks) +} + mod message_content_serde { use super::*; - use serde::de::{Deserialize, Deserializer, Error}; + use serde::{de::Error, Deserialize, Deserializer}; use serde_json::Value; pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(Some(vec![s.into()].into())), + match Value::deserialize(deserializer)? { + Value::String(s) => { + let chunks = parse_markdown_to_chunks(&s).map_err(Error::custom)?; + Ok(Some(chunks.into())) + } Value::Array(arr) => arr .into_iter() .map(|v| match v { Value::String(s) => Ok(ContentChunk::Text(s)), - Value::Object(map) => match map - .get("image_url") - .and_then(|x| x.get("url").and_then(|u| u.as_str())) - { - Some(url) => Ok(ContentChunk::ImageUrl(url.to_string())), - None => map - .get("text") - .and_then(|t| t.as_str()) - .map(|text| Ok(ContentChunk::Text(text.to_string()))) - .map_or_else( - || Err(Error::custom("Expected a string or an object")), - |x| x, - ), - }, - _ => Err(Error::custom("Expected a string or an object")), + Value::Object(map) => serde_json::from_value(Value::Object(map)) + .map_err(|e| Error::custom(e.to_string())), + _ => Err(Error::custom("Expected string or object")), }) .collect::, _>>() .map(|chunks| Some(chunks.into())), Value::Null => Ok(None), - _ => Err(Error::custom("Invalid content format")), + _ => Err(Error::custom( + "Expected string or array of text/image_url objects", + )), } } } @@ -1169,6 +1203,7 @@ pub(crate) struct ErrorResponse { mod tests { use super::*; + use serde_json::json; use tokenizers::Tokenizer; pub(crate) async fn get_tokenizer() -> Tokenizer { @@ -1236,4 +1271,58 @@ mod tests { ); assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } + + #[test] + fn test_message_content_chunks_serde() { + let content = json!("Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); + let chunks = message_content_serde::deserialize(content) + .expect("Failed to deserialize") + .unwrap(); + + assert_eq!( + chunks, + ContentChunks(vec![ + ContentChunk::Text("Whats in this image?".to_string()), + ContentChunk::ImageUrl("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string()) + ]) + ); + + let flattened = serde_json::to_string(&chunks).unwrap(); + + assert_eq!( + flattened, + r#""Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)""# + ); + } + + #[test] + fn test_typed_message_content_chunks_serde() { + let content = json!([ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + ]); + let chunks: ContentChunks = message_content_serde::deserialize(content) + .expect("Failed to deserialize") + .unwrap(); + + assert_eq!( + chunks, + ContentChunks(vec![ + ContentChunk::Text("Whats in this image?".to_string()), + ContentChunk::ImageUrl("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string()) + ]) + ); + + let flattened = serde_json::to_string(&chunks).unwrap(); + + assert_eq!( + flattened, + r#""Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)""# + ); + } }