mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix the openai backend + evading the @property of
tokenizer.eos_token_id.
This commit is contained in:
parent
fd705ef292
commit
1f1885d911
@ -589,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
// TODO Modify this to a true enum.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub content: Option<String>,
|
||||
@ -623,6 +625,31 @@ impl ChatCompletionChunk {
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
finish_reason: Option<String>,
|
||||
) -> Self {
|
||||
let delta = match (delta, tool_calls) {
|
||||
(Some(delta), _) => ChatCompletionDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(delta),
|
||||
tool_calls: None,
|
||||
},
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: Some(DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
(None, None) => ChatCompletionDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
object: "text_completion".to_string(),
|
||||
@ -631,19 +658,7 @@ impl ChatCompletionChunk {
|
||||
system_fingerprint,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: ChatCompletionDelta {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tc[0].to_string(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
delta,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
}],
|
||||
|
@ -1097,7 +1097,13 @@ async fn chat_completions(
|
||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
(Some(stream_token.token.text), None)
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
event
|
||||
|
@ -56,7 +56,8 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||
tokenizer.eos_token_id = set(tokenizer.eos_token_id)
|
||||
# TODO Huge hack
|
||||
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
|
@ -27,7 +27,14 @@ class Model(ABC):
|
||||
):
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# all_special_ids is not set correctly if the rust tokenizer is unpacked
|
||||
# TODO report this to transformers.
|
||||
other_special_ids = {
|
||||
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
|
||||
}
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
self.all_special_ids.update(other_special_ids)
|
||||
self.requires_padding = requires_padding
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
@ -197,8 +197,10 @@ class StoppingCriteria:
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
]
|
||||
# TODO Hack because eos_token_id cannot be what we want.
|
||||
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
|
||||
return StoppingCriteria(
|
||||
tokenizer.eos_token_id,
|
||||
eos_token_id,
|
||||
stop_sequence_criterias,
|
||||
pb.max_new_tokens,
|
||||
pb.ignore_eos_token,
|
||||
|
Loading…
Reference in New Issue
Block a user