fix: update temperature and sampling logic in chat

This commit is contained in:
drbh 2024-04-15 19:13:18 +00:00
parent 0520bde039
commit 27cd254b89
4 changed files with 6 additions and 8 deletions

View File

@ -709,10 +709,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
}
fn default_tool_prompt() -> Option<String> {

View File

@ -1054,13 +1054,13 @@ async fn chat_completions(
inputs: inputs.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
temperature: req.temperature.map(|t| 1.0 - t),
repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: req.do_sample,
do_sample: req.temperature.map_or(true, |t| t > 0.0),
max_new_tokens,
return_full_text: None,
stop,

View File

@ -217,7 +217,7 @@ impl Validation {
}
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
if temperature < 0.0 {
return Err(ValidationError::Temperature);
}

View File

@ -275,7 +275,9 @@ class HeterogeneousNextTokenChooser:
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
# 1 and 0 both mean no sampling in different contexts
sample or x == 1.0 or x == 0.0 or math.isclose(x, 0.0)
for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)