mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: reduce and refactor changes
This commit is contained in:
parent
27cd254b89
commit
24a5588735
@ -1000,6 +1000,7 @@ async fn chat_completions(
|
|||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
|
temperature
|
||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
@ -1008,6 +1009,10 @@ async fn chat_completions(
|
|||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
|
// rescale where 0 is deterministic and 1 is random (this is the opposite of other endpoints)
|
||||||
|
let adjusted_temperature = temperature.map_or(1.0, |t| 1.0 - t);
|
||||||
|
let do_sample = adjusted_temperature > 0.0;
|
||||||
|
let temperature = Some(adjusted_temperature);
|
||||||
|
|
||||||
// extract tool grammar if present
|
// extract tool grammar if present
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
@ -1054,13 +1059,13 @@ async fn chat_completions(
|
|||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature.map(|t| 1.0 - t),
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: req.temperature.map_or(true, |t| t > 0.0),
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop,
|
stop,
|
||||||
|
@ -273,25 +273,23 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any(x != 1.0 for x in temperature):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
# 1 and 0 both mean no sampling in different contexts
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
sample or x == 1.0 or x == 0.0 or math.isclose(x, 0.0)
|
|
||||||
for x, sample in zip(temperature, do_sample)
|
|
||||||
]
|
]
|
||||||
warpers.append(
|
warpers.append(
|
||||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 0 for x in top_k]):
|
if any(x != 0 for x in top_k):
|
||||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in top_p]):
|
if any(x < 1.0 for x in top_p):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in typical_p]):
|
if any(x < 1.0 for x in typical_p):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user