diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 00000000..b406d43b --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,18 @@ +on: + push: + +name: Secret Leaks + +permissions: + contents: read + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main diff --git a/Cargo.lock b/Cargo.lock index b5de8576..b9bd7363 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1856,12 +1856,23 @@ dependencies = [ [[package]] name = "minijinja" -version = "1.0.12" -source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4" dependencies = [ "serde", ] +[[package]] +name = "minijinja-contrib" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07" +dependencies = [ + "minijinja", + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3604,6 +3615,7 @@ dependencies = [ "metrics", "metrics-exporter-prometheus", "minijinja", + "minijinja-contrib", "ngrok", "nohash-hasher", "once_cell", diff --git a/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json new file mode 100644 index 00000000..83390832 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}", + "role": "assistant" + } + } + ], + "created": 1718044128, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.5-dev0-native", + "usage": { + "completion_tokens": 39, + "prompt_tokens": 136, + "total_tokens": 175 + } +} diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py new file mode 100644 index 00000000..9c4c048e --- /dev/null +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -0,0 +1,101 @@ +import pytest +import requests +from pydantic import BaseModel +from typing import List + + +@pytest.fixture(scope="module") +def llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + num_shard=1, + disable_grammar_support=False, + use_flash_attention=False, + max_batch_prefill_tokens=3000, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def llama_grammar(llama_grammar_handle): + await llama_grammar_handle.health(300) + return llama_grammar_handle.client + + +@pytest.mark.asyncio +async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): + + class Weather(BaseModel): + unit: str + temperature: List[int] + + # send the request + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + "seed": 42, + "max_tokens": 500, + "response_format": {"type": "json_object", "value": Weather.schema()}, + }, + ) + + chat_completion = response.json() + called = chat_completion["choices"][0]["message"]["content"] + + assert response.status_code == 200 + assert ( + called + == '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}' + ) + assert chat_completion == response_snapshot + + +@pytest.mark.asyncio +async def test_grammar_response_format_llama_error_if_tools_not_installed( + llama_grammar, +): + class Weather(BaseModel): + unit: str + temperature: List[int] + + # send the request + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + "seed": 42, + "max_tokens": 500, + "tools": [], + "response_format": {"type": "json_object", "value": Weather.schema()}, + }, + ) + + # 422 means the server was unable to process the request because it contains invalid data. + assert response.status_code == 422 + assert response.json() == { + "error": "Grammar and tools are mutually exclusive", + "error_type": "grammar and tools", + } diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e6264be..5bf4c00c 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -44,7 +44,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } -minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" } +minijinja = { version = "2.0.2" } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" @@ -58,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } default = ["ngrok"] ngrok = ["dep:ngrok"] google = [] +kserve = [] diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 20630c1b..07c334a3 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -12,6 +12,8 @@ use crate::{ use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; + use serde_json::{json, Map, Value}; use std::collections::HashMap; use std::sync::Arc; @@ -62,14 +64,7 @@ impl Infer { .find(|t| t.name == "default") .map(|t| t.template), }) - .map(|t| { - // .strip() is not supported in minijinja - // .capitalize() is not supported in minijinja but we can use | capitalize - let t = t - .replace(".strip()", " | trim") - .replace(".capitalize()", " | capitalize"); - ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) - }); + .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -277,6 +272,8 @@ struct ChatTemplate { impl ChatTemplate { fn new(template: String, bos_token: Option, eos_token: Option) -> Self { let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); diff --git a/router/src/kserve.rs b/router/src/kserve.rs new file mode 100644 index 00000000..b64efd38 --- /dev/null +++ b/router/src/kserve.rs @@ -0,0 +1,247 @@ +use crate::{ + default_parameters, + server::{generate_internal, ComputeType}, + Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema, +}; +use axum::extract::{Extension, Path}; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use futures::stream::FuturesUnordered; +use futures::TryStreamExt; +use reqwest::header::HeaderMap; +use reqwest::StatusCode; + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct OutputChunk { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct InferenceOutput { + pub id: String, + pub outputs: Vec, +} + +#[derive(Debug, Deserialize, ToSchema)] +pub(crate) struct InferenceRequest { + pub id: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, + pub inputs: Vec, + pub outputs: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub(crate) struct Input { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub(crate) struct Output { + pub name: String, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct LiveResponse { + pub live: bool, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ReadyResponse { + pub live: bool, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct MetadataServerResponse { + pub name: String, + pub version: String, + pub extensions: Vec, +} + +// Routes + +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/health/live", + responses( + (status = 200, description = "Service is live", body = LiveReponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_health_live() -> Result)> { + let data = LiveResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/health/ready", + responses( + (status = 200, description = "Service is ready", body = ReadyResponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_health_ready() -> Result)> { + let data = ReadyResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2", + responses( + (status = 200, description = "Metadata retrieved", body = MetadataServerResponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kerve_server_metadata() -> Result)> { + let data = MetadataServerResponse { + name: "text-generation-inference".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + extensions: vec![ + "health".to_string(), + "models".to_string(), + "metrics".to_string(), + ], + }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}", + responses( + (status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_metadata( + Path((model_name, model_version)): Path<(String, String)>, +) -> Result)> { + let data = MetadataServerResponse { + name: model_name, + version: model_version, + extensions: vec!["infer".to_string(), "ready".to_string()], + }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/infer", + request_body = Json, + responses( + (status = 200, description = "Inference executed successfully", body = InferenceOutput), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_infer( + infer: Extension, + Extension(compute_type): Extension, + Json(payload): Json, +) -> Result)> { + let id = payload.id.clone(); + let str_inputs = payload + .inputs + .iter() + .map(|input| { + std::str::from_utf8(&input.data).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "utf8".to_string(), + }), + ) + }) + }) + .collect::, _>>()?; + + if str_inputs.len() != payload.outputs.len() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Inputs and outputs length mismatch".to_string(), + error_type: "length mismatch".to_string(), + }), + )); + } + + let output_chunks = str_inputs + .iter() + .zip(&payload.outputs) + .map(|(str_input, output)| { + let generate_request = GenerateRequest { + inputs: str_input.to_string(), + parameters: payload.parameters.clone(), + }; + let infer = infer.clone(); + let compute_type = compute_type.clone(); + let span = tracing::Span::current(); + async move { + generate_internal(infer, compute_type, Json(generate_request), span) + .await + .map(|(_, Json(generation))| { + let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); + OutputChunk { + name: output.name.clone(), + shape: vec![1, generation_as_bytes.len()], + datatype: "BYTES".to_string(), + data: generation_as_bytes, + } + }) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + } + }) + .collect::>() + .try_collect::>() + .await?; + + let inference_output = InferenceOutput { + id: id.clone(), + outputs: output_chunks, + }; + + Ok((HeaderMap::new(), Json(inference_output)).into_response()) +} + +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/ready", + responses( + (status = 200, description = "Model version is ready", body = ReadyResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) + ) +)] +pub async fn kserve_model_metadata_ready( + Path((_model_name, _model_version)): Path<(String, String)>, +) -> Result)> { + let data = ReadyResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 08d57873..bb407c5f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -4,6 +4,9 @@ mod infer; pub mod server; mod validation; +#[cfg(feature = "kserve")] +mod kserve; + use serde::{Deserialize, Serialize}; use tracing::warn; use utoipa::ToSchema; @@ -89,6 +92,7 @@ pub(crate) enum GrammarType { /// JSON Schema is a declarative language that allows to annotate JSON documents /// with types and descriptions. #[serde(rename = "json")] + #[serde(alias = "json_object")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] Json(serde_json::Value), #[serde(rename = "regex")] @@ -797,6 +801,13 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, example = "null")] #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, + + /// Response format constraints for the generation. + /// + /// NOTE: A request can use `response_format` OR `tools` but not both. + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub response_format: Option, } fn default_tool_prompt() -> Option { diff --git a/router/src/server.rs b/router/src/server.rs index fee8ebdf..6e6b93b6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,6 +4,11 @@ use crate::infer::v2::SchedulerV2; use crate::infer::v3::SchedulerV3; use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +#[cfg(feature = "kserve")] +use crate::kserve::{ + kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, + kserve_model_metadata, kserve_model_metadata_ready, +}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, @@ -172,7 +177,7 @@ async fn generate( generate_internal(infer, ComputeType(compute_type), Json(req), span).await } -async fn generate_internal( +pub(crate) async fn generate_internal( infer: Extension, ComputeType(compute_type): ComputeType, Json(req): Json, @@ -1017,6 +1022,7 @@ async fn chat_completions( tool_choice, tool_prompt, temperature, + response_format, .. } = req; @@ -1031,6 +1037,18 @@ async fn chat_completions( other => (true, other), }; + // response_format and tools are mutually exclusive + if response_format.is_some() && tools.as_ref().is_some() { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Grammar and tools are mutually exclusive".to_string(), + error_type: "grammar and tools".to_string(), + }), + )); + } + // extract tool grammar if present let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, @@ -1047,16 +1065,21 @@ async fn chat_completions( } }; - let grammar_with_prompt = tool_grammar + // determine the appropriate arguments for apply_chat_template + let tools_grammar_prompt = tool_grammar .as_ref() .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - let typed_grammar = grammar_with_prompt - .as_ref() - .map(|(grammar, _)| grammar.clone()); + let (tools_grammar_prompt, grammar) = match response_format { + Some(response_format) => (None, Some(response_format)), + None => ( + tools_grammar_prompt.clone(), + tools_grammar_prompt.map(|(grammar, _)| grammar.clone()), + ), + }; // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { + let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -1092,7 +1115,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: req.top_logprobs, - grammar: typed_grammar, + grammar, ..Default::default() }, }; @@ -1711,28 +1734,58 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; - // Define VertextApiDoc conditionally only if the "google" feature is enabled - let doc = { - // avoid `mut` if possible - #[cfg(feature = "google")] - { - use crate::VertexInstance; + #[allow(unused_mut)] // mut is needed for conditional compilation + let mut doc = ApiDoc::openapi(); - #[derive(OpenApi)] - #[openapi( - paths(vertex_compatibility), - components(schemas(VertexInstance, VertexRequest, VertexResponse)) - )] - struct VertextApiDoc; + #[cfg(feature = "google")] + { + use crate::VertexInstance; - // limiting mutability to the smallest scope necessary - let mut doc = ApiDoc::openapi(); - doc.merge(VertextApiDoc::openapi()); - doc - } - #[cfg(not(feature = "google"))] - ApiDoc::openapi() - }; + #[derive(OpenApi)] + #[openapi( + paths(vertex_compatibility), + components(schemas(VertexInstance, VertexRequest, VertexResponse)) + )] + struct VertexApiDoc; + + doc.merge(VertexApiDoc::openapi()); + } + + #[cfg(feature = "kserve")] + { + use crate::kserve::{ + InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk, + ReadyResponse, + }; + use crate::kserve::{ + __path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready, + __path_kserve_model_infer, __path_kserve_model_metadata, + __path_kserve_model_metadata_ready, + }; + + #[derive(OpenApi)] + #[openapi( + paths( + kserve_model_infer, + kserve_health_live, + kserve_health_ready, + kerve_server_metadata, + kserve_model_metadata, + kserve_model_metadata_ready, + ), + components(schemas( + InferenceOutput, + InferenceRequest, + LiveResponse, + MetadataServerResponse, + OutputChunk, + ReadyResponse, + )) + )] + struct KServeApiDoc; + + doc.merge(KServeApiDoc::openapi()); + } // Configure Swagger UI let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); @@ -1782,6 +1835,27 @@ pub async fn run( } } + #[cfg(feature = "kserve")] + { + tracing::info!("Built with `kserve` feature"); + app = app + .route( + "/v2/models/:model_name/versions/:model_version/infer", + post(kserve_model_infer), + ) + .route( + "/v2/models/:model_name/versions/:model_version", + get(kserve_model_metadata), + ) + .route("/v2/health/ready", get(kserve_health_ready)) + .route("/v2/health/live", get(kserve_health_live)) + .route("/v2", get(kerve_server_metadata)) + .route( + "/v2/models/:model_name/versions/:model_version/ready", + get(kserve_model_metadata_ready), + ); + } + // add layers after routes app = app .layer(Extension(info)) diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 8c0437ea..2f2b5ef6 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa -commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0 +commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 648d28ab..c2f12189 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen - if seqlen > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.short_inv_freq - t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) - if self.scaling_factor is not None: - t /= self.scaling_factor - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + freqs = torch.cat([short_freqs, long_freqs]) + + self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 83d62dea..9b2d01e0 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel): return_dict=return_dict, ) - logits, speculative_logits = self.lm_head(outputs) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) loss = None diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 709c15c5..ec325f08 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -8,7 +8,6 @@ from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, - PhiConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = PhiConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 0131bffc..880e0801 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -86,5 +86,4 @@ class GPTNeoxSharded(CausalLM): use_cache=True, ) - logits = outputs.logits - return logits, speculative_logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index aaf81f21..83bd209c 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -76,11 +76,11 @@ class OPTSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) - return outputs.logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index bc2b6db2..37ca277b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -72,11 +72,13 @@ class RW(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + ): # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, + use_cache=True, ) - return outputs.logits, outputs.past_key_values + + return outputs.logits, speculative_logits, outputs.past_key_values