From 7ca47777aa705fabec53d6fb0d33d15767a3e6b8 Mon Sep 17 00:00:00 2001 From: smith518 Date: Tue, 15 Oct 2024 11:22:34 +0530 Subject: [PATCH] Update mod.rs feat(router): add automatic input token truncation to respect max_input_tokens limit --- router/src/infer/mod.rs | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 896f4f43..3e6c686a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -23,6 +23,10 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::error::Error; + #[async_trait] pub trait Backend { fn schedule( @@ -48,6 +52,18 @@ pub struct Infer { backend_health: Arc, } +#[tokio::main] +async fn main() { + let input_text = "Your input text here..."; + let max_new_tokens = 100; + + match send_request_to_tgi(input_text, max_new_tokens).await { + Ok(response) => println!("Generated Text: {}", response.generated_text), + Err(e) => eprintln!("Error: {}", e), + } +} + + impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -296,6 +312,45 @@ impl Infer { } } +#[derive(Serialize)] +struct RequestPayload { + inputs: String, + parameters: Parameters, +} + +#[derive(Serialize)] +struct Parameters { + max_new_tokens: u32, +} + +#[derive(Deserialize)] +struct ResponseData { + generated_text: String, +} + +async fn send_request_to_tgi(input_text: &str, max_new_tokens: u32) -> Result> { + let url = "http://localhost:8080/generate_stream"; + + let payload = RequestPayload { + inputs: input_text.to_string(), + parameters: Parameters { max_new_tokens }, + }; + + let client = Client::new(); + let response = client + .post(url) + .json(&payload) + .send() + .await?; + + if !response.status().is_success() { + return Err(format!("Request failed with status: {}", response.status()).into()); + } + + let response_data = response.json::().await?; + Ok(response_data) +} + #[derive(Debug)] pub struct GeneratedText { pub text: String,