From 424b24f5fabb4ad70ccb22492ae6295c1e4addaa Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 3 Apr 2024 20:07:02 +0000 Subject: [PATCH] feat: accept list as prompt and use first string --- router/src/lib.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/router/src/lib.rs b/router/src/lib.rs index ddb28848..3c9177c0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -280,6 +280,31 @@ fn default_parameters() -> GenerateParameters { } } +mod prompt_serde { + use serde::{self, Deserialize, Deserializer}; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + match value { + Value::String(s) => Ok(s), + Value::Array(arr) => arr + .first() + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| { + serde::de::Error::custom("array is empty or contains non-string elements") + }), + _ => Err(serde::de::Error::custom( + "expected a string or an array of strings", + )), + } + } +} + #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] pub struct CompletionRequest { /// UNUSED @@ -289,6 +314,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] + #[serde(deserialize_with = "prompt_serde::deserialize")] pub prompt: String, /// The maximum number of tokens that can be generated in the chat completion.