ai-gateway/controller/relay.go

324 lines
9.2 KiB
Go
Raw Normal View History

2023-04-23 10:24:11 +00:00
package controller
import (
"fmt"
"net/http"
"one-api/common"
2023-07-15 11:06:51 +00:00
"strconv"
2023-04-25 12:27:53 +00:00
"strings"
"github.com/gin-gonic/gin"
2023-04-23 10:24:11 +00:00
)
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
2023-06-19 02:28:55 +00:00
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
2023-05-21 06:26:59 +00:00
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
2023-05-21 06:26:59 +00:00
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
2023-05-21 06:26:59 +00:00
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
2023-05-15 02:48:52 +00:00
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
2023-05-15 02:48:52 +00:00
}
type TextRequest struct {
2023-05-16 08:18:35 +00:00
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
ResponseFormat string `json:"response_format"`
Style string `json:"style"`
User string `json:"user"`
}
type WhisperResponse struct {
2023-08-27 07:28:23 +00:00
Text string `json:"text,omitempty"`
}
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
2023-05-15 02:48:52 +00:00
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
2023-06-11 01:49:57 +00:00
Code any `json:"code"`
2023-05-15 02:48:52 +00:00
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
2023-07-22 08:18:03 +00:00
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
2023-07-22 08:18:03 +00:00
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
2023-07-22 08:18:03 +00:00
}
type ChatCompletionsStreamResponse struct {
2023-07-22 08:18:03 +00:00
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
2023-04-23 10:24:11 +00:00
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
2023-06-19 02:28:55 +00:00
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
2023-06-19 02:28:55 +00:00
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
2023-08-27 07:28:23 +00:00
err = relayAudioHelper(c, relayMode)
2023-06-19 02:28:55 +00:00
default:
2023-06-19 07:00:22 +00:00
err = relayTextHelper(c, relayMode)
}
if err != nil {
2023-09-17 07:39:46 +00:00
requestId := c.GetString(common.RequestIdKey)
2023-07-15 11:06:51 +00:00
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
2023-08-12 02:20:54 +00:00
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
2023-07-15 11:06:51 +00:00
}
2023-09-17 07:39:46 +00:00
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
2023-07-15 11:06:51 +00:00
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
2023-09-17 07:39:46 +00:00
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
2023-05-21 12:58:00 +00:00
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
}
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
Code: "api_not_implemented",
}
2023-06-23 14:59:44 +00:00
c.JSON(http.StatusNotImplemented, gin.H{
"error": err,
})
}
func RelayNotFound(c *gin.Context) {
err := OpenAIError{
2023-08-11 11:53:01 +00:00
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
2023-08-11 11:53:01 +00:00
Code: "",
}
2023-06-23 14:59:44 +00:00
c.JSON(http.StatusNotFound, gin.H{
"error": err,
})
}