ai-gateway/relay/channel/ali/main.go

254 lines
7.6 KiB
Go
Raw Normal View History

2024-01-14 11:21:03 +00:00
package ali
2023-07-28 15:45:08 +00:00
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
2024-01-14 11:21:03 +00:00
"one-api/relay/channel/openai"
2023-07-28 15:45:08 +00:00
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
2024-01-14 11:21:03 +00:00
const EnableSearchModelSuffix = "-internet"
2024-01-14 11:21:03 +00:00
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
2023-07-28 15:45:08 +00:00
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
2023-07-28 15:45:08 +00:00
}
enableSearch := false
aliModel := request.Model
2024-01-14 11:21:03 +00:00
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
2024-01-14 11:21:03 +00:00
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
2024-01-14 11:21:03 +00:00
return &ChatRequest{
Model: aliModel,
2024-01-14 11:21:03 +00:00
Input: Input{
Messages: messages,
2023-07-28 15:45:08 +00:00
},
2024-01-14 11:21:03 +00:00
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
2023-07-28 15:45:08 +00:00
}
}
2024-01-14 11:21:03 +00:00
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
2024-01-14 11:21:03 +00:00
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
2024-01-14 11:21:03 +00:00
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
2024-01-14 11:21:03 +00:00
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
2024-01-14 11:21:03 +00:00
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
2024-01-14 11:21:03 +00:00
Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
2024-01-14 11:21:03 +00:00
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
2024-01-14 11:21:03 +00:00
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
2023-07-28 15:45:08 +00:00
Index: 0,
2024-01-14 11:21:03 +00:00
Message: openai.Message{
2023-07-28 15:45:08 +00:00
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
2024-01-14 11:21:03 +00:00
fullTextResponse := openai.TextResponse{
2023-07-28 15:45:08 +00:00
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
2024-01-14 11:21:03 +00:00
Choices: []openai.TextResponseChoice{choice},
Usage: openai.Usage{
2023-07-28 15:45:08 +00:00
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
2024-01-14 11:21:03 +00:00
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
2023-07-28 15:45:08 +00:00
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
2024-01-14 11:21:03 +00:00
response := openai.ChatCompletionsStreamResponse{
2023-07-28 15:45:08 +00:00
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "qwen",
2024-01-14 11:21:03 +00:00
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
2023-07-28 15:45:08 +00:00
}
return &response
}
2024-01-14 11:21:03 +00:00
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage openai.Usage
2023-07-28 15:45:08 +00:00
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
2024-01-14 11:21:03 +00:00
common.SetEventStreamHeaders(c)
//lastResponseText := ""
2023-07-28 15:45:08 +00:00
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
2024-01-14 11:21:03 +00:00
var aliResponse ChatResponse
2023-07-28 15:45:08 +00:00
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
2023-07-28 15:45:08 +00:00
response := streamResponseAli2OpenAI(&aliResponse)
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
2023-07-28 15:45:08 +00:00
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
return nil, &usage
}
2024-01-14 11:21:03 +00:00
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var aliResponse ChatResponse
2023-07-28 15:45:08 +00:00
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
if aliResponse.Code != "" {
2024-01-14 11:21:03 +00:00
return &openai.ErrorWithStatusCode{
Error: openai.Error{
2023-07-28 15:45:08 +00:00
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
2023-07-28 15:45:08 +00:00
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}