ai-gateway/relay/channel/zhipu/main.go

268 lines
7.5 KiB
Go
Raw Normal View History

2024-01-14 11:21:03 +00:00
package zhipu
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"sync"
"time"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
// chatglm_std, chatglm_lite
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
2024-01-14 11:21:03 +00:00
func GetToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey)
if ok {
2024-01-14 11:21:03 +00:00
tokenData := data.(tokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
}
split := strings.Split(apikey, ".")
if len(split) != 2 {
logger.SysError("invalid zhipu key: " + apikey)
return ""
}
id := split[0]
secret := split[1]
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
timestamp := time.Now().UnixNano() / 1e6
payload := jwt.MapClaims{
"api_key": id,
"exp": expMillis,
"timestamp": timestamp,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
2024-01-14 11:21:03 +00:00
zhipuTokens.Store(apikey, tokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
return tokenString
}
func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
2024-01-14 11:21:03 +00:00
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Role: "system",
Content: message.StringContent(),
})
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Role: "user",
Content: "Okay",
})
} else {
2024-01-14 11:21:03 +00:00
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
2024-01-14 11:21:03 +00:00
return &Request{
Prompt: messages,
Temperature: request.Temperature,
TopP: request.TopP,
Incremental: false,
}
}
2024-01-14 11:21:03 +00:00
func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: response.Data.TaskId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
2024-01-14 11:21:03 +00:00
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
2024-01-14 11:21:03 +00:00
openaiChoice := openai.TextResponseChoice{
Index: i,
Message: model.Message{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
FinishReason: "",
}
if i == len(response.Data.Choices)-1 {
openaiChoice.FinishReason = "stop"
}
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
}
return &fullTextResponse
}
2024-01-14 11:21:03 +00:00
func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse
2024-01-14 11:21:03 +00:00
response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "chatglm",
2024-01-14 11:21:03 +00:00
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
2024-01-14 11:21:03 +00:00
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
2024-01-14 11:21:03 +00:00
choice.FinishReason = &constant.StopFinishReason
response := openai.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "chatglm",
2024-01-14 11:21:03 +00:00
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response, &zhipuResponse.Usage
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage *model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
lines := strings.Split(data, "\n")
for i, line := range lines {
if len(line) < 5 {
continue
}
if line[:5] == "data:" {
dataChan <- line[5:]
if i != len(lines)-1 {
dataChan <- "\n"
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
}
}
}
stopChan <- true
}()
2024-01-14 11:21:03 +00:00
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
2024-01-14 11:21:03 +00:00
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
2024-01-14 11:21:03 +00:00
var zhipuResponse Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
Code: zhipuResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
fullTextResponse.Model = "chatglm"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
2024-01-14 11:21:03 +00:00
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}