From 31b85ded542db4072032360a181d5347e104fb64 Mon Sep 17 00:00:00 2001 From: hongsheng Date: Thu, 25 Jan 2024 04:21:22 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=E6=99=BA?= =?UTF-8?q?=E8=B0=B1V4=20API=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 2 + common/model-ratio.go | 2 + controller/channel-test.go | 2 + controller/model.go | 18 ++ relay/channel/openai/model.go | 14 +- relay/channel/zhipu_v4/adaptor.go | 22 ++ relay/channel/zhipu_v4/main.go | 234 +++++++++++++++++++ relay/channel/zhipu_v4/model.go | 59 +++++ relay/constant/api_type.go | 3 + relay/controller/util.go | 19 ++ web/berry/src/views/Channel/type/Config.js | 6 + web/default/src/pages/Channel/EditChannel.js | 3 + 12 files changed, 380 insertions(+), 4 deletions(-) create mode 100644 relay/channel/zhipu_v4/adaptor.go create mode 100644 relay/channel/zhipu_v4/main.go create mode 100644 relay/channel/zhipu_v4/model.go diff --git a/common/constants.go b/common/constants.go index 325454d4..d96cf6d7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -63,6 +63,7 @@ const ( ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 ChannelTypeGemini = 24 + ChannelTypeZhipu_v4 = 25 ) var ChannelBaseURLs = []string{ @@ -91,4 +92,5 @@ var ChannelBaseURLs = []string{ "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", // 23 "https://generativelanguage.googleapis.com", // 24 + "https://open.bigmodel.cn", // 25 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 9f31e0d7..bb275801 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -92,6 +92,8 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "glm-4": 7.143, // ¥0.1 / 1k tokens + "glm-3-turbo": 0.3572, // ¥0.005 / 1k tokens "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing "qwen-plus": 1.4286, // ¥0.02 / 1k tokens "qwen-max": 1.4286, // ¥0.02 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index 88d6e3f2..82116de8 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -32,6 +32,8 @@ func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, fallthrough case common.ChannelTypeZhipu: fallthrough + case common.ChannelTypeZhipu_v4: + fallthrough case common.ChannelTypeAli: fallthrough case common.ChannelType360: diff --git a/controller/model.go b/controller/model.go index b7ec1b6a..87eb1a4e 100644 --- a/controller/model.go +++ b/controller/model.go @@ -495,6 +495,24 @@ func init() { Root: "chatglm_lite", Parent: nil, }, + { + Id: "glm-4", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu_v4", + Permission: permission, + Root: "glm-4", + Parent: nil, + }, + { + Id: "glm-3-turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu_v4", + Permission: permission, + Root: "glm-3-turbo", + Parent: nil, + }, { Id: "qwen-turbo", Object: "model", diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 937fb424..3170c940 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -1,9 +1,11 @@ package openai type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role"` + Content any `json:"content"` + Name *string `json:"name,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + ToolCallId any `json:"tool_call_id,omitempty"` } type ImageURL struct { @@ -109,6 +111,7 @@ type GeneralOpenAIRequest struct { MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` + Stop any `json:"stop,omitempty"` N int `json:"n,omitempty"` Input any `json:"input,omitempty"` Instruction string `json:"instruction,omitempty"` @@ -267,9 +270,12 @@ type ImageResponse struct { type ChatCompletionsStreamResponseChoice struct { Delta struct { - Content string `json:"content"` + Content string `json:"content"` + Role string `json:"role,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` } `json:"delta"` FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/zhipu_v4/adaptor.go b/relay/channel/zhipu_v4/adaptor.go new file mode 100644 index 00000000..a80b57ce --- /dev/null +++ b/relay/channel/zhipu_v4/adaptor.go @@ -0,0 +1,22 @@ +package zhipu_v4 + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/zhipu_v4/main.go b/relay/channel/zhipu_v4/main.go new file mode 100644 index 00000000..46e766b8 --- /dev/null +++ b/relay/channel/zhipu_v4/main.go @@ -0,0 +1,234 @@ +package zhipu_v4 + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/common/logger" + "one-api/relay/channel/openai" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" +) + +// https://open.bigmodel.cn/dev/api + +var zhipuTokens sync.Map +var expSeconds int64 = 24 * 3600 + +func GetToken(apikey string) string { + data, ok := zhipuTokens.Load(apikey) + if ok { + tokenData := data.(tokenData) + if time.Now().Before(tokenData.ExpiryTime) { + return tokenData.Token + } + } + + split := strings.Split(apikey, ".") + if len(split) != 2 { + logger.SysError("invalid zhipu key: " + apikey) + return "" + } + + id := split[0] + secret := split[1] + + expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 + expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) + + timestamp := time.Now().UnixNano() / 1e6 + + payload := jwt.MapClaims{ + "api_key": id, + "exp": expMillis, + "timestamp": timestamp, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return "" + } + + zhipuTokens.Store(apikey, tokenData{ + Token: tokenString, + ExpiryTime: expiryTime, + }) + + return tokenString +} + +func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { + messages := make([]Message, 0, len(request.Messages)) + for _, message := range request.Messages { + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + ToolCalls: message.ToolCalls, + ToolCallId: message.ToolCallId, + }) + } + str, ok := request.Stop.(string) + var Stop []string + if ok { + Stop = []string{str} + } else { + Stop, _ = request.Stop.([]string) + } + return &Request{ + Model: request.Model, + Stream: request.Stream, + Messages: messages, + Temperature: request.Temperature, + TopP: request.TopP, + MaxTokens: request.MaxTokens, + Stop: Stop, + Tools: request.Tools, + ToolChoice: request.ToolChoice, + } +} + +func StreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse) { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content + choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role + choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls + choice.Index = zhipuResponse.Choices[0].Index + choice.FinishReason = zhipuResponse.Choices[0].FinishReason + response := openai.ChatCompletionsStreamResponse{ + Id: zhipuResponse.Id, + Object: "chat.completion.chunk", + Created: zhipuResponse.Created, + Model: "glm-4", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func LastStreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { + response := StreamResponseZhipuV42OpenAI(zhipuResponse) + return response, &zhipuResponse.Usage +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage *openai.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + if data[:6] != "data: " && data[:6] != "[DONE]" { + continue + } + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + if strings.HasPrefix(data, "data: [DONE]") { + data = data[:12] + } + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + + var streamResponse StreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + } + var response *openai.ChatCompletionsStreamResponse + if strings.Contains(data, "prompt_tokens") { + response, usage = LastStreamResponseZhipuV42OpenAI(&streamResponse) + } else { + response = StreamResponseZhipuV42OpenAI(&streamResponse) + } + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var textResponse Response + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &openai.ErrorWithStatusCode{ + Error: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the HTTPClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, &textResponse.Usage +} diff --git a/relay/channel/zhipu_v4/model.go b/relay/channel/zhipu_v4/model.go new file mode 100644 index 00000000..5c2b267e --- /dev/null +++ b/relay/channel/zhipu_v4/model.go @@ -0,0 +1,59 @@ +package zhipu_v4 + +import ( + "one-api/relay/channel/openai" + "time" +) + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + ToolCallId any `json:"tool_call_id,omitempty"` +} + +type Request struct { + Model string `json:"model"` + Stream bool `json:"stream,omitempty"` + Messages []Message `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` + RequestId string `json:"request_id,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} + +type TextResponseChoice struct { + Index int `json:"index"` + Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type Response struct { + Id string `json:"id"` + Created int64 `json:"created"` + Model string `json:"model"` + TextResponseChoices []TextResponseChoice `json:"choices"` + openai.Usage `json:"usage"` + openai.Error `json:"error"` +} + +type StreamResponseChoice struct { + Index int `json:"index,omitempty"` + Delta Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type StreamResponse struct { + Id string `json:"id"` + Created int64 `json:"created"` + Choices []StreamResponseChoice `json:"choices"` + openai.Usage `json:"usage"` +} + +type tokenData struct { + Token string + ExpiryTime time.Time +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 658bfb90..5748bd2d 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -15,6 +15,7 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini + APITypeZhipu_v4 ) func ChannelType2APIType(channelType int) int { @@ -38,6 +39,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeTencent case common.ChannelTypeGemini: apiType = APITypeGemini + case common.ChannelTypeZhipu_v4: + apiType = APITypeZhipu_v4 } return apiType } diff --git a/relay/controller/util.go b/relay/controller/util.go index 02f1b30f..618c2f87 100644 --- a/relay/controller/util.go +++ b/relay/controller/util.go @@ -19,6 +19,7 @@ import ( "one-api/relay/channel/tencent" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" + "one-api/relay/channel/zhipu_v4" "one-api/relay/constant" "one-api/relay/util" "strings" @@ -79,6 +80,8 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel method = "sse-invoke" } fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) + case constant.APITypeZhipu_v4: + fullRequestURL = "https://open.bigmodel.cn/api/paas/v4/chat/completions" case constant.APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" if relayMode == constant.RelayModeEmbeddings { @@ -147,6 +150,13 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM return nil, err } requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeZhipu_v4: + zhipuRequest := zhipu_v4.ConvertRequest(textRequest) + jsonStr, err := json.Marshal(zhipuRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) case constant.APITypeAli: var jsonStr []byte var err error @@ -223,6 +233,9 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util case constant.APITypeZhipu: token := zhipu.GetToken(apiKey) req.Header.Set("Authorization", token) + case constant.APITypeZhipu_v4: + token := zhipu_v4.GetToken(apiKey) + req.Header.Set("Authorization", token) case constant.APITypeAli: req.Header.Set("Authorization", "Bearer "+apiKey) if isStream { @@ -286,6 +299,12 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp * } else { err, usage = zhipu.Handler(c, resp) } + case constant.APITypeZhipu_v4: + if isStream { + err, usage = zhipu_v4.StreamHandler(c, resp) + } else { + err, usage = zhipu_v4.Handler(c, resp) + } case constant.APITypeAli: if isStream { err, usage = ali.StreamHandler(c, resp) diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index d270f527..443e96e9 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -139,6 +139,12 @@ const typeConfig = { }, modelGroup: "google gemini", }, + 25: { + input: { + models: ["glm-4", "glm-3-turbo"], + }, + modelGroup: "zhipu_v4", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 0d4e114d..df2dc161 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -93,6 +93,9 @@ const EditChannel = () => { case 24: localModels = ['gemini-pro', 'gemini-pro-vision']; break; + case 24: + localModels = ['glm-4', 'glm-3-turbo']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From d72ebbda0ba014a3d633ff418a99b10c859f6d9d Mon Sep 17 00:00:00 2001 From: hongsheng Date: Thu, 25 Jan 2024 04:54:14 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=99=BA=E8=B0=B1V4=20stream=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E5=B8=A6=E4=B8=8A=E5=87=86=E7=A1=AE=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/zhipu_v4/main.go | 14 +++++++------- relay/controller/util.go | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/relay/channel/zhipu_v4/main.go b/relay/channel/zhipu_v4/main.go index 46e766b8..baa7f77b 100644 --- a/relay/channel/zhipu_v4/main.go +++ b/relay/channel/zhipu_v4/main.go @@ -99,7 +99,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { } } -func StreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse) { +func StreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse, reqModel string) *openai.ChatCompletionsStreamResponse { var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role @@ -110,18 +110,18 @@ func StreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCo Id: zhipuResponse.Id, Object: "chat.completion.chunk", Created: zhipuResponse.Created, - Model: "glm-4", + Model: reqModel, Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func LastStreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { - response := StreamResponseZhipuV42OpenAI(zhipuResponse) +func LastStreamResponseZhipuV42OpenAI(zhipuResponse *StreamResponse, reqModel string) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { + response := StreamResponseZhipuV42OpenAI(zhipuResponse, reqModel) return response, &zhipuResponse.Usage } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func StreamHandler(c *gin.Context, resp *http.Response, reqModel string) (*openai.ErrorWithStatusCode, *openai.Usage) { var usage *openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -168,9 +168,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus } var response *openai.ChatCompletionsStreamResponse if strings.Contains(data, "prompt_tokens") { - response, usage = LastStreamResponseZhipuV42OpenAI(&streamResponse) + response, usage = LastStreamResponseZhipuV42OpenAI(&streamResponse, reqModel) } else { - response = StreamResponseZhipuV42OpenAI(&streamResponse) + response = StreamResponseZhipuV42OpenAI(&streamResponse, reqModel) } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/controller/util.go b/relay/controller/util.go index 618c2f87..f8789a30 100644 --- a/relay/controller/util.go +++ b/relay/controller/util.go @@ -301,7 +301,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp * } case constant.APITypeZhipu_v4: if isStream { - err, usage = zhipu_v4.StreamHandler(c, resp) + err, usage = zhipu_v4.StreamHandler(c, resp, textRequest.Model) } else { err, usage = zhipu_v4.Handler(c, resp) }