feat: vertex

This commit is contained in:
xsl 2024-04-25 13:41:27 +08:00
parent da0842272c
commit ea62d3e3be
10 changed files with 581 additions and 0 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertex"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
"github.com/songquanpeng/one-api/relay/apitype"
@ -49,6 +50,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &coze.Adaptor{}
case apitype.Cohere:
return &cohere.Adaptor{}
case apitype.Vertex:
return &vertex.Adaptor{}
}
return nil
}

View File

@ -0,0 +1,78 @@
package vertex
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *meta.Meta) {
}
// https://$LOCATION-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/$LOCATION/publishers/anthropic/models/$MODEL:streamRawPredict
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
_ = meta
// todo 需要修改为配置
location := ""
projectId := ""
models := ""
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict",
location, projectId, location, models), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
token, err := getToken(c, meta)
if err != nil {
return err
}
// token可以设置到token表的key字段SetupContextForSelectedChannel会设置该header
req.Header.Set("Authorization", "Bearer "+token)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
_, _ = c, relayMode
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "vertex"
}

View File

@ -0,0 +1,8 @@
package vertex
// https://ai.google.dev/models/gemini
var ModelList = []string{
"claude-3-opus-20240229",
"claude-3-opus",
}

View File

@ -0,0 +1,260 @@
package vertex
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
messages := make([]Message, 0)
for _, message := range textRequest.Messages {
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
messages = append(messages, Message{
Role: message.Role,
Content: []Content{content},
})
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &Source{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
}
contents = append(contents, content)
}
messages = append(messages, Message{
Role: message.Role,
Content: contents,
})
}
return &Request{
AnthropicVersion: "vertex-2023-10-16",
Messages: messages,
MaxTokens: textRequest.MaxTokens,
Stream: textRequest.Stream,
}
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
}
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func stopReasonClaude2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return *reason
}
}

View File

@ -0,0 +1,63 @@
package vertex
type Request struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Content struct {
Type string `json:"type"`
Source *Source `json:"source,omitempty"`
Text string `json:"text,omitempty"`
}
type Source struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []Content `json:"content"`
Model string `json:"model"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage Usage `json:"usage"`
Error Error `json:"error"`
}
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}
type StreamResponse struct {
Type string `json:"type"`
Message *Response `json:"message"`
Index int `json:"index"`
ContentBlock *Content `json:"content_block"`
Delta *Delta `json:"delta"`
Usage *Usage `json:"usage"`
}

View File

@ -0,0 +1,164 @@
package vertex
import (
"bytes"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
"time"
)
type Credentials struct {
PrivateKey string
PrivateKeyID string
ClientEmail string
}
// ServiceAccount holds the credentials and scopes required for token generation
type ServiceAccount struct {
Cred *Credentials
Scopes string
}
var scopes = "https://www.googleapis.com/auth/cloud-platform"
// createSignedJWT creates a Signed JWT from service account credentials
func (sa *ServiceAccount) createSignedJWT() (string, error) {
if sa.Cred == nil {
return "", fmt.Errorf("credentials are nil")
}
issuedAt := time.Now()
expiresAt := issuedAt.Add(time.Hour)
claims := &jwt.MapClaims{
"iss": sa.Cred.ClientEmail,
"sub": sa.Cred.ClientEmail,
"aud": "https://www.googleapis.com/oauth2/v4/token",
"iat": issuedAt.Unix(),
"exp": expiresAt.Unix(),
"scope": scopes,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = sa.Cred.PrivateKeyID
token.Header["alg"] = "RS256"
token.Header["typ"] = "JWT"
// 解析 PEM 编码的私钥
block, _ := pem.Decode([]byte(sa.Cred.PrivateKey))
if block == nil {
return "", errors.New("failed to decode PEM block containing private key")
}
// 解析 RSA 私钥
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", errors.New("private key is not of type RSA")
}
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
// getToken uses the signed JWT to obtain an access token
func (sa *ServiceAccount) getToken(ctx context.Context) (string, error) {
signedJWT, err := sa.createSignedJWT()
if err != nil {
return "", err
}
return exchangeJwtForAccessToken(ctx, signedJWT)
}
// exchangeJwtForAccessToken exchanges a Signed JWT for a Google OAuth Access Token.
func exchangeJwtForAccessToken(ctx context.Context, signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
params := map[string]string{
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": signedJWT,
}
jsonData, err := json.Marshal(params)
if err != nil {
return "", err
}
// Create a new HTTP client with a timeout
client := &http.Client{
Timeout: time.Second * 5,
}
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var data map[string]interface{}
err = json.Unmarshal(body, &data)
if err != nil {
return "", err
}
// Extract the access token from the response
accessToken, ok := data["access_token"].(string)
if !ok {
return "", err // You might want to return a more specific error here
}
return accessToken, nil
}
func getToken(ctx context.Context, meta *meta.Meta) (string, error) {
// todo 每次请求都要换次token
encodedString := ""
decodedBytes, err := base64.StdEncoding.DecodeString(encodedString)
if err != nil {
return "", err
}
m := make(map[string]string)
err = json.Unmarshal(decodedBytes, &m)
if err != nil {
return "", err
}
sa := &ServiceAccount{
Cred: &Credentials{
PrivateKey: m["private_key"],
PrivateKeyID: m["private_key_id"],
ClientEmail: m["client_email"],
},
Scopes: scopes,
}
return sa.getToken(ctx)
}

View File

@ -15,6 +15,7 @@ const (
AwsClaude
Coze
Cohere
Vertex
Dummy // this one is only for count, do not add any channel after this
)

View File

@ -37,6 +37,7 @@ const (
AwsClaude
Coze
Cohere
Vertex
Dummy
)

View File

@ -31,6 +31,8 @@ func ToAPIType(channelType int) int {
apiType = apitype.Coze
case Cohere:
apiType = apitype.Cohere
case Vertex:
apiType = apitype.Vertex
}
return apiType

View File

@ -37,6 +37,7 @@ var ChannelBaseURLs = []string{
"", // 33
"https://api.coze.com", // 34
"https://api.cohere.ai", //35
"", //36
}
func init() {