diff --git a/providers/cohere/chat.go b/providers/cohere/chat.go index 14eb8bfe..fd845043 100644 --- a/providers/cohere/chat.go +++ b/providers/cohere/chat.go @@ -144,13 +144,9 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereRespon Created: common.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, Model: request.Model, - Usage: &types.Usage{ - PromptTokens: response.Meta.BilledUnits.InputTokens, - }, + Usage: &types.Usage{}, } - - openaiResponse.Usage.CompletionTokens = response.Meta.BilledUnits.OutputTokens + response.Meta.Tokens.SearchUnits + response.Meta.Tokens.Classifications - openaiResponse.Usage.TotalTokens = openaiResponse.Usage.PromptTokens + openaiResponse.Usage.CompletionTokens + *openaiResponse.Usage = usageHandle(&response.Meta.BilledUnits) usage := provider.GetUsage() *usage = *openaiResponse.Usage @@ -188,6 +184,7 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream if cohereResponse.EventType == "stream-end" { choice.FinishReason = types.FinishReasonStop + *h.Usage = usageHandle(&cohereResponse.Response.Meta.BilledUnits) } else { choice.Delta = types.ChatCompletionStreamChoiceDelta{ Role: types.ChatMessageRoleAssistant, @@ -206,3 +203,13 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream responseBody, _ := json.Marshal(chatCompletion) dataChan <- string(responseBody) } + +func usageHandle(token *Tokens) types.Usage { + usage := types.Usage{ + PromptTokens: token.InputTokens, + CompletionTokens: token.OutputTokens + token.SearchUnits + token.Classifications, + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return usage +}