Fix the openai backend + evading the @property of

tokenizer.eos_token_id.
This commit is contained in:
Nicolas Patry 2024-04-25 16:20:25 +00:00
parent fd705ef292
commit 1f1885d911
5 changed files with 48 additions and 17 deletions

View File

@ -589,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, // TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub content: Option<String>, pub content: Option<String>,
@ -623,6 +625,31 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
@ -631,19 +658,7 @@ impl ChatCompletionChunk {
system_fingerprint, system_fingerprint,
choices: vec![ChatCompletionChoice { choices: vec![ChatCompletionChoice {
index: 0, index: 0,
delta: ChatCompletionDelta { delta,
role: "assistant".to_string(),
content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
},
logprobs, logprobs,
finish_reason, finish_reason,
}], }],

View File

@ -1097,7 +1097,13 @@ async fn chat_completions(
let (content, tool_calls) = if tool_grammar.is_some() { let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text])) (None, Some(vec![stream_token.token.text]))
} else { } else {
(Some(stream_token.token.text), None) let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
}; };
event event

View File

@ -56,7 +56,8 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
if isinstance(generation_config.eos_token_id, (list, set)): if isinstance(generation_config.eos_token_id, (list, set)):
tokenizer.eos_token_id = set(tokenizer.eos_token_id) # TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code

View File

@ -27,7 +27,14 @@ class Model(ABC):
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding self.requires_padding = requires_padding
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device

View File

@ -197,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria( return StoppingCriteria(
tokenizer.eos_token_id, eos_token_id,
stop_sequence_criterias, stop_sequence_criterias,
pb.max_new_tokens, pb.max_new_tokens,
pb.ignore_eos_token, pb.ignore_eos_token,