package xunfei import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" "net/http" "net/url" "one-api/common" "one-api/relay/channel/openai" "one-api/relay/constant" "strings" "time" ) // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { messages = append(messages, Message{ Role: "user", Content: message.StringContent(), }) messages = append(messages, Message{ Role: "assistant", Content: "Okay", }) } else { messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages return &xunfeiRequest } func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } choice := openai.TextResponseChoice{ Index: 0, Message: openai.Message{ Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, FinishReason: constant.StopFinishReason, } fullTextResponse := openai.TextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, Usage: response.Payload.Usage.Text, } return &fullTextResponse } func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &constant.StopFinishReason } response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "SparkDesk", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { HmacWithShaToBase64 := func(algorithm, data, key string) string { mac := hmac.New(sha256.New, []byte(key)) mac.Write([]byte(data)) encodeData := mac.Sum(nil) return base64.StdEncoding.EncodeToString(encodeData) } ul, err := url.Parse(hostUrl) if err != nil { fmt.Println(err) } date := time.Now().UTC().Format(time.RFC1123) signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} sign := strings.Join(signString, "\n") sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha) authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) v := url.Values{} v.Add("host", ul.Host) v.Add("date", date) v.Add("authorization", authorization) callUrl := hostUrl + "?" + v.Encode() return callUrl } func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } common.SetEventStreamHeaders(c) var usage openai.Usage c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) return nil, &usage } func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } var usage openai.Usage var content string var xunfeiResponse ChatResponse stop := false for !stop { select { case xunfeiResponse = <-dataChan: if len(xunfeiResponse.Payload.Choices.Text) == 0 { continue } content += xunfeiResponse.Payload.Choices.Text[0].Content usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens case stop = <-stopChan: } } if len(xunfeiResponse.Payload.Choices.Text) == 0 { xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } xunfeiResponse.Payload.Choices.Text[0].Content = content response := responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") _, _ = c.Writer.Write(jsonResponse) return nil, &usage } func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } conn, resp, err := d.Dial(authUrl, nil) if err != nil || resp.StatusCode != 101 { return nil, nil, err } data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { return nil, nil, err } dataChan := make(chan ChatResponse) stopChan := make(chan bool) go func() { for { _, msg, err := conn.ReadMessage() if err != nil { common.SysError("error reading stream response: " + err.Error()) break } var response ChatResponse err = json.Unmarshal(msg, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { common.SysError("error closing websocket connection: " + err.Error()) } break } } stopChan <- true }() return dataChan, stopChan, nil } func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { apiVersion = c.GetString("api_version") } if apiVersion == "" { apiVersion = "v1.1" common.SysLog("api_version not found, use default: " + apiVersion) } domain := "general" if apiVersion != "v1.1" { domain += strings.Split(apiVersion, ".")[0] } authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) return domain, authUrl }