mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix the openai backend + evading the @property of
tokenizer.eos_token_id.
This commit is contained in:
parent
fd705ef292
commit
1f1885d911
@ -589,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
// TODO Modify this to a true enum.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
@ -623,6 +625,31 @@ impl ChatCompletionChunk {
|
|||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let delta = match (delta, tool_calls) {
|
||||||
|
(Some(delta), _) => ChatCompletionDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: Some(delta),
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
(None, Some(tool_calls)) => ChatCompletionDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: Some(DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: tool_calls[0].to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
(None, None) => ChatCompletionDelta {
|
||||||
|
role: None,
|
||||||
|
content: None,
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
@ -631,19 +658,7 @@ impl ChatCompletionChunk {
|
|||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionChoice {
|
choices: vec![ChatCompletionChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: ChatCompletionDelta {
|
delta,
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: delta,
|
|
||||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: None,
|
|
||||||
arguments: tc[0].to_string(),
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
}],
|
}],
|
||||||
|
@ -1097,7 +1097,13 @@ async fn chat_completions(
|
|||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
(Some(stream_token.token.text), None)
|
let content = if !stream_token.token.special {
|
||||||
|
Some(stream_token.token.text)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
(content, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
event
|
event
|
||||||
|
@ -56,7 +56,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||||
tokenizer.eos_token_id = set(tokenizer.eos_token_id)
|
# TODO Huge hack
|
||||||
|
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
@ -27,7 +27,14 @@ class Model(ABC):
|
|||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# all_special_ids is not set correctly if the rust tokenizer is unpacked
|
||||||
|
# TODO report this to transformers.
|
||||||
|
other_special_ids = {
|
||||||
|
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
|
||||||
|
}
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
self.all_special_ids.update(other_special_ids)
|
||||||
self.requires_padding = requires_padding
|
self.requires_padding = requires_padding
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -197,8 +197,10 @@ class StoppingCriteria:
|
|||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
]
|
]
|
||||||
|
# TODO Hack because eos_token_id cannot be what we want.
|
||||||
|
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
tokenizer.eos_token_id,
|
eos_token_id,
|
||||||
stop_sequence_criterias,
|
stop_sequence_criterias,
|
||||||
pb.max_new_tokens,
|
pb.max_new_tokens,
|
||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
|
Loading…
Reference in New Issue
Block a user