feat: improve serde add tests and cleanup

This commit is contained in:
drbh 2024-05-15 17:26:50 +00:00
parent c98a6b9948
commit f8be8d5da7
5 changed files with 186 additions and 31 deletions

View File

@ -29,6 +29,7 @@ from text_generation.types import (
ChatCompletionComplete,
Completion,
)
from huggingface_hub import AsyncInferenceClient
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@ -225,6 +226,7 @@ class GenerousResponseComparator(ResponseComparator):
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
self.inference_client = AsyncInferenceClient(f"http://localhost:{port}")
def _inner_health(self):
raise NotImplementedError

View File

@ -0,0 +1,25 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": " The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a",
"name": null,
"role": "assistant",
"tool_calls": null
}
}
],
"created": 1715786475,
"id": "",
"model": "llava-hf/llava-v1.6-mistral-7b-hf",
"object": "text_completion",
"system_fingerprint": "2.0.2-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 2943,
"total_tokens": 3043
}
}

View File

@ -0,0 +1,41 @@
import pytest
@pytest.fixture(scope="module")
def flash_llava_next_handle(launcher):
with launcher("llava-hf/llava-v1.6-mistral-7b-hf") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llava_chat(flash_llava_next_handle):
await flash_llava_next_handle.health(3000)
return flash_llava_next_handle.inference_client
@pytest.mark.private
async def test_flash_llava_simple(flash_llava_chat, response_snapshot):
response = await flash_llava_chat.chat_completion(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
],
},
],
seed=42,
max_tokens=100,
)
assert (
response.choices[0].message.content
== " The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a"
)
assert response == response_snapshot

View File

@ -375,8 +375,6 @@ impl ChatTemplate {
}
}
println!("{:?}", messages);
self.template
.render(ChatTemplateInputs {
messages,

View File

@ -8,7 +8,8 @@ mod validation;
use infer::{Infer, InferError, InferStreamResponse};
use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use utoipa::ToSchema;
@ -896,16 +897,12 @@ pub(crate) struct ImageUrl {
pub url: String,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
#[derive(Clone, Serialize, Debug, PartialEq)]
enum ContentChunk {
Text(String),
ImageUrl(String),
}
#[derive(Clone, Deserialize, Debug)]
struct ContentChunks(Vec<ContentChunk>);
// Convert in and out of ContentChunk
impl From<String> for ContentChunk {
fn from(s: String) -> Self {
ContentChunk::Text(s)
@ -918,7 +915,29 @@ impl From<&str> for ContentChunk {
}
}
// Convert in and out of ContentChunks
impl<'de> Deserialize<'de> for ContentChunk {
fn deserialize<D>(deserializer: D) -> Result<ContentChunk, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(tag = "type")]
enum ContentType {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
match ContentType::deserialize(deserializer)? {
ContentType::Text { text } => Ok(ContentChunk::Text(text)),
ContentType::ImageUrl { image_url } => Ok(ContentChunk::ImageUrl(image_url.url)),
}
}
}
#[derive(Clone, Deserialize, Debug, PartialEq)]
struct ContentChunks(Vec<ContentChunk>);
impl From<Vec<ContentChunk>> for ContentChunks {
fn from(chunks: Vec<ContentChunk>) -> Self {
Self(chunks)
@ -947,48 +966,63 @@ impl Serialize for ContentChunks {
ContentChunk::ImageUrl(s) => format!("![]({})", s),
})
.collect::<Vec<_>>()
.join(" ");
.join("");
serializer.serialize_str(&formatted)
}
}
fn parse_markdown_to_chunks(s: &str) -> Result<Vec<ContentChunk>, serde_json::Error> {
let mut chunks = Vec::new();
let re = Regex::new(r"!\[([^\]]*)\]\(([^)]+)\)").unwrap();
let mut last_index = 0;
for cap in re.captures_iter(s) {
if let Some(m) = cap.get(0) {
if m.start() > last_index {
chunks.push(ContentChunk::Text(s[last_index..m.start()].to_string()));
}
let _alt_text = cap.get(1).map(|m| m.as_str().to_string());
let url = cap.get(2).map(|m| m.as_str().to_string()).unwrap();
chunks.push(ContentChunk::ImageUrl(url));
last_index = m.end();
}
}
if last_index < s.len() {
chunks.push(ContentChunk::Text(s[last_index..].to_string()));
}
Ok(chunks)
}
mod message_content_serde {
use super::*;
use serde::de::{Deserialize, Deserializer, Error};
use serde::{de::Error, Deserialize, Deserializer};
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ContentChunks>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(vec![s.into()].into())),
match Value::deserialize(deserializer)? {
Value::String(s) => {
let chunks = parse_markdown_to_chunks(&s).map_err(Error::custom)?;
Ok(Some(chunks.into()))
}
Value::Array(arr) => arr
.into_iter()
.map(|v| match v {
Value::String(s) => Ok(ContentChunk::Text(s)),
Value::Object(map) => match map
.get("image_url")
.and_then(|x| x.get("url").and_then(|u| u.as_str()))
{
Some(url) => Ok(ContentChunk::ImageUrl(url.to_string())),
None => map
.get("text")
.and_then(|t| t.as_str())
.map(|text| Ok(ContentChunk::Text(text.to_string())))
.map_or_else(
|| Err(Error::custom("Expected a string or an object")),
|x| x,
),
},
_ => Err(Error::custom("Expected a string or an object")),
Value::Object(map) => serde_json::from_value(Value::Object(map))
.map_err(|e| Error::custom(e.to_string())),
_ => Err(Error::custom("Expected string or object")),
})
.collect::<Result<Vec<_>, _>>()
.map(|chunks| Some(chunks.into())),
Value::Null => Ok(None),
_ => Err(Error::custom("Invalid content format")),
_ => Err(Error::custom(
"Expected string or array of text/image_url objects",
)),
}
}
}
@ -1169,6 +1203,7 @@ pub(crate) struct ErrorResponse {
mod tests {
use super::*;
use serde_json::json;
use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer {
@ -1236,4 +1271,58 @@ mod tests {
);
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
}
#[test]
fn test_message_content_chunks_serde() {
let content = json!("Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
let chunks = message_content_serde::deserialize(content)
.expect("Failed to deserialize")
.unwrap();
assert_eq!(
chunks,
ContentChunks(vec![
ContentChunk::Text("Whats in this image?".to_string()),
ContentChunk::ImageUrl("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string())
])
);
let flattened = serde_json::to_string(&chunks).unwrap();
assert_eq!(
flattened,
r#""Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)""#
);
}
#[test]
fn test_typed_message_content_chunks_serde() {
let content = json!([
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
]);
let chunks: ContentChunks = message_content_serde::deserialize(content)
.expect("Failed to deserialize")
.unwrap();
assert_eq!(
chunks,
ContentChunks(vec![
ContentChunk::Text("Whats in this image?".to_string()),
ContentChunk::ImageUrl("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string())
])
);
let flattened = serde_json::to_string(&chunks).unwrap();
assert_eq!(
flattened,
r#""Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)""#
);
}
}