diff --git a/router/src/lib.rs b/router/src/lib.rs index ddb28848..3c9177c0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -280,6 +280,31 @@ fn default_parameters() -> GenerateParameters { } } +mod prompt_serde { + use serde::{self, Deserialize, Deserializer}; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + match value { + Value::String(s) => Ok(s), + Value::Array(arr) => arr + .first() + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| { + serde::de::Error::custom("array is empty or contains non-string elements") + }), + _ => Err(serde::de::Error::custom( + "expected a string or an array of strings", + )), + } + } +} + #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] pub struct CompletionRequest { /// UNUSED @@ -289,6 +314,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] + #[serde(deserialize_with = "prompt_serde::deserialize")] pub prompt: String, /// The maximum number of tokens that can be generated in the chat completion.