From a7bf3196d458c4f71b287cbc6d16754717d32776 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 17 Apr 2024 02:44:47 +0000 Subject: [PATCH] feat: improve logs by passing span to internal functions --- router/src/server.rs | 83 ++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index f0cf37d2..302a4753 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -164,6 +164,15 @@ async fn generate( Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + generate_internal(infer, ComputeType(compute_type), Json(req), span).await +} + +async fn generate_internal( + infer: Extension, + ComputeType(compute_type): ComputeType, + Json(req): Json, + span: tracing::Span, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -362,12 +371,13 @@ async fn generate_stream( HeaderMap, Sse>>, ) { + let span = tracing::Span::current(); let on_message_callback = |stream_token: StreamResponse| { let event = Event::default(); event.json_data(stream_token).unwrap() }; let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await; + generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); (headers, sse) } @@ -377,8 +387,8 @@ async fn generate_stream_internal( ComputeType(compute_type): ComputeType, Json(req): Json, on_message_callback: impl Fn(StreamResponse) -> Event, + span: tracing::Span, ) -> (HeaderMap, impl Stream>) { - let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -584,6 +594,7 @@ async fn completions( Extension(info): Extension, Json(req): Json, ) -> Result)> { + let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); let stream = req.stream; @@ -657,7 +668,7 @@ async fn completions( format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let infer_clone = infer.clone(); let compute_type_clone = compute_type.clone(); - let params_clone = generate_request.parameters.clone(); + let span_clone = span.clone(); // Create a future for each generate_stream_internal call. let generate_future = async move { @@ -691,32 +702,28 @@ async fn completions( let (header_tx, header_rx) = oneshot::channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); - tokio::spawn( - async move { - let (header_map, sse) = generate_stream_internal( - infer_clone.clone(), - compute_type_clone.clone(), - Json(generate_request), - on_message_callback, - ) - .await; + tokio::spawn(async move { + let (header_map, sse) = generate_stream_internal( + infer_clone.clone(), + compute_type_clone.clone(), + Json(generate_request), + on_message_callback, + span_clone.clone(), + ) + .await; - // send and dont wait for response - let _ = header_tx.send(header_map); + // send and dont wait for response + let _ = header_tx.send(header_map); - // pin an emit messages to the sse_tx - let mut sse = Box::pin(sse); - while let Some(event) = sse.next().await { - if sse_tx.send(event).is_err() { - tracing::error!("Failed to send event. Receiver dropped."); - break; - } + // pin an emit messages to the sse_tx + let mut sse = Box::pin(sse); + while let Some(event) = sse.next().await { + if sse_tx.send(event).is_err() { + tracing::error!("Failed to send event. Receiver dropped."); + break; } } - .instrument( - tracing::info_span!("request", index = %index, parameters = ?params_clone), - ), - ); + }); (header_rx, sse_rx) }; @@ -802,17 +809,17 @@ async fn completions( for (index, generate_request) in generate_requests.into_iter().enumerate() { let infer_clone = infer.clone(); let compute_type_clone = compute_type.clone(); - let params_clone = generate_request.parameters.clone(); + let span_clone = span.clone(); let response_future = async move { - let result = generate( + let result = generate_internal( Extension(infer_clone), - Extension(compute_type_clone), + compute_type_clone, Json(generate_request), + span_clone, ) .await; result.map(|(headers, generation)| (index, headers, generation)) - } - .instrument(tracing::info_span!("request", index = %index, parameters = ?params_clone)); + }; responses.push(response_future); } let generate_responses = responses.try_collect::>().await?; @@ -979,6 +986,7 @@ async fn chat_completions( Extension(info): Extension, Json(req): Json, ) -> Result)> { + let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); let ChatRequest { @@ -1116,17 +1124,14 @@ async fn chat_completions( compute_type, Json(generate_request), on_message_callback, + span, ) .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = generate( - Extension(infer), - Extension(compute_type), - Json(generate_request), - ) - .await?; + let (headers, Json(generation)) = + generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -1223,6 +1228,7 @@ async fn vertex_compatibility( Extension(compute_type): Extension, Json(req): Json, ) -> Result)> { + let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); // check that theres at least one instance @@ -1254,10 +1260,11 @@ async fn vertex_compatibility( }; async { - generate( + generate_internal( Extension(infer.clone()), - Extension(compute_type.clone()), + compute_type.clone(), Json(generate_request), + span.clone(), ) .await .map(|(_, Json(generation))| generation.generated_text)