From 1f1885d911baa4ab16160b873896ec97ef13febc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 25 Apr 2024 16:20:25 +0000 Subject: [PATCH] Fix the openai backend + evading the @property of tokenizer.eos_token_id. --- router/src/lib.rs | 43 +++++++++++++------ router/src/server.rs | 8 +++- .../models/flash_llama.py | 3 +- server/text_generation_server/models/model.py | 7 +++ server/text_generation_server/utils/tokens.py | 4 +- 5 files changed, 48 insertions(+), 17 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index ecd8e2e0..9b9097f6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -589,7 +589,9 @@ pub(crate) struct ChatCompletionChoice { #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionDelta { #[schema(example = "user")] - pub role: String, + // TODO Modify this to a true enum. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub role: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "What is Deep Learning?")] pub content: Option, @@ -623,6 +625,31 @@ impl ChatCompletionChunk { logprobs: Option, finish_reason: Option, ) -> Self { + let delta = match (delta, tool_calls) { + (Some(delta), _) => ChatCompletionDelta { + role: Some("assistant".to_string()), + content: Some(delta), + tool_calls: None, + }, + (None, Some(tool_calls)) => ChatCompletionDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(DeltaToolCall { + index: 0, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: None, + arguments: tool_calls[0].to_string(), + }, + }), + }, + (None, None) => ChatCompletionDelta { + role: None, + content: None, + tool_calls: None, + }, + }; Self { id: String::new(), object: "text_completion".to_string(), @@ -631,19 +658,7 @@ impl ChatCompletionChunk { system_fingerprint, choices: vec![ChatCompletionChoice { index: 0, - delta: ChatCompletionDelta { - role: "assistant".to_string(), - content: delta, - tool_calls: tool_calls.map(|tc| DeltaToolCall { - index: 0, - id: String::new(), - r#type: "function".to_string(), - function: Function { - name: None, - arguments: tc[0].to_string(), - }, - }), - }, + delta, logprobs, finish_reason, }], diff --git a/router/src/server.rs b/router/src/server.rs index 302a4753..089de7be 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1097,7 +1097,13 @@ async fn chat_completions( let (content, tool_calls) = if tool_grammar.is_some() { (None, Some(vec![stream_token.token.text])) } else { - (Some(stream_token.token.text), None) + let content = if !stream_token.token.special { + Some(stream_token.token.text) + } else { + None + }; + + (content, None) }; event diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 612a071d..f638a95e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -56,7 +56,8 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) if isinstance(generation_config.eos_token_id, (list, set)): - tokenizer.eos_token_id = set(tokenizer.eos_token_id) + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index cec9eafa..4f35b0aa 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -27,7 +27,14 @@ class Model(ABC): ): self.model = model.eval() self.tokenizer = tokenizer + + # all_special_ids is not set correctly if the rust tokenizer is unpacked + # TODO report this to transformers. + other_special_ids = { + id for id, token in tokenizer.added_tokens_decoder.items() if token.special + } self.all_special_ids = set(tokenizer.all_special_ids) + self.all_special_ids.update(other_special_ids) self.requires_padding = requires_padding self.dtype = dtype self.device = device diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 49ef2d3b..520f3452 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -197,8 +197,10 @@ class StoppingCriteria: stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] + # TODO Hack because eos_token_id cannot be what we want. + eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id) return StoppingCriteria( - tokenizer.eos_token_id, + eos_token_id, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token,