281 lines
8.4 KiB
Go
281 lines
8.4 KiB
Go
package controller
|
||
|
||
import (
|
||
"crypto/hmac"
|
||
"crypto/sha256"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/gorilla/websocket"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"one-api/common"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// https://console.xfyun.cn/services/cbm
|
||
// https://www.xfyun.cn/doc/spark/Web.html
|
||
|
||
type XunfeiMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
type XunfeiChatRequest struct {
|
||
Header struct {
|
||
AppId string `json:"app_id"`
|
||
} `json:"header"`
|
||
Parameter struct {
|
||
Chat struct {
|
||
Domain string `json:"domain,omitempty"`
|
||
Temperature float64 `json:"temperature,omitempty"`
|
||
TopK int `json:"top_k,omitempty"`
|
||
MaxTokens int `json:"max_tokens,omitempty"`
|
||
Auditing bool `json:"auditing,omitempty"`
|
||
} `json:"chat"`
|
||
} `json:"parameter"`
|
||
Payload struct {
|
||
Message struct {
|
||
Text []XunfeiMessage `json:"text"`
|
||
} `json:"message"`
|
||
} `json:"payload"`
|
||
}
|
||
|
||
type XunfeiChatResponseTextItem struct {
|
||
Content string `json:"content"`
|
||
Role string `json:"role"`
|
||
Index int `json:"index"`
|
||
}
|
||
|
||
type XunfeiChatResponse struct {
|
||
Header struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Sid string `json:"sid"`
|
||
Status int `json:"status"`
|
||
} `json:"header"`
|
||
Payload struct {
|
||
Choices struct {
|
||
Status int `json:"status"`
|
||
Seq int `json:"seq"`
|
||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||
} `json:"choices"`
|
||
Usage struct {
|
||
//Text struct {
|
||
// QuestionTokens string `json:"question_tokens"`
|
||
// PromptTokens string `json:"prompt_tokens"`
|
||
// CompletionTokens string `json:"completion_tokens"`
|
||
// TotalTokens string `json:"total_tokens"`
|
||
//} `json:"text"`
|
||
Text Usage `json:"text"`
|
||
} `json:"usage"`
|
||
} `json:"payload"`
|
||
}
|
||
|
||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
|
||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||
for _, message := range request.Messages {
|
||
if message.Role == "system" {
|
||
messages = append(messages, XunfeiMessage{
|
||
Role: "user",
|
||
Content: message.Content,
|
||
})
|
||
messages = append(messages, XunfeiMessage{
|
||
Role: "assistant",
|
||
Content: "Okay",
|
||
})
|
||
} else {
|
||
messages = append(messages, XunfeiMessage{
|
||
Role: message.Role,
|
||
Content: message.Content,
|
||
})
|
||
}
|
||
}
|
||
xunfeiRequest := XunfeiChatRequest{}
|
||
xunfeiRequest.Header.AppId = xunfeiAppId
|
||
xunfeiRequest.Parameter.Chat.Domain = "general"
|
||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||
xunfeiRequest.Payload.Message.Text = messages
|
||
return &xunfeiRequest
|
||
}
|
||
|
||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
||
if len(response.Payload.Choices.Text) == 0 {
|
||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||
{
|
||
Content: "",
|
||
},
|
||
}
|
||
}
|
||
choice := OpenAITextResponseChoice{
|
||
Index: 0,
|
||
Message: Message{
|
||
Role: "assistant",
|
||
Content: response.Payload.Choices.Text[0].Content,
|
||
},
|
||
}
|
||
fullTextResponse := OpenAITextResponse{
|
||
Object: "chat.completion",
|
||
Created: common.GetTimestamp(),
|
||
Choices: []OpenAITextResponseChoice{choice},
|
||
Usage: response.Payload.Usage.Text,
|
||
}
|
||
return &fullTextResponse
|
||
}
|
||
|
||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||
{
|
||
Content: "",
|
||
},
|
||
}
|
||
}
|
||
var choice ChatCompletionsStreamResponseChoice
|
||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||
choice.FinishReason = &stopFinishReason
|
||
}
|
||
response := ChatCompletionsStreamResponse{
|
||
Object: "chat.completion.chunk",
|
||
Created: common.GetTimestamp(),
|
||
Model: "SparkDesk",
|
||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||
}
|
||
return &response
|
||
}
|
||
|
||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||
mac := hmac.New(sha256.New, []byte(key))
|
||
mac.Write([]byte(data))
|
||
encodeData := mac.Sum(nil)
|
||
return base64.StdEncoding.EncodeToString(encodeData)
|
||
}
|
||
ul, err := url.Parse(hostUrl)
|
||
if err != nil {
|
||
fmt.Println(err)
|
||
}
|
||
date := time.Now().UTC().Format(time.RFC1123)
|
||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||
sign := strings.Join(signString, "\n")
|
||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||
"hmac-sha256", "host date request-line", sha)
|
||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||
v := url.Values{}
|
||
v.Add("host", ul.Host)
|
||
v.Add("date", date)
|
||
v.Add("authorization", authorization)
|
||
callUrl := hostUrl + "?" + v.Encode()
|
||
return callUrl
|
||
}
|
||
|
||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string, version string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||
var usage Usage
|
||
d := websocket.Dialer{
|
||
HandshakeTimeout: 5 * time.Second,
|
||
}
|
||
hostUrl := "wss://aichat.xf-yun.com/v1/chat"
|
||
if version != "" { //换成新版的,支持v2
|
||
hostUrl = "wss://spark-api.xf-yun.com/" + version + "/chat"
|
||
}
|
||
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
||
if err != nil || resp.StatusCode != 101 {
|
||
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
||
}
|
||
data := requestOpenAI2Xunfei(textRequest, appId)
|
||
err = conn.WriteJSON(data)
|
||
if err != nil {
|
||
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
||
}
|
||
dataChan := make(chan XunfeiChatResponse)
|
||
stopChan := make(chan bool)
|
||
go func() {
|
||
for {
|
||
_, msg, err := conn.ReadMessage()
|
||
if err != nil {
|
||
common.SysError("error reading stream response: " + err.Error())
|
||
break
|
||
}
|
||
var response XunfeiChatResponse
|
||
err = json.Unmarshal(msg, &response)
|
||
if err != nil {
|
||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||
break
|
||
}
|
||
dataChan <- response
|
||
if response.Payload.Choices.Status == 2 {
|
||
err := conn.Close()
|
||
if err != nil {
|
||
common.SysError("error closing websocket connection: " + err.Error())
|
||
}
|
||
break
|
||
}
|
||
}
|
||
stopChan <- true
|
||
}()
|
||
setEventStreamHeaders(c)
|
||
c.Stream(func(w io.Writer) bool {
|
||
select {
|
||
case xunfeiResponse := <-dataChan:
|
||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||
jsonResponse, err := json.Marshal(response)
|
||
if err != nil {
|
||
common.SysError("error marshalling stream response: " + err.Error())
|
||
return true
|
||
}
|
||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||
return true
|
||
case <-stopChan:
|
||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||
return false
|
||
}
|
||
})
|
||
return nil, &usage
|
||
}
|
||
|
||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||
var xunfeiResponse XunfeiChatResponse
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
err = resp.Body.Close()
|
||
if err != nil {
|
||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
err = json.Unmarshal(responseBody, &xunfeiResponse)
|
||
if err != nil {
|
||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
if xunfeiResponse.Header.Code != 0 {
|
||
return &OpenAIErrorWithStatusCode{
|
||
OpenAIError: OpenAIError{
|
||
Message: xunfeiResponse.Header.Message,
|
||
Type: "xunfei_error",
|
||
Param: "",
|
||
Code: xunfeiResponse.Header.Code,
|
||
},
|
||
StatusCode: resp.StatusCode,
|
||
}, nil
|
||
}
|
||
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||
if err != nil {
|
||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
c.Writer.Header().Set("Content-Type", "application/json")
|
||
c.Writer.WriteHeader(resp.StatusCode)
|
||
_, err = c.Writer.Write(jsonResponse)
|
||
return nil, &fullTextResponse.Usage
|
||
}
|