添加对智谱V4 API的支持

This commit is contained in:
hongsheng 2024-01-25 04:21:22 +08:00
parent 4f214c48c6
commit 31b85ded54
12 changed files with 380 additions and 4 deletions

View File

@ -63,6 +63,7 @@ const (
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeZhipu_v4 = 25
)
var ChannelBaseURLs = []string{
@ -91,4 +92,5 @@ var ChannelBaseURLs = []string{
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://open.bigmodel.cn", // 25
}

View File

@ -92,6 +92,8 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572, // ¥0.005 / 1k tokens
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens

View File

@ -32,6 +32,8 @@ func testChannel(channel *model.Channel, request openai.ChatRequest) (err error,
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeZhipu_v4:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:

View File

@ -495,6 +495,24 @@ func init() {
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "glm-4",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu_v4",
Permission: permission,
Root: "glm-4",
Parent: nil,
},
{
Id: "glm-3-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu_v4",
Permission: permission,
Root: "glm-3-turbo",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",

View File

@ -1,9 +1,11 @@
package openai
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId any `json:"tool_call_id,omitempty"`
}
type ImageURL struct {
@ -109,6 +111,7 @@ type GeneralOpenAIRequest struct {
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
@ -267,9 +270,12 @@ type ImageResponse struct {
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
Content string `json:"content"`
Role string `json:"role,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
}
type ChatCompletionsStreamResponse struct {

View File

@ -0,0 +1,22 @@
package zhipu_v4
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/relay/channel/openai"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@ -0,0 +1,234 @@
package zhipu_v4
import (
"bufio"
"bytes"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/logger"
"one-api/relay/channel/openai"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
// https://open.bigmodel.cn/dev/api
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
func GetToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey)
if ok {
tokenData := data.(tokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
}
split := strings.Split(apikey, ".")
if len(split) != 2 {
logger.SysError("invalid zhipu key: " + apikey)
return ""
}
id := split[0]
secret := split[1]
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
timestamp := time.Now().UnixNano() / 1e6
payload := jwt.MapClaims{
"api_key": id,
"exp": expMillis,
"timestamp": timestamp,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
zhipuTokens.Store(apikey, tokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
return tokenString
}
func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
return &Request{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
Stop: Stop,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
}
func StreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse) {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
choice.Index = zhipuResponse.Choices[0].Index
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
response := openai.ChatCompletionsStreamResponse{
Id: zhipuResponse.Id,
Object: "chat.completion.chunk",
Created: zhipuResponse.Created,
Model: "glm-4",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func LastStreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
response := StreamResponseZhipuV42OpenAI(zhipuResponse)
return response, &zhipuResponse.Usage
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage *openai.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var streamResponse StreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
}
var response *openai.ChatCompletionsStreamResponse
if strings.Contains(data, "prompt_tokens") {
response, usage = LastStreamResponseZhipuV42OpenAI(&streamResponse)
} else {
response = StreamResponseZhipuV42OpenAI(&streamResponse)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var textResponse Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &openai.ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &textResponse.Usage
}

View File

@ -0,0 +1,59 @@
package zhipu_v4
import (
"one-api/relay/channel/openai"
"time"
)
type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId any `json:"tool_call_id,omitempty"`
}
type Request struct {
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stop []string `json:"stop,omitempty"`
RequestId string `json:"request_id,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}
type TextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Response struct {
Id string `json:"id"`
Created int64 `json:"created"`
Model string `json:"model"`
TextResponseChoices []TextResponseChoice `json:"choices"`
openai.Usage `json:"usage"`
openai.Error `json:"error"`
}
type StreamResponseChoice struct {
Index int `json:"index,omitempty"`
Delta Message `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type StreamResponse struct {
Id string `json:"id"`
Created int64 `json:"created"`
Choices []StreamResponseChoice `json:"choices"`
openai.Usage `json:"usage"`
}
type tokenData struct {
Token string
ExpiryTime time.Time
}

View File

@ -15,6 +15,7 @@ const (
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeZhipu_v4
)
func ChannelType2APIType(channelType int) int {
@ -38,6 +39,8 @@ func ChannelType2APIType(channelType int) int {
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
case common.ChannelTypeZhipu_v4:
apiType = APITypeZhipu_v4
}
return apiType
}

View File

@ -19,6 +19,7 @@ import (
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_v4"
"one-api/relay/constant"
"one-api/relay/util"
"strings"
@ -79,6 +80,8 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case constant.APITypeZhipu_v4:
fullRequestURL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
case constant.APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == constant.RelayModeEmbeddings {
@ -147,6 +150,13 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
case constant.APITypeZhipu_v4:
zhipuRequest := zhipu_v4.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
case constant.APITypeAli:
var jsonStr []byte
var err error
@ -223,6 +233,9 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util
case constant.APITypeZhipu:
token := zhipu.GetToken(apiKey)
req.Header.Set("Authorization", token)
case constant.APITypeZhipu_v4:
token := zhipu_v4.GetToken(apiKey)
req.Header.Set("Authorization", token)
case constant.APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if isStream {
@ -286,6 +299,12 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
} else {
err, usage = zhipu.Handler(c, resp)
}
case constant.APITypeZhipu_v4:
if isStream {
err, usage = zhipu_v4.StreamHandler(c, resp)
} else {
err, usage = zhipu_v4.Handler(c, resp)
}
case constant.APITypeAli:
if isStream {
err, usage = ali.StreamHandler(c, resp)

View File

@ -139,6 +139,12 @@ const typeConfig = {
},
modelGroup: "google gemini",
},
25: {
input: {
models: ["glm-4", "glm-3-turbo"],
},
modelGroup: "zhipu_v4",
},
};
export { defaultConfig, typeConfig };

View File

@ -93,6 +93,9 @@ const EditChannel = () => {
case 24:
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
case 24:
localModels = ['glm-4', 'glm-3-turbo'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}