ai-gateway/relay/channel/ali/main.go

268 lines
8.1 KiB
Go
Raw Normal View History

2024-01-14 11:21:03 +00:00
package ali
2023-07-28 15:45:08 +00:00
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
2023-07-28 15:45:08 +00:00
"io"
"net/http"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
2024-01-14 11:21:03 +00:00
const EnableSearchModelSuffix = "-internet"
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
2024-01-14 11:21:03 +00:00
messages := make([]Message, 0, len(request.Messages))
2023-07-28 15:45:08 +00:00
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
2023-07-28 15:45:08 +00:00
}
enableSearch := false
aliModel := request.Model
2024-01-14 11:21:03 +00:00
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
2024-01-14 11:21:03 +00:00
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
if request.TopP >= 1 {
request.TopP = 0.9999
}
2024-01-14 11:21:03 +00:00
return &ChatRequest{
Model: aliModel,
2024-01-14 11:21:03 +00:00
Input: Input{
Messages: messages,
2023-07-28 15:45:08 +00:00
},
2024-01-14 11:21:03 +00:00
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
2024-04-03 16:46:30 +00:00
Tools: request.Tools,
},
2023-07-28 15:45:08 +00:00
}
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
2024-01-14 11:21:03 +00:00
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
2024-01-14 11:21:03 +00:00
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
2024-01-14 11:21:03 +00:00
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
2024-01-14 11:21:03 +00:00
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
2024-01-14 11:21:03 +00:00
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
2024-01-14 11:21:03 +00:00
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
2023-07-28 15:45:08 +00:00
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: response.Output.Choices,
Usage: model.Usage{
2023-07-28 15:45:08 +00:00
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
2024-01-14 11:21:03 +00:00
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
2024-01-14 11:21:03 +00:00
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta = aliChoice.Message
if aliChoice.FinishReason != "null" {
finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
2024-01-14 11:21:03 +00:00
response := openai.ChatCompletionsStreamResponse{
2023-07-28 15:45:08 +00:00
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "qwen",
2024-01-14 11:21:03 +00:00
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
2023-07-28 15:45:08 +00:00
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
2023-07-28 15:45:08 +00:00
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
2024-01-14 11:21:03 +00:00
common.SetEventStreamHeaders(c)
//lastResponseText := ""
2023-07-28 15:45:08 +00:00
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
2024-01-14 11:21:03 +00:00
var aliResponse ChatResponse
2023-07-28 15:45:08 +00:00
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
2023-07-28 15:45:08 +00:00
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
2023-07-28 15:45:08 +00:00
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
2023-07-28 15:45:08 +00:00
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
2023-07-28 15:45:08 +00:00
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
2024-01-14 11:21:03 +00:00
var aliResponse ChatResponse
2023-07-28 15:45:08 +00:00
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
logger.Debugf(ctx, "response body: %s\n", responseBody)
2023-07-28 15:45:08 +00:00
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
2023-07-28 15:45:08 +00:00
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
2023-07-28 15:45:08 +00:00
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
2023-07-28 15:45:08 +00:00
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}