Use Default trait when parameters: null

This commit is contained in:
Alvaro Bartolome 2024-09-23 21:23:39 +02:00
parent 8ef3da72e1
commit 4ac0cd2339
No known key found for this signature in database
2 changed files with 18 additions and 34 deletions

View File

@ -62,7 +62,7 @@ pub(crate) struct GenerateVertexInstance {
pub parameters: Option<GenerateParameters>, pub parameters: Option<GenerateParameters>,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
pub(crate) struct ChatRequestParameters { pub(crate) struct ChatRequestParameters {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.

View File

@ -1452,43 +1452,27 @@ async fn vertex_compatibility(
}, },
VertexInstance::Chat(instance) => { VertexInstance::Chat(instance) => {
let messages = instance.messages; let messages = instance.messages;
let ChatRequestParameters { let chat_request: ChatRequestParameters = instance.parameters.unwrap_or_default();
model,
max_tokens,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.parameters.unwrap();
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = chat_request.presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100)); let max_new_tokens = chat_request.max_tokens.or(Some(100));
let tool_prompt = tool_prompt let tool_prompt = chat_request
.tool_prompt
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt); .unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default(); let stop = chat_request.stop.unwrap_or_default();
// enable greedy only when temperature is 0 // enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature { let (do_sample, temperature) = match chat_request.temperature {
Some(temperature) if temperature == 0.0 => (false, None), Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other), other => (true, other),
}; };
let (inputs, grammar, _using_tools) = match prepare_chat_input( let (inputs, grammar, _using_tools) = match prepare_chat_input(
&infer, &infer,
response_format, chat_request.response_format,
tools, chat_request.tools,
tool_choice, chat_request.tool_choice,
&tool_prompt, &tool_prompt,
guideline, chat_request.guideline,
messages, messages,
) { ) {
Ok(result) => result, Ok(result) => result,
@ -1510,9 +1494,9 @@ async fn vertex_compatibility(
best_of: None, best_of: None,
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty, frequency_penalty: chat_request.frequency_penalty,
top_k: None, top_k: None,
top_p, top_p: chat_request.top_p,
typical_p: None, typical_p: None,
do_sample, do_sample,
max_new_tokens, max_new_tokens,
@ -1521,11 +1505,11 @@ async fn vertex_compatibility(
truncate: None, truncate: None,
watermark: false, watermark: false,
details: true, details: true,
decoder_input_details: !stream, decoder_input_details: !chat_request.stream,
seed, seed: chat_request.seed,
top_n_tokens: top_logprobs, top_n_tokens: chat_request.top_logprobs,
grammar, grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from), adapter_id: chat_request.model.filter(|m| *m != "tgi").map(String::from),
}, },
} }
} }