feat: improve logs by passing span to internal functions

This commit is contained in:
drbh 2024-04-17 02:44:47 +00:00
parent bd28c36815
commit a7bf3196d4

View File

@ -164,6 +164,15 @@ async fn generate(
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
}
async fn generate_internal(
infer: Extension<Infer>,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -362,12 +371,13 @@ async fn generate_stream(
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let span = tracing::Span::current();
let on_message_callback = |stream_token: StreamResponse| { let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await; generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
@ -377,8 +387,8 @@ async fn generate_stream_internal(
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event, on_message_callback: impl Fn(StreamResponse) -> Event,
span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -584,6 +594,7 @@ async fn completions(
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<CompletionRequest>, Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let stream = req.stream;
@ -657,7 +668,7 @@ async fn completions(
format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let infer_clone = infer.clone(); let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone(); let compute_type_clone = compute_type.clone();
let params_clone = generate_request.parameters.clone(); let span_clone = span.clone();
// Create a future for each generate_stream_internal call. // Create a future for each generate_stream_internal call.
let generate_future = async move { let generate_future = async move {
@ -691,32 +702,28 @@ async fn completions(
let (header_tx, header_rx) = oneshot::channel(); let (header_tx, header_rx) = oneshot::channel();
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn( tokio::spawn(async move {
async move { let (header_map, sse) = generate_stream_internal(
let (header_map, sse) = generate_stream_internal( infer_clone.clone(),
infer_clone.clone(), compute_type_clone.clone(),
compute_type_clone.clone(), Json(generate_request),
Json(generate_request), on_message_callback,
on_message_callback, span_clone.clone(),
) )
.await; .await;
// send and dont wait for response // send and dont wait for response
let _ = header_tx.send(header_map); let _ = header_tx.send(header_map);
// pin an emit messages to the sse_tx // pin an emit messages to the sse_tx
let mut sse = Box::pin(sse); let mut sse = Box::pin(sse);
while let Some(event) = sse.next().await { while let Some(event) = sse.next().await {
if sse_tx.send(event).is_err() { if sse_tx.send(event).is_err() {
tracing::error!("Failed to send event. Receiver dropped."); tracing::error!("Failed to send event. Receiver dropped.");
break; break;
}
} }
} }
.instrument( });
tracing::info_span!("request", index = %index, parameters = ?params_clone),
),
);
(header_rx, sse_rx) (header_rx, sse_rx)
}; };
@ -802,17 +809,17 @@ async fn completions(
for (index, generate_request) in generate_requests.into_iter().enumerate() { for (index, generate_request) in generate_requests.into_iter().enumerate() {
let infer_clone = infer.clone(); let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone(); let compute_type_clone = compute_type.clone();
let params_clone = generate_request.parameters.clone(); let span_clone = span.clone();
let response_future = async move { let response_future = async move {
let result = generate( let result = generate_internal(
Extension(infer_clone), Extension(infer_clone),
Extension(compute_type_clone), compute_type_clone,
Json(generate_request), Json(generate_request),
span_clone,
) )
.await; .await;
result.map(|(headers, generation)| (index, headers, generation)) result.map(|(headers, generation)| (index, headers, generation))
} };
.instrument(tracing::info_span!("request", index = %index, parameters = ?params_clone));
responses.push(response_future); responses.push(response_future);
} }
let generate_responses = responses.try_collect::<Vec<_>>().await?; let generate_responses = responses.try_collect::<Vec<_>>().await?;
@ -979,6 +986,7 @@ async fn chat_completions(
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let ChatRequest { let ChatRequest {
@ -1116,17 +1124,14 @@ async fn chat_completions(
compute_type, compute_type,
Json(generate_request), Json(generate_request),
on_message_callback, on_message_callback,
span,
) )
.await; .await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = generate( let (headers, Json(generation)) =
Extension(infer), generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -1223,6 +1228,7 @@ async fn vertex_compatibility(
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>, Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
// check that theres at least one instance // check that theres at least one instance
@ -1254,10 +1260,11 @@ async fn vertex_compatibility(
}; };
async { async {
generate( generate_internal(
Extension(infer.clone()), Extension(infer.clone()),
Extension(compute_type.clone()), compute_type.clone(),
Json(generate_request), Json(generate_request),
span.clone(),
) )
.await .await
.map(|(_, Json(generation))| generation.generated_text) .map(|(_, Json(generation))| generation.generated_text)