From ea62d3e3be970ca8e19e46ce0e9ed802e74f4a1a Mon Sep 17 00:00:00 2001 From: xsl <1872744675@qq.com> Date: Thu, 25 Apr 2024 13:41:27 +0800 Subject: [PATCH] feat: vertex --- relay/adaptor.go | 3 + relay/adaptor/vertex/adaptor.go | 78 +++++++++ relay/adaptor/vertex/constants.go | 8 + relay/adaptor/vertex/main.go | 260 ++++++++++++++++++++++++++++++ relay/adaptor/vertex/model.go | 63 ++++++++ relay/adaptor/vertex/token.go | 164 +++++++++++++++++++ relay/apitype/define.go | 1 + relay/channeltype/define.go | 1 + relay/channeltype/helper.go | 2 + relay/channeltype/url.go | 1 + 10 files changed, 581 insertions(+) create mode 100644 relay/adaptor/vertex/adaptor.go create mode 100644 relay/adaptor/vertex/constants.go create mode 100644 relay/adaptor/vertex/main.go create mode 100644 relay/adaptor/vertex/model.go create mode 100644 relay/adaptor/vertex/token.go diff --git a/relay/adaptor.go b/relay/adaptor.go index 293b6d79..e321bed3 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/tencent" + "github.com/songquanpeng/one-api/relay/adaptor/vertex" "github.com/songquanpeng/one-api/relay/adaptor/xunfei" "github.com/songquanpeng/one-api/relay/adaptor/zhipu" "github.com/songquanpeng/one-api/relay/apitype" @@ -49,6 +50,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &coze.Adaptor{} case apitype.Cohere: return &cohere.Adaptor{} + case apitype.Vertex: + return &vertex.Adaptor{} } return nil } diff --git a/relay/adaptor/vertex/adaptor.go b/relay/adaptor/vertex/adaptor.go new file mode 100644 index 00000000..5a5e0bc4 --- /dev/null +++ b/relay/adaptor/vertex/adaptor.go @@ -0,0 +1,78 @@ +package vertex + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +// https://$LOCATION-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/$LOCATION/publishers/anthropic/models/$MODEL:streamRawPredict +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + _ = meta + // todo 需要修改为配置 + location := "" + projectId := "" + models := "" + + return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", + location, projectId, location, models), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + channelhelper.SetupCommonRequestHeader(c, req, meta) + token, err := getToken(c, meta) + if err != nil { + return err + } + // token可以设置到token表的key字段,SetupContextForSelectedChannel会设置该header + req.Header.Set("Authorization", "Bearer "+token) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + _, _ = c, relayMode + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return channelhelper.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "vertex" +} diff --git a/relay/adaptor/vertex/constants.go b/relay/adaptor/vertex/constants.go new file mode 100644 index 00000000..b03533c2 --- /dev/null +++ b/relay/adaptor/vertex/constants.go @@ -0,0 +1,8 @@ +package vertex + +// https://ai.google.dev/models/gemini + +var ModelList = []string{ + "claude-3-opus-20240229", + "claude-3-opus", +} diff --git a/relay/adaptor/vertex/main.go b/relay/adaptor/vertex/main.go new file mode 100644 index 00000000..3be3b8f9 --- /dev/null +++ b/relay/adaptor/vertex/main.go @@ -0,0 +1,260 @@ +package vertex + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + + messages := make([]Message, 0) + for _, message := range textRequest.Messages { + var content Content + if message.IsStringContent() { + content.Type = "text" + content.Text = message.StringContent() + messages = append(messages, Message{ + Role: message.Role, + Content: []Content{content}, + }) + continue + } + var contents []Content + openaiContent := message.ParseContent() + for _, part := range openaiContent { + var content Content + if part.Type == model.ContentTypeText { + content.Type = "text" + content.Text = part.Text + } else if part.Type == model.ContentTypeImageURL { + content.Type = "image" + content.Source = &Source{ + Type: "base64", + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + content.Source.MediaType = mimeType + content.Source.Data = data + } + contents = append(contents, content) + } + messages = append(messages, Message{ + Role: message.Role, + Content: contents, + }) + } + + return &Request{ + AnthropicVersion: "vertex-2023-10-16", + Messages: messages, + MaxTokens: textRequest.MaxTokens, + Stream: textRequest.Stream, + } +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var claudeResponse Response + err = json.Unmarshal(responseBody, &claudeResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if claudeResponse.Error.Type != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: claudeResponse.Error.Message, + Type: claudeResponse.Error.Type, + Param: "", + Code: claudeResponse.Error.Type, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) + fullTextResponse.Model = modelName + usage := model.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), + Model: claudeResponse.Model, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { + continue + } + if !strings.HasPrefix(data, "data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + var usage model.Usage + var modelName string + var id string + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var claudeResponse StreamResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } + if response == nil { + return true + } + response.Id = id + response.Model = modelName + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + return nil, &usage +} + +func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var stopReason string + switch claudeResponse.Type { + case "message_start": + return nil, claudeResponse.Message + case "content_block_start": + if claudeResponse.ContentBlock != nil { + responseText = claudeResponse.ContentBlock.Text + } + case "content_block_delta": + if claudeResponse.Delta != nil { + responseText = claudeResponse.Delta.Text + } + case "message_delta": + if claudeResponse.Usage != nil { + response = &Response{ + Usage: *claudeResponse.Usage, + } + } + if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { + stopReason = *claudeResponse.Delta.StopReason + } + } + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = responseText + choice.Delta.Role = "assistant" + finishReason := stopReasonClaude2OpenAI(&stopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response +} + +func stopReasonClaude2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + default: + return *reason + } +} diff --git a/relay/adaptor/vertex/model.go b/relay/adaptor/vertex/model.go new file mode 100644 index 00000000..4452cd4c --- /dev/null +++ b/relay/adaptor/vertex/model.go @@ -0,0 +1,63 @@ +package vertex + +type Request struct { + AnthropicVersion string `json:"anthropic_version"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` +} + +type Message struct { + Role string `json:"role"` + Content []Content `json:"content"` +} + +type Content struct { + Type string `json:"type"` + Source *Source `json:"source,omitempty"` + Text string `json:"text,omitempty"` +} + +type Source struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type Error struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type Response struct { + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []Content `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage Usage `json:"usage"` + Error Error `json:"error"` +} + +type Delta struct { + Type string `json:"type"` + Text string `json:"text"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` +} + +type StreamResponse struct { + Type string `json:"type"` + Message *Response `json:"message"` + Index int `json:"index"` + ContentBlock *Content `json:"content_block"` + Delta *Delta `json:"delta"` + Usage *Usage `json:"usage"` +} diff --git a/relay/adaptor/vertex/token.go b/relay/adaptor/vertex/token.go new file mode 100644 index 00000000..517e45a1 --- /dev/null +++ b/relay/adaptor/vertex/token.go @@ -0,0 +1,164 @@ +package vertex + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "github.com/golang-jwt/jwt" + "github.com/songquanpeng/one-api/relay/meta" + "io" + "net/http" + "time" +) + +type Credentials struct { + PrivateKey string + PrivateKeyID string + ClientEmail string +} + +// ServiceAccount holds the credentials and scopes required for token generation +type ServiceAccount struct { + Cred *Credentials + Scopes string +} + +var scopes = "https://www.googleapis.com/auth/cloud-platform" + +// createSignedJWT creates a Signed JWT from service account credentials +func (sa *ServiceAccount) createSignedJWT() (string, error) { + if sa.Cred == nil { + return "", fmt.Errorf("credentials are nil") + } + + issuedAt := time.Now() + expiresAt := issuedAt.Add(time.Hour) + + claims := &jwt.MapClaims{ + "iss": sa.Cred.ClientEmail, + "sub": sa.Cred.ClientEmail, + "aud": "https://www.googleapis.com/oauth2/v4/token", + "iat": issuedAt.Unix(), + "exp": expiresAt.Unix(), + "scope": scopes, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = sa.Cred.PrivateKeyID + token.Header["alg"] = "RS256" + token.Header["typ"] = "JWT" + + // 解析 PEM 编码的私钥 + block, _ := pem.Decode([]byte(sa.Cred.PrivateKey)) + if block == nil { + return "", errors.New("failed to decode PEM block containing private key") + } + + // 解析 RSA 私钥 + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return "", err + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return "", errors.New("private key is not of type RSA") + } + + signedToken, err := token.SignedString(rsaPrivateKey) + if err != nil { + return "", err + } + + return signedToken, nil +} + +// getToken uses the signed JWT to obtain an access token +func (sa *ServiceAccount) getToken(ctx context.Context) (string, error) { + signedJWT, err := sa.createSignedJWT() + if err != nil { + return "", err + } + + return exchangeJwtForAccessToken(ctx, signedJWT) +} + +// exchangeJwtForAccessToken exchanges a Signed JWT for a Google OAuth Access Token. +func exchangeJwtForAccessToken(ctx context.Context, signedJWT string) (string, error) { + authURL := "https://www.googleapis.com/oauth2/v4/token" + params := map[string]string{ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": signedJWT, + } + + jsonData, err := json.Marshal(params) + if err != nil { + return "", err + } + + // Create a new HTTP client with a timeout + client := &http.Client{ + Timeout: time.Second * 5, + } + + req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var data map[string]interface{} + err = json.Unmarshal(body, &data) + if err != nil { + return "", err + } + + // Extract the access token from the response + accessToken, ok := data["access_token"].(string) + if !ok { + return "", err // You might want to return a more specific error here + } + + return accessToken, nil +} + +func getToken(ctx context.Context, meta *meta.Meta) (string, error) { + // todo 每次请求都要换次token??? + encodedString := "" + decodedBytes, err := base64.StdEncoding.DecodeString(encodedString) + if err != nil { + return "", err + } + m := make(map[string]string) + err = json.Unmarshal(decodedBytes, &m) + if err != nil { + return "", err + } + + sa := &ServiceAccount{ + Cred: &Credentials{ + PrivateKey: m["private_key"], + PrivateKeyID: m["private_key_id"], + ClientEmail: m["client_email"], + }, + Scopes: scopes, + } + return sa.getToken(ctx) +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go index a1c8e6e1..f851b9e2 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -15,6 +15,7 @@ const ( AwsClaude Coze Cohere + Vertex Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index 4b37e566..bf7e24e4 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -37,6 +37,7 @@ const ( AwsClaude Coze Cohere + Vertex Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index 42b77891..bc150379 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -31,6 +31,8 @@ func ToAPIType(channelType int) int { apiType = apitype.Coze case Cohere: apiType = apitype.Cohere + case Vertex: + apiType = apitype.Vertex } return apiType diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 64fdcd0a..d5b9bef3 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -37,6 +37,7 @@ var ChannelBaseURLs = []string{ "", // 33 "https://api.coze.com", // 34 "https://api.cohere.ai", //35 + "", //36 } func init() {