ai-gateway/relay/channel/tencent/main.go

235 lines
7.0 KiB
Go
Raw Normal View History

2024-01-14 11:21:03 +00:00
package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"io"
"net/http"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
2024-01-14 11:21:03 +00:00
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
continue
}
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
2024-01-14 11:21:03 +00:00
return &ChatRequest{
Timestamp: helper.GetTimestamp(),
Expired: helper.GetTimestamp() + 24*60*60,
QueryID: helper.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
2024-01-14 11:21:03 +00:00
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
2024-01-14 11:21:03 +00:00
choice := openai.TextResponseChoice{
Index: 0,
2024-01-14 11:21:03 +00:00
Message: openai.Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
2024-01-14 11:21:03 +00:00
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
2024-01-14 11:21:03 +00:00
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
2024-01-14 11:21:03 +00:00
choice.FinishReason = &constant.StopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
2024-01-14 11:21:03 +00:00
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
2024-01-14 11:21:03 +00:00
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
2024-01-14 11:21:03 +00:00
var TencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
2024-01-14 11:21:03 +00:00
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var TencentResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
2024-01-14 11:21:03 +00:00
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse.Model = "hunyuan"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
2024-01-14 11:21:03 +00:00
func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
2024-01-14 11:21:03 +00:00
func GetSign(req ChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}