From 79524108a3b5a06a0ff80ca8fbc3e0d748f58ed9 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Wed, 29 May 2024 00:36:54 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=96=20chore:=20Rename=20relay/util=20t?= =?UTF-8?q?o=20relay/relay=5Futil=20package=20and=20add=20utils=20package?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli/export.go | 4 +- common/config/config.go | 5 +- common/image/image.go | 9 +++ common/logger.go | 6 +- common/requester/http_client.go | 71 ++----------------- common/requester/http_requester.go | 14 +--- common/requester/ws_client.go | 21 +++--- common/stmp/email.go | 3 +- common/storage/storage_test.go | 7 +- common/telegram/command_aff.go | 3 +- common/telegram/common.go | 16 ++--- common/token.go | 2 +- common/{utils.go => utils/helper.go} | 16 ++--- common/utils/proxy.go | 77 +++++++++++++++++++++ controller/channel-test.go | 9 +-- controller/channel.go | 3 +- controller/github.go | 3 +- controller/option.go | 3 +- controller/pricing.go | 20 +++--- controller/redemption.go | 5 +- controller/token.go | 15 ++-- controller/user.go | 5 +- main.go | 6 +- middleware/auth.go | 5 +- middleware/rate-limit.go | 5 +- middleware/request-id.go | 3 +- middleware/utils.go | 6 +- model/balancer.go | 5 +- model/channel.go | 5 +- model/log.go | 5 +- model/main.go | 13 ++-- model/redemption.go | 5 +- model/telegram_menu.go | 4 +- model/token.go | 7 +- model/user.go | 11 +-- providers/ali/chat.go | 5 +- providers/claude/chat.go | 7 +- providers/cloudflareAI/chat.go | 9 +-- providers/cloudflareAI/image_generations.go | 3 +- providers/cohere/chat.go | 7 +- providers/coze/chat.go | 9 +-- providers/gemini/chat.go | 9 +-- providers/gemini/type.go | 3 +- providers/ollama/chat.go | 9 +-- providers/palm/chat.go | 5 +- providers/stabilityAI/image_generations.go | 3 +- providers/tencent/chat.go | 11 +-- providers/xunfei/chat.go | 5 +- relay/base.go | 10 +-- relay/chat.go | 5 +- relay/common.go | 7 +- relay/completions.go | 5 +- relay/main.go | 8 +-- relay/midjourney/relay-mj.go | 6 +- relay/model.go | 14 ++-- relay/{util => relay_util}/cache.go | 9 +-- relay/{util => relay_util}/cache_db.go | 8 +-- relay/{util => relay_util}/cache_redis.go | 7 +- relay/{util => relay_util}/pricing.go | 9 +-- relay/{util => relay_util}/quota.go | 2 +- relay/{util => relay_util}/type.go | 2 +- 61 files changed, 309 insertions(+), 265 deletions(-) rename common/{utils.go => utils/helper.go} (94%) create mode 100644 common/utils/proxy.go rename relay/{util => relay_util}/cache.go (93%) rename relay/{util => relay_util}/cache_db.go (84%) rename relay/{util => relay_util}/cache_redis.go (84%) rename relay/{util => relay_util}/pricing.go (97%) rename relay/{util => relay_util}/quota.go (99%) rename relay/{util => relay_util}/type.go (98%) diff --git a/cli/export.go b/cli/export.go index cda9788b..3cb2b23c 100644 --- a/cli/export.go +++ b/cli/export.go @@ -3,13 +3,13 @@ package cli import ( "encoding/json" "one-api/common" - "one-api/relay/util" + "one-api/relay/relay_util" "os" "sort" ) func ExportPrices() { - prices := util.GetPricesList("default") + prices := relay_util.GetPricesList("default") if len(prices) == 0 { common.SysError("No prices found") diff --git a/common/config/config.go b/common/config/config.go index 0535d1f9..cbe04c08 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -6,6 +6,7 @@ import ( "one-api/cli" "one-api/common" + "one-api/common/utils" "github.com/spf13/viper" ) @@ -22,11 +23,11 @@ func InitConf() { common.IsMasterNode = viper.GetString("node_type") != "slave" common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second - common.SessionSecret = common.GetOrDefault("session_secret", common.SessionSecret) + common.SessionSecret = utils.GetOrDefault("session_secret", common.SessionSecret) } func setConfigFile() { - if !common.IsFileExist(*cli.Config) { + if !utils.IsFileExist(*cli.Config) { return } diff --git a/common/image/image.go b/common/image/image.go index 92176c2f..df5e578b 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -16,6 +16,15 @@ import ( _ "golang.org/x/image/webp" ) +// var ImageHttpClients = &http.Client{ +// Transport: &http.Transport{ +// DialContext: requester.Socks5ProxyFunc, +// Proxy: requester.ProxyFunc, +// }, +// // +// // // Timeout: 30 * time.Second, +// } + func IsImageUrl(url string) (bool, error) { resp, err := http.Head(url) if err != nil { diff --git a/common/logger.go b/common/logger.go index 73a3539c..ba2c452d 100644 --- a/common/logger.go +++ b/common/logger.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "one-api/common/utils" + "github.com/gin-gonic/gin" "github.com/spf13/viper" ) @@ -62,13 +64,13 @@ func getLogDir() string { } var err error - logDir, err = filepath.Abs(viper.GetString("log_dir")) + logDir, err = filepath.Abs(logDir) if err != nil { log.Fatal(err) return "" } - if !IsFileExist(logDir) { + if !utils.IsFileExist(logDir) { err = os.Mkdir(logDir, 0777) if err != nil { log.Fatal(err) diff --git a/common/requester/http_client.go b/common/requester/http_client.go index 17f76ce0..4a351595 100644 --- a/common/requester/http_client.go +++ b/common/requester/http_client.go @@ -1,87 +1,24 @@ package requester import ( - "context" - "fmt" - "net" "net/http" - "net/url" - "one-api/common" + "one-api/common/utils" "time" - - "golang.org/x/net/proxy" ) -type ContextKey string - -const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr" -const ProxySock5AddrKey ContextKey = "proxySock5Addr" - -func proxyFunc(req *http.Request) (*url.URL, error) { - proxyAddr := req.Context().Value(ProxyHTTPAddrKey) - if proxyAddr == nil { - return nil, nil - } - - proxyURL, err := url.Parse(proxyAddr.(string)) - if err != nil { - return nil, fmt.Errorf("error parsing proxy address: %w", err) - } - - switch proxyURL.Scheme { - case "http", "https": - return proxyURL, nil - } - - return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) -} - -func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) { - // 设置TCP超时 - dialer := &net.Dialer{ - Timeout: time.Duration(common.GetOrDefault("connect_timeout", 5)) * time.Second, - KeepAlive: 30 * time.Second, - } - - // 从上下文中获取代理地址 - proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string) - if !ok { - return dialer.DialContext(ctx, network, addr) - } - - proxyURL, err := url.Parse(proxyAddr) - if err != nil { - return nil, fmt.Errorf("error parsing proxy address: %w", err) - } - var auth *proxy.Auth = nil - password, isSetPassword := proxyURL.User.Password() - if isSetPassword { - auth = &proxy.Auth{ - User: proxyURL.User.Username(), - Password: password, - } - } - proxyDialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("error creating socks5 dialer: %w", err) - } - - return proxyDialer.Dial(network, addr) -} - var HTTPClient *http.Client func InitHttpClient() { trans := &http.Transport{ - DialContext: socks5ProxyFunc, - Proxy: proxyFunc, + DialContext: utils.Socks5ProxyFunc, + Proxy: utils.ProxyFunc, } HTTPClient = &http.Client{ Transport: trans, } - relayTimeout := common.GetOrDefault("relay_timeout", 600) + relayTimeout := utils.GetOrDefault("relay_timeout", 600) if relayTimeout != 0 { HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second } diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go index f36601fe..328a778e 100644 --- a/common/requester/http_requester.go +++ b/common/requester/http_requester.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/utils" "one-api/types" "strconv" "strings" @@ -52,18 +53,7 @@ type requestOptions struct { type requestOption func(*requestOptions) func (r *HTTPRequester) setProxy() context.Context { - if r.proxyAddr == "" { - return r.Context - } - - // 如果是以 socks5:// 开头的地址,那么使用 socks5 代理 - if strings.HasPrefix(r.proxyAddr, "socks5://") { - return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr) - } - - // 否则使用 http 代理 - return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr) - + return utils.SetProxy(r.Context, r.proxyAddr) } // 创建请求 diff --git a/common/requester/ws_client.go b/common/requester/ws_client.go index 30000072..7b6cf937 100644 --- a/common/requester/ws_client.go +++ b/common/requester/ws_client.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/utils" "time" "github.com/gorilla/websocket" @@ -14,7 +15,7 @@ import ( func GetWSClient(proxyAddr string) *websocket.Dialer { dialer := &websocket.Dialer{ - HandshakeTimeout: time.Duration(common.GetOrDefault("connect_timeout", 5)) * time.Second, + HandshakeTimeout: time.Duration(utils.GetOrDefault("connect_timeout", 5)) * time.Second, } if proxyAddr != "" { @@ -38,20 +39,16 @@ func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error { case "http", "https": dialer.Proxy = http.ProxyURL(proxyURL) case "socks5": - var auth *proxy.Auth = nil - password, isSetPassword := proxyURL.User.Password() - if isSetPassword { - auth = &proxy.Auth{ - User: proxyURL.User.Username(), - Password: password, - } - } - socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) + proxyDialer, err := proxy.FromURL(proxyURL, proxy.Direct) if err != nil { - return fmt.Errorf("error creating socks5 dialer: %w", err) + return fmt.Errorf("error creating proxy dialer: %w", err) } + originalNetDial := dialer.NetDial dialer.NetDial = func(network, addr string) (net.Conn, error) { - return socks5Proxy.Dial(network, addr) + if originalNetDial != nil { + return originalNetDial(network, addr) + } + return proxyDialer.Dial(network, addr) } default: return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) diff --git a/common/stmp/email.go b/common/stmp/email.go index fd1014ee..1f8b6a03 100644 --- a/common/stmp/email.go +++ b/common/stmp/email.go @@ -3,6 +3,7 @@ package stmp import ( "fmt" "one-api/common" + "one-api/common/utils" "strings" "github.com/wneessen/go-mail" @@ -67,7 +68,7 @@ func (s *StmpConfig) Send(to, subject, body string) error { func (s *StmpConfig) getReferences() string { froms := strings.Split(s.From, "@") - return fmt.Sprintf("<%s.%s@%s>", froms[0], common.GetUUID(), froms[1]) + return fmt.Sprintf("<%s.%s@%s>", froms[0], utils.GetUUID(), froms[1]) } func (s *StmpConfig) Render(to, subject, content string) error { diff --git a/common/storage/storage_test.go b/common/storage/storage_test.go index 80359457..d3a00704 100644 --- a/common/storage/storage_test.go +++ b/common/storage/storage_test.go @@ -5,7 +5,8 @@ import ( "fmt" "testing" - "one-api/common" + "one-api/common/utils" + "one-api/common/requester" "one-api/common/storage/drives" @@ -32,7 +33,7 @@ func TestSMMSUpload(t *testing.T) { fmt.Println(err) } - url, err := smUpload.Upload(image, common.GetUUID()+".png") + url, err := smUpload.Upload(image, utils.GetUUID()+".png") fmt.Println(url) fmt.Println(err) assert.Nil(t, err) @@ -48,7 +49,7 @@ func TestImgurUpload(t *testing.T) { fmt.Println(err) } - url, err := imgurUpload.Upload(image, common.GetUUID()+".png") + url, err := imgurUpload.Upload(image, utils.GetUUID()+".png") fmt.Println(url) fmt.Println(err) assert.Nil(t, err) diff --git a/common/telegram/command_aff.go b/common/telegram/command_aff.go index fd15a2a6..68733169 100644 --- a/common/telegram/command_aff.go +++ b/common/telegram/command_aff.go @@ -2,6 +2,7 @@ package telegram import ( "one-api/common" + "one-api/common/utils" "strings" "github.com/PaulSonOfLars/gotgbot/v2" @@ -15,7 +16,7 @@ func commandAffStart(b *gotgbot.Bot, ctx *ext.Context) error { } if user.AffCode == "" { - user.AffCode = common.GetRandomString(4) + user.AffCode = utils.GetRandomString(4) if err := user.Update(false); err != nil { ctx.EffectiveMessage.Reply(b, "系统错误,请稍后再试", nil) return nil diff --git a/common/telegram/common.go b/common/telegram/common.go index bb3284d1..cd56193f 100644 --- a/common/telegram/common.go +++ b/common/telegram/common.go @@ -1,8 +1,10 @@ package telegram import ( + "context" "errors" "fmt" + "net" "net/http" "net/url" "one-api/common" @@ -243,22 +245,16 @@ func getHttpClient() (httpClient *http.Client) { }, } case "socks5": - var auth *proxy.Auth = nil - password, isSetPassword := proxyURL.User.Password() - if isSetPassword { - auth = &proxy.Auth{ - User: proxyURL.User.Username(), - Password: password, - } - } - dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) + dialer, err := proxy.FromURL(proxyURL, proxy.Direct) if err != nil { common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error()) return } httpClient = &http.Client{ Transport: &http.Transport{ - Dial: dialer.Dial, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr) + }, }, } default: diff --git a/common/token.go b/common/token.go index ec30f47e..65ec7243 100644 --- a/common/token.go +++ b/common/token.go @@ -119,7 +119,7 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in imageTokens, err := countImageTokens(url, detail) if err != nil { //Due to the excessive length of the error information, only extract and record the most critical part. - SysError("error counting image tokens: " + err.Error()) + SysError("error counting image tokens: " + err.Error()) } else { tokenNum += imageTokens } diff --git a/common/utils.go b/common/utils/helper.go similarity index 94% rename from common/utils.go rename to common/utils/helper.go index 8b0bd555..a9e408fd 100644 --- a/common/utils.go +++ b/common/utils/helper.go @@ -1,4 +1,4 @@ -package common +package utils import ( "encoding/json" @@ -109,13 +109,13 @@ func Seconds2Time(num int) (time string) { } func Interface2String(inter interface{}) string { - switch inter.(type) { + switch inter := inter.(type) { case string: - return inter.(string) + return inter case int: - return fmt.Sprintf("%d", inter.(int)) + return fmt.Sprintf("%d", inter) case float64: - return fmt.Sprintf("%f", inter.(float64)) + return fmt.Sprintf("%f", inter) } return "Not Implemented" } @@ -140,12 +140,7 @@ func GetUUID() string { const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -func init() { - rand.Seed(time.Now().UnixNano()) -} - func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) key := make([]byte, 48) for i := 0; i < 16; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] @@ -162,7 +157,6 @@ func GenerateKey() string { } func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) key := make([]byte, length) for i := 0; i < length; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] diff --git a/common/utils/proxy.go b/common/utils/proxy.go new file mode 100644 index 00000000..b0b1cf9d --- /dev/null +++ b/common/utils/proxy.go @@ -0,0 +1,77 @@ +package utils + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/net/proxy" +) + +type ContextKey string + +const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr" +const ProxySock5AddrKey ContextKey = "proxySock5Addr" + +func ProxyFunc(req *http.Request) (*url.URL, error) { + proxyAddr := req.Context().Value(ProxyHTTPAddrKey) + if proxyAddr == nil { + return nil, nil + } + + proxyURL, err := url.Parse(proxyAddr.(string)) + if err != nil { + return nil, fmt.Errorf("error parsing proxy address: %w", err) + } + + switch proxyURL.Scheme { + case "http", "https": + return proxyURL, nil + } + + return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) +} + +func Socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{ + Timeout: time.Duration(GetOrDefault("connect_timeout", 5)) * time.Second, + KeepAlive: 30 * time.Second, + } + + proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string) + if !ok { + return dialer.DialContext(ctx, network, addr) + } + + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return nil, fmt.Errorf("error parsing proxy address: %w", err) + } + + proxyDialer, err := proxy.FromURL(proxyURL, dialer) + if err != nil { + return nil, fmt.Errorf("error creating proxy dialer: %w", err) + } + + return proxyDialer.Dial(network, addr) +} + +func SetProxy(ctx context.Context, proxyAddr string) context.Context { + if proxyAddr == "" { + return ctx + } + + key := ProxyHTTPAddrKey + + // 如果是以 socks5:// 开头的地址,那么使用 socks5 代理 + if strings.HasPrefix(proxyAddr, "socks5://") { + key = ProxySock5AddrKey + } + + // 否则使用 http 代理 + return context.WithValue(ctx, key, proxyAddr) +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 7465320c..b52340d0 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "one-api/common" "one-api/common/notify" + "one-api/common/utils" "one-api/model" "one-api/providers" providers_base "one-api/providers/base" @@ -153,7 +154,7 @@ func testAllChannels(isNotify bool) error { time.Sleep(common.RequestInterval) isChannelEnabled := channel.Status == common.ChannelStatusEnabled - sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", common.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr()) + sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr()) tik := time.Now() err, openaiErr := testChannel(channel, "") tok := time.Now() @@ -161,7 +162,7 @@ func testAllChannels(isNotify bool) error { // 通道为禁用状态,并且还是请求错误 或者 响应时间超过阈值 直接跳过,也不需要更新响应时间。 if !isChannelEnabled { if err != nil { - sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", common.EscapeMarkdownText(err.Error())) + sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", utils.EscapeMarkdownText(err.Error())) continue } if milliseconds > disableThreshold { @@ -187,13 +188,13 @@ func testAllChannels(isNotify bool) error { } if ShouldDisableChannel(openaiErr, -1) { - sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", common.EscapeMarkdownText(err.Error())) + sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", utils.EscapeMarkdownText(err.Error())) DisableChannel(channel.Id, channel.Name, err.Error(), false) continue } if err != nil { - sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", common.EscapeMarkdownText(err.Error())) + sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", utils.EscapeMarkdownText(err.Error())) continue } } diff --git a/controller/channel.go b/controller/channel.go index 8ad724fd..0f401a45 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -4,6 +4,7 @@ import ( "errors" "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strconv" "strings" @@ -64,7 +65,7 @@ func AddChannel(c *gin.Context) { }) return } - channel.CreatedTime = common.GetTimestamp() + channel.CreatedTime = utils.GetTimestamp() keys := strings.Split(channel.Key, "\n") channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { diff --git a/controller/github.go b/controller/github.go index 00ec3a88..a6d923d0 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strconv" "time" @@ -216,7 +217,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := common.GetRandomString(12) + state := utils.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/option.go b/controller/option.go index bbf83578..99f36607 100644 --- a/controller/option.go +++ b/controller/option.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strings" @@ -19,7 +20,7 @@ func GetOptions(c *gin.Context) { } options = append(options, &model.Option{ Key: k, - Value: common.Interface2String(v), + Value: utils.Interface2String(v), }) } common.OptionMapRWMutex.Unlock() diff --git a/controller/pricing.go b/controller/pricing.go index d2d96e81..db059ee2 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -6,7 +6,7 @@ import ( "net/url" "one-api/common" "one-api/model" - "one-api/relay/util" + "one-api/relay/relay_util" "github.com/gin-gonic/gin" ) @@ -14,7 +14,7 @@ import ( func GetPricesList(c *gin.Context) { pricesType := c.DefaultQuery("type", "db") - prices := util.GetPricesList(pricesType) + prices := relay_util.GetPricesList(pricesType) if len(prices) == 0 { common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found")) @@ -33,7 +33,7 @@ func GetPricesList(c *gin.Context) { } func GetAllModelList(c *gin.Context) { - prices := util.PricingInstance.GetAllPrices() + prices := relay_util.PricingInstance.GetAllPrices() channelModel := model.ChannelGroup.Rule modelsMap := make(map[string]bool) @@ -68,7 +68,7 @@ func AddPrice(c *gin.Context) { return } - if err := util.PricingInstance.AddPrice(&price); err != nil { + if err := relay_util.PricingInstance.AddPrice(&price); err != nil { common.APIRespondWithError(c, http.StatusOK, err) return } @@ -94,7 +94,7 @@ func UpdatePrice(c *gin.Context) { return } - if err := util.PricingInstance.UpdatePrice(modelName, &price); err != nil { + if err := relay_util.PricingInstance.UpdatePrice(modelName, &price); err != nil { common.APIRespondWithError(c, http.StatusOK, err) return } @@ -114,7 +114,7 @@ func DeletePrice(c *gin.Context) { modelName = modelName[1:] modelName, _ = url.PathUnescape(modelName) - if err := util.PricingInstance.DeletePrice(modelName); err != nil { + if err := relay_util.PricingInstance.DeletePrice(modelName); err != nil { common.APIRespondWithError(c, http.StatusOK, err) return } @@ -127,7 +127,7 @@ func DeletePrice(c *gin.Context) { type PriceBatchRequest struct { OriginalModels []string `json:"original_models"` - util.BatchPrices + relay_util.BatchPrices } func BatchSetPrices(c *gin.Context) { @@ -137,7 +137,7 @@ func BatchSetPrices(c *gin.Context) { return } - if err := util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil { + if err := relay_util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil { common.APIRespondWithError(c, http.StatusOK, err) return } @@ -159,7 +159,7 @@ func BatchDeletePrices(c *gin.Context) { return } - if err := util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil { + if err := relay_util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil { common.APIRespondWithError(c, http.StatusOK, err) return } @@ -184,7 +184,7 @@ func SyncPricing(c *gin.Context) { return } - err := util.PricingInstance.SyncPricing(prices, overwrite == "true") + err := relay_util.PricingInstance.SyncPricing(prices, overwrite == "true") if err != nil { common.APIRespondWithError(c, http.StatusOK, err) return diff --git a/controller/redemption.go b/controller/redemption.go index ee111f88..5ad481aa 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -3,6 +3,7 @@ package controller import ( "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strconv" @@ -85,12 +86,12 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := common.GetUUID() + key := utils.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, Key: key, - CreatedTime: common.GetTimestamp(), + CreatedTime: utils.GetTimestamp(), Quota: redemption.Quota, } err = cleanRedemption.Insert() diff --git a/controller/token.go b/controller/token.go index d87391db..9f43b111 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,6 +3,7 @@ package controller import ( "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strconv" @@ -62,9 +63,9 @@ func GetPlaygroundToken(c *gin.Context) { cleanToken := model.Token{ UserId: userId, Name: tokenName, - Key: common.GenerateKey(), - CreatedTime: common.GetTimestamp(), - AccessedTime: common.GetTimestamp(), + Key: utils.GenerateKey(), + CreatedTime: utils.GetTimestamp(), + AccessedTime: utils.GetTimestamp(), ExpiredTime: 0, RemainQuota: 0, UnlimitedQuota: true, @@ -132,9 +133,9 @@ func AddToken(c *gin.Context) { cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, - Key: common.GenerateKey(), - CreatedTime: common.GetTimestamp(), - AccessedTime: common.GetTimestamp(), + Key: utils.GenerateKey(), + CreatedTime: utils.GetTimestamp(), + AccessedTime: utils.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, @@ -199,7 +200,7 @@ func UpdateToken(c *gin.Context) { return } if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", diff --git a/controller/user.go b/controller/user.go index 95cb4dee..4973314d 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strconv" "time" @@ -261,7 +262,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = common.GetUUID() + user.AccessToken = utils.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -297,7 +298,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = common.GetRandomString(4) + user.AffCode = utils.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/main.go b/main.go index b235d76b..973fe2f1 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,7 @@ import ( "one-api/cron" "one-api/middleware" "one-api/model" - "one-api/relay/util" + "one-api/relay/relay_util" "one-api/router" "time" @@ -40,7 +40,7 @@ func main() { common.InitRedisClient() // Initialize options model.InitOptionMap() - util.NewPricing() + relay_util.NewPricing() initMemoryCache() initSync() @@ -112,6 +112,6 @@ func SyncChannelCache(frequency int) { time.Sleep(time.Duration(frequency) * time.Second) common.SysLog("syncing channels from database") model.ChannelGroup.Load() - util.PricingInstance.Init() + relay_util.PricingInstance.Init() } } diff --git a/middleware/auth.go b/middleware/auth.go index 1eb0c2b4..697acacf 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "one-api/common" + "one-api/common/utils" "one-api/model" "strings" @@ -109,10 +110,10 @@ func tokenAuth(c *gin.Context, key string) { if len(parts) > 1 { if model.IsAdmin(token.UserId) { if strings.HasPrefix(parts[1], "!") { - channelId := common.String2Int(parts[1][1:]) + channelId := utils.String2Int(parts[1][1:]) c.Set("skip_channel_id", channelId) } else { - channelId := common.String2Int(parts[1]) + channelId := utils.String2Int(parts[1]) if channelId == 0 { abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") return diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 41b1aa59..dae6639c 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "one-api/common" + "one-api/common/utils" "time" "github.com/gin-gonic/gin" @@ -99,11 +100,11 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } func GlobalWebRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GetOrDefault("global.web_rate_limit", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW") + return rateLimitFactory(utils.GetOrDefault("global.web_rate_limit", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GetOrDefault("global.api_rate_limit", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA") + return rateLimitFactory(utils.GetOrDefault("global.api_rate_limit", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { diff --git a/middleware/request-id.go b/middleware/request-id.go index d626eeb1..edca8c6f 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -3,6 +3,7 @@ package middleware import ( "context" "one-api/common" + "one-api/common/utils" "time" "github.com/gin-gonic/gin" @@ -10,7 +11,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := common.GetTimeString() + common.GetRandomString(8) + id := utils.GetTimeString() + utils.GetRandomString(8) c.Set(common.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) ctx = context.WithValue(ctx, "requestStartTime", time.Now()) diff --git a/middleware/utils.go b/middleware/utils.go index 536125cc..bfa58881 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,14 +1,16 @@ package middleware import ( - "github.com/gin-gonic/gin" "one-api/common" + "one-api/common/utils" + + "github.com/gin-gonic/gin" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "message": utils.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "type": "one_api_error", }, }) diff --git a/model/balancer.go b/model/balancer.go index c71df3ce..cdcf906f 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -4,6 +4,7 @@ import ( "errors" "math/rand" "one-api/common" + "one-api/common/utils" "strings" "sync" "time" @@ -105,7 +106,7 @@ func (cc *ChannelsChooser) Next(group, modelName string, filters ...ChannelsFilt channelsPriority, ok := cc.Rule[group][modelName] if !ok { - matchModel := common.GetModelsWithMatch(&cc.Match, modelName) + matchModel := utils.GetModelsWithMatch(&cc.Match, modelName) channelsPriority, ok = cc.Rule[group][matchModel] if !ok { return nil, errors.New("model not found") @@ -199,7 +200,7 @@ func (cc *ChannelsChooser) Load() { // 逗号分割 ability.ChannelId channelIds := strings.Split(ability.ChannelIds, ",") for _, channelId := range channelIds { - priorityIds = append(priorityIds, common.String2Int(channelId)) + priorityIds = append(priorityIds, utils.String2Int(channelId)) } newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds) diff --git a/model/channel.go b/model/channel.go index 6469b536..7e6b2481 100644 --- a/model/channel.go +++ b/model/channel.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/utils" "strings" "gorm.io/datatypes" @@ -235,7 +236,7 @@ func (channel *Channel) UpdateRaw(overwrite bool) error { func (channel *Channel) UpdateResponseTime(responseTime int64) { err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ - TestTime: common.GetTimestamp(), + TestTime: utils.GetTimestamp(), ResponseTime: int(responseTime), }).Error if err != nil { @@ -245,7 +246,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { func (channel *Channel) UpdateBalance(balance float64) { err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ - BalanceUpdatedTime: common.GetTimestamp(), + BalanceUpdatedTime: utils.GetTimestamp(), Balance: balance, }).Error if err != nil { diff --git a/model/log.go b/model/log.go index 20a55ac7..b507ed22 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "one-api/common" + "one-api/common/utils" "gorm.io/gorm" ) @@ -41,7 +42,7 @@ func RecordLog(userId int, logType int, content string) { log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: utils.GetTimestamp(), Type: logType, Content: content, } @@ -59,7 +60,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: utils.GetTimestamp(), Type: LogTypeConsume, Content: content, PromptTokens: promptTokens, diff --git a/model/main.go b/model/main.go index 6e5ca224..637d6ba9 100644 --- a/model/main.go +++ b/model/main.go @@ -3,6 +3,7 @@ package model import ( "fmt" "one-api/common" + "one-api/common/utils" "strconv" "strings" "time" @@ -26,7 +27,7 @@ func SetupDB() { if viper.GetBool("batch_update_enabled") { common.BatchUpdateEnabled = true - common.BatchUpdateInterval = common.GetOrDefault("batch_update_interval", 5) + common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5) common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") InitBatchUpdater() } @@ -47,7 +48,7 @@ func createRootAccountIfNeed() error { Role: common.RoleRootUser, Status: common.UserStatusEnabled, DisplayName: "Root User", - AccessToken: common.GetUUID(), + AccessToken: utils.GetUUID(), Quota: 100000000, } DB.Create(&rootUser) @@ -78,7 +79,7 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - config := fmt.Sprintf("?_busy_timeout=%d", common.GetOrDefault("sqlite_busy_timeout", 3000)) + config := fmt.Sprintf("?_busy_timeout=%d", utils.GetOrDefault("sqlite_busy_timeout", 3000)) return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) @@ -96,9 +97,9 @@ func InitDB() (err error) { return err } - sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(utils.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60))) if !common.IsMasterNode { return nil diff --git a/model/redemption.go b/model/redemption.go index 07df8d80..821e244b 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/utils" "gorm.io/gorm" ) @@ -33,7 +34,7 @@ func GetRedemptionsList(params *GenericParams) (*DataResult[Redemption], error) var redemptions []*Redemption db := DB if params.Keyword != "" { - db = db.Where("id = ? or name LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%") + db = db.Where("id = ? or name LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%") } return PaginateAndOrder[Redemption](db, ¶ms.PaginationParams, &redemptions, allowedRedemptionslOrderFields) @@ -75,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return err } - redemption.RedeemedTime = common.GetTimestamp() + redemption.RedeemedTime = utils.GetTimestamp() redemption.Status = common.RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err diff --git a/model/telegram_menu.go b/model/telegram_menu.go index 451fda3e..fe579e67 100644 --- a/model/telegram_menu.go +++ b/model/telegram_menu.go @@ -2,7 +2,7 @@ package model import ( "errors" - "one-api/common" + "one-api/common/utils" ) type TelegramMenu struct { @@ -22,7 +22,7 @@ func GetTelegramMenusList(params *GenericParams) (*DataResult[TelegramMenu], err var menus []*TelegramMenu db := DB if params.Keyword != "" { - db = db.Where("id = ? or command LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%") + db = db.Where("id = ? or command LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%") } return PaginateAndOrder[TelegramMenu](db, ¶ms.PaginationParams, &menus, allowedTelegramMenusOrderFields) diff --git a/model/token.go b/model/token.go index 44f85601..2e89185d 100644 --- a/model/token.go +++ b/model/token.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/common/stmp" + "one-api/common/utils" "gorm.io/gorm" ) @@ -71,7 +72,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } - if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { + if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() { if !common.RedisEnabled { token.Status = common.TokenStatusExpired err := token.SelectUpdate() @@ -188,7 +189,7 @@ func increaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": utils.GetTimestamp(), }, ).Error return err @@ -210,7 +211,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": utils.GetTimestamp(), }, ).Error return err diff --git a/model/user.go b/model/user.go index 27933038..311c0702 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/utils" "strings" "gorm.io/gorm" @@ -55,7 +56,7 @@ func GetUsersList(params *GenericParams) (*DataResult[User], error) { var users []*User db := DB.Omit("password") if params.Keyword != "" { - db = db.Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%", params.Keyword+"%", params.Keyword+"%") + db = db.Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%", params.Keyword+"%", params.Keyword+"%") } return PaginateAndOrder[User](db, ¶ms.PaginationParams, &users, allowedUserOrderFields) @@ -115,9 +116,9 @@ func (user *User) Insert(inviterId int) error { } } user.Quota = common.QuotaForNewUser - user.AccessToken = common.GetUUID() - user.AffCode = common.GetRandomString(4) - user.CreatedTime = common.GetTimestamp() + user.AccessToken = utils.GetUUID() + user.AffCode = utils.GetRandomString(4) + user.CreatedTime = utils.GetTimestamp() result := DB.Create(user) if result.Error != nil { return result.Error @@ -165,7 +166,7 @@ func (user *User) Delete() error { } // 不改变当前数据库索引,通过更改用户名来删除用户 - user.Username = user.Username + "_del_" + common.GetRandomString(6) + user.Username = user.Username + "_del_" + utils.GetRandomString(6) err := user.Update(false) if err != nil { return err diff --git a/providers/ali/chat.go b/providers/ali/chat.go index ed6dac9f..8efc839a 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -92,7 +93,7 @@ func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *ty openaiResponse = &types.ChatCompletionResponse{ ID: response.RequestId, Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: request.Model, Choices: response.Output.ToChatCompletionChoices(), Usage: &types.Usage{ @@ -223,7 +224,7 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, d streamResponse := types.ChatCompletionStreamResponse{ ID: aliResponse.RequestId, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 471f0f4f..7c3d53e9 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/common/image" "one-api/common/requester" + "one-api/common/utils" "one-api/providers/base" "one-api/types" "strings" @@ -172,7 +173,7 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeRespon openaiResponse = &types.ChatCompletionResponse{ ID: response.Id, Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, Model: request.Model, Usage: &types.Usage{ @@ -264,9 +265,9 @@ func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStream choice.FinishReason = &finishReason } chatCompletion := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } diff --git a/providers/cloudflareAI/chat.go b/providers/cloudflareAI/chat.go index 5ba0bb95..873be274 100644 --- a/providers/cloudflareAI/chat.go +++ b/providers/cloudflareAI/chat.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -85,9 +86,9 @@ func (p *CloudflareAIProvider) convertToChatOpenai(response *ChatRespone, reques } openaiResponse = &types.ChatCompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: request.Model, Choices: []types.ChatCompletionChoice{{ Index: 0, @@ -155,9 +156,9 @@ func (h *CloudflareAIStreamHandler) handlerStream(rawLine *[]byte, dataChan chan func (h *CloudflareAIStreamHandler) convertToOpenaiStream(chatResponse *ChatResult, dataChan chan string, isStop bool) { streamResponse := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, } diff --git a/providers/cloudflareAI/image_generations.go b/providers/cloudflareAI/image_generations.go index 91ce7acf..0c3b8068 100644 --- a/providers/cloudflareAI/image_generations.go +++ b/providers/cloudflareAI/image_generations.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/storage" + "one-api/common/utils" "one-api/types" "time" ) @@ -46,7 +47,7 @@ func (p *CloudflareAIProvider) CreateImageGenerations(request *types.ImageReques url := "" if request.ResponseFormat == "" || request.ResponseFormat == "url" { - url = storage.Upload(body, common.GetUUID()+".png") + url = storage.Upload(body, utils.GetUUID()+".png") } openaiResponse := &types.ImageResponse{ diff --git a/providers/cohere/chat.go b/providers/cohere/chat.go index 0f4b5080..ae416ab6 100644 --- a/providers/cohere/chat.go +++ b/providers/cohere/chat.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/providers/base" "one-api/types" "strings" @@ -138,7 +139,7 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereRespon openaiResponse = &types.ChatCompletionResponse{ ID: response.GenerationID, Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, Model: request.Model, Usage: &types.Usage{}, @@ -190,9 +191,9 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream } chatCompletion := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } diff --git a/providers/coze/chat.go b/providers/coze/chat.go index daf52d67..3913e28f 100644 --- a/providers/coze/chat.go +++ b/providers/coze/chat.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -89,9 +90,9 @@ func (p *CozeProvider) convertToChatOpenai(response *CozeResponse, request *type } openaiResponse = &types.ChatCompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: request.Model, Choices: []types.ChatCompletionChoice{{ Index: 0, @@ -168,9 +169,9 @@ func (h *CozeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) { streamResponse := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, } diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index e1210fd3..28185345 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -145,9 +146,9 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque } openaiResponse = &types.ChatCompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: request.Model, Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), } @@ -191,9 +192,9 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) { streamResponse := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, // Choices: choices, } diff --git a/providers/gemini/type.go b/providers/gemini/type.go index 798d39da..79d1d680 100644 --- a/providers/gemini/type.go +++ b/providers/gemini/type.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/common/image" + "one-api/common/utils" "one-api/types" ) @@ -120,7 +121,7 @@ func (g *GeminiFunctionCall) ToOpenAITool() *types.ChatCompletionToolCalls { args, _ := json.Marshal(g.Args) return &types.ChatCompletionToolCalls{ - Id: "call_" + common.GetRandomString(24), + Id: "call_" + utils.GetRandomString(24), Type: types.ChatMessageRoleFunction, Index: 0, Function: &types.ChatCompletionToolCallsFunction{ diff --git a/providers/ollama/chat.go b/providers/ollama/chat.go index 0143b40b..5b01cf79 100644 --- a/providers/ollama/chat.go +++ b/providers/ollama/chat.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/common/image" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -100,9 +101,9 @@ func (p *OllamaProvider) convertToChatOpenai(response *ChatResponse, request *ty } openaiResponse = &types.ChatCompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: request.Model, Choices: []types.ChatCompletionChoice{choices}, Usage: &types.Usage{ @@ -195,9 +196,9 @@ func (h *ollamaStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin } chatCompletion := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } diff --git a/providers/palm/chat.go b/providers/palm/chat.go index b4c721bf..41577c96 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -177,11 +178,11 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp choice.FinishReason = types.FinishReasonStop streamResponse := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", Model: h.Request.Model, Choices: []types.ChatCompletionStreamChoice{choice}, - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), } responseBody, _ := json.Marshal(streamResponse) diff --git a/providers/stabilityAI/image_generations.go b/providers/stabilityAI/image_generations.go index 52987c95..e79a46bd 100644 --- a/providers/stabilityAI/image_generations.go +++ b/providers/stabilityAI/image_generations.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/storage" + "one-api/common/utils" "one-api/types" "time" ) @@ -71,7 +72,7 @@ func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest if request.ResponseFormat == "" || request.ResponseFormat == "url" { body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image) if err == nil { - imgUrl = storage.Upload(body, common.GetUUID()+".png") + imgUrl = storage.Upload(body, utils.GetUUID()+".png") } } diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 18e009ab..25625cf1 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" ) @@ -101,7 +102,7 @@ func (p *TencentProvider) convertToChatOpenai(response *TencentChatResponse, req openaiResponse = &types.ChatCompletionResponse{ Object: "chat.completion", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Usage: response.Usage, Model: request.Model, } @@ -137,9 +138,9 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq stream = 1 } return &TencentChatRequest{ - Timestamp: common.GetTimestamp(), - Expired: common.GetTimestamp() + 24*60*60, - QueryID: common.GetUUID(), + Timestamp: utils.GetTimestamp(), + Expired: utils.GetTimestamp() + 24*60*60, + QueryID: utils.GetUUID(), Temperature: request.Temperature, TopP: request.TopP, Stream: stream, @@ -178,7 +179,7 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string) { streamResponse := types.ChatCompletionStreamResponse{ Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, } if len(tencentChatResponse.Choices) > 0 { diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index 0578d73f..ee8d289f 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/types" "strings" @@ -194,7 +195,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa ID: xunfeiResponse.Header.Sid, Object: "chat.completion", Model: h.Request.Model, - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, Usage: &xunfeiResponse.Payload.Usage.Text, } @@ -310,7 +311,7 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp chatCompletion := types.ChatCompletionStreamResponse{ ID: xunfeiChatResponse.Header.Sid, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: h.Request.Model, } diff --git a/relay/base.go b/relay/base.go index 72d42aae..fc9da25a 100644 --- a/relay/base.go +++ b/relay/base.go @@ -1,7 +1,7 @@ package relay import ( - "one-api/relay/util" + "one-api/relay/relay_util" "one-api/types" providersBase "one-api/providers/base" @@ -14,7 +14,7 @@ type relayBase struct { provider providersBase.ProviderInterface originalModel string modelName string - cache *util.ChatCacheProps + cache *relay_util.ChatCacheProps } type RelayBaseInterface interface { @@ -28,14 +28,14 @@ type RelayBaseInterface interface { getModelName() string getContext() *gin.Context SetChatCache(allow bool) - GetChatCache() *util.ChatCacheProps + GetChatCache() *relay_util.ChatCacheProps } func (r *relayBase) SetChatCache(allow bool) { - r.cache = util.NewChatCacheProps(r.c, allow) + r.cache = relay_util.NewChatCacheProps(r.c, allow) } -func (r *relayBase) GetChatCache() *util.ChatCacheProps { +func (r *relayBase) GetChatCache() *relay_util.ChatCacheProps { return r.cache } diff --git a/relay/chat.go b/relay/chat.go index f8e1eb1b..e7bb6c26 100644 --- a/relay/chat.go +++ b/relay/chat.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" providersBase "one-api/providers/base" "one-api/types" @@ -100,9 +101,9 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) { func (r *relayChat) getUsageResponse() string { if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage { usageResponse := types.ChatCompletionStreamResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: r.chatRequest.Model, Choices: []types.ChatCompletionStreamChoice{}, Usage: r.provider.GetUsage(), diff --git a/relay/common.go b/relay/common.go index cebfcac5..45907804 100644 --- a/relay/common.go +++ b/relay/common.go @@ -9,11 +9,12 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" "one-api/controller" "one-api/model" "one-api/providers" providersBase "one-api/providers/base" - "one-api/relay/util" + "one-api/relay/relay_util" "one-api/types" "strings" @@ -142,7 +143,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith type StreamEndHandler func() string -func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) { +func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *relay_util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) { requester.SetEventStreamHeaders(c) dataChan, errChan := stream.Recv() @@ -257,7 +258,7 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { requestId := c.GetString(common.RequestIdKey) - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + err.OpenAIError.Message = utils.MessageWithRequestId(err.OpenAIError.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.OpenAIError, }) diff --git a/relay/completions.go b/relay/completions.go index 591b3c84..4aa69136 100644 --- a/relay/completions.go +++ b/relay/completions.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/common/utils" providersBase "one-api/providers/base" "one-api/types" @@ -93,9 +94,9 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo func (r *relayCompletions) getUsageResponse() string { if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage { usageResponse := types.CompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: utils.GetTimestamp(), Model: r.request.Model, Choices: []types.CompletionChoice{}, Usage: r.provider.GetUsage(), diff --git a/relay/main.go b/relay/main.go index 5123715c..6ed4d385 100644 --- a/relay/main.go +++ b/relay/main.go @@ -5,7 +5,7 @@ import ( "net/http" "one-api/common" "one-api/model" - "one-api/relay/util" + "one-api/relay/relay_util" "one-api/types" "time" @@ -96,8 +96,8 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod relay.getProvider().SetUsage(usage) - var quota *util.Quota - quota, err = util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens) + var quota *relay_util.Quota + quota, err = relay_util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens) if err != nil { done = true return @@ -119,7 +119,7 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod return } -func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) { +func cacheProcessing(c *gin.Context, cacheProps *relay_util.ChatCacheProps) { responseCache(c, cacheProps.Response) // 写入日志 diff --git a/relay/midjourney/relay-mj.go b/relay/midjourney/relay-mj.go index 4ff6f40f..078e0ff4 100644 --- a/relay/midjourney/relay-mj.go +++ b/relay/midjourney/relay-mj.go @@ -14,7 +14,7 @@ import ( "one-api/model" provider "one-api/providers/midjourney" "one-api/relay" - "one-api/relay/util" + "one-api/relay/relay_util" "one-api/types" "strconv" "strings" @@ -539,10 +539,10 @@ func getMjRequestPath(path string) string { return requestURL } -func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) { +func getQuota(c *gin.Context, action string) (*relay_util.Quota, *types.OpenAIErrorWithStatusCode) { modelName := CoverActionToModelName(action) - return util.NewQuota(c, modelName, 1000) + return relay_util.NewQuota(c, modelName, 1000) } func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { diff --git a/relay/model.go b/relay/model.go index a4effbba..de9acc0f 100644 --- a/relay/model.go +++ b/relay/model.go @@ -5,7 +5,7 @@ import ( "net/http" "one-api/common" "one-api/model" - "one-api/relay/util" + "one-api/relay/relay_util" "one-api/types" "sort" @@ -90,7 +90,7 @@ func ListModels(c *gin.Context) { } func ListModelsForAdmin(c *gin.Context) { - prices := util.PricingInstance.GetAllPrices() + prices := relay_util.PricingInstance.GetAllPrices() var openAIModels []OpenAIModels for modelId, price := range prices { openAIModels = append(openAIModels, OpenAIModels{ @@ -123,7 +123,7 @@ func ListModelsForAdmin(c *gin.Context) { func RetrieveModel(c *gin.Context) { modelName := c.Param("model") openaiModel := getOpenAIModelWithName(modelName) - if *openaiModel.OwnedBy != util.UnknownOwnedBy { + if *openaiModel.OwnedBy != relay_util.UnknownOwnedBy { c.JSON(200, openaiModel) } else { openAIError := types.OpenAIError{ @@ -139,15 +139,15 @@ func RetrieveModel(c *gin.Context) { } func getModelOwnedBy(channelType int) (ownedBy *string) { - if ownedByName, ok := util.ModelOwnedBy[channelType]; ok { + if ownedByName, ok := relay_util.ModelOwnedBy[channelType]; ok { return &ownedByName } - return &util.UnknownOwnedBy + return &relay_util.UnknownOwnedBy } func getOpenAIModelWithName(modelName string) *OpenAIModels { - price := util.PricingInstance.GetPrice(modelName) + price := relay_util.PricingInstance.GetPrice(modelName) return &OpenAIModels{ Id: modelName, @@ -164,6 +164,6 @@ func GetModelOwnedBy(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": util.ModelOwnedBy, + "data": relay_util.ModelOwnedBy, }) } diff --git a/relay/util/cache.go b/relay/relay_util/cache.go similarity index 93% rename from relay/util/cache.go rename to relay/relay_util/cache.go index 8cb3a6d2..8315e7a3 100644 --- a/relay/util/cache.go +++ b/relay/relay_util/cache.go @@ -1,10 +1,11 @@ -package util +package relay_util import ( "crypto/md5" "encoding/hex" "fmt" "one-api/common" + "one-api/common/utils" "one-api/model" "github.com/gin-gonic/gin" @@ -37,7 +38,7 @@ func GetDebugList(userId int) ([]*ChatCacheProps, error) { var props []*ChatCacheProps for _, cache := range caches { - prop, err := common.UnmarshalString[ChatCacheProps](cache.Data) + prop, err := utils.UnmarshalString[ChatCacheProps](cache.Data) if err != nil { continue } @@ -77,7 +78,7 @@ func (p *ChatCacheProps) SetHash(request any) { return } - p.hash(common.Marshal(request)) + p.hash(utils.Marshal(request)) } func (p *ChatCacheProps) SetResponse(response any) { @@ -90,7 +91,7 @@ func (p *ChatCacheProps) SetResponse(response any) { return } - responseStr := common.Marshal(response) + responseStr := utils.Marshal(response) if responseStr == "" { return } diff --git a/relay/util/cache_db.go b/relay/relay_util/cache_db.go similarity index 84% rename from relay/util/cache_db.go rename to relay/relay_util/cache_db.go index 22466449..0a1a16a2 100644 --- a/relay/util/cache_db.go +++ b/relay/relay_util/cache_db.go @@ -1,8 +1,8 @@ -package util +package relay_util import ( "errors" - "one-api/common" + "one-api/common/utils" "one-api/model" "time" ) @@ -15,7 +15,7 @@ func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps { return nil } - props, err := common.UnmarshalString[ChatCacheProps](cache.Data) + props, err := utils.UnmarshalString[ChatCacheProps](cache.Data) if err != nil { return nil } @@ -28,7 +28,7 @@ func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) err } func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error { - data := common.Marshal(props) + data := utils.Marshal(props) if data == "" { return errors.New("marshal error") } diff --git a/relay/util/cache_redis.go b/relay/relay_util/cache_redis.go similarity index 84% rename from relay/util/cache_redis.go rename to relay/relay_util/cache_redis.go index a9eaa81f..d5ad7a59 100644 --- a/relay/util/cache_redis.go +++ b/relay/relay_util/cache_redis.go @@ -1,9 +1,10 @@ -package util +package relay_util import ( "errors" "fmt" "one-api/common" + "one-api/common/utils" "time" ) @@ -17,7 +18,7 @@ func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps { return nil } - props, err := common.UnmarshalString[ChatCacheProps](cache) + props, err := utils.UnmarshalString[ChatCacheProps](cache) if err != nil { return nil } @@ -31,7 +32,7 @@ func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) e return nil } - data := common.Marshal(&props) + data := utils.Marshal(&props) if data == "" { return errors.New("marshal error") } diff --git a/relay/util/pricing.go b/relay/relay_util/pricing.go similarity index 97% rename from relay/util/pricing.go rename to relay/relay_util/pricing.go index 7588428d..67c46e54 100644 --- a/relay/util/pricing.go +++ b/relay/relay_util/pricing.go @@ -1,9 +1,10 @@ -package util +package relay_util import ( "encoding/json" "errors" "one-api/common" + "one-api/common/utils" "one-api/model" "sort" "strings" @@ -98,7 +99,7 @@ func (p *Pricing) GetPrice(modelName string) *model.Price { return price } - matchModel := common.GetModelsWithMatch(&p.Match, modelName) + matchModel := utils.GetModelsWithMatch(&p.Match, modelName) if price, ok := p.Prices[matchModel]; ok { return price } @@ -281,7 +282,7 @@ func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []stri var updatePrices []string for _, model := range originalModels { - if !common.Contains(model, batchPrices.Models) { + if !utils.Contains(model, batchPrices.Models) { deletePrices = append(deletePrices, model) } else { updatePrices = append(updatePrices, model) @@ -289,7 +290,7 @@ func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []stri } for _, model := range batchPrices.Models { - if !common.Contains(model, originalModels) { + if !utils.Contains(model, originalModels) { addPrice := batchPrices.Price addPrice.Model = model addPrices = append(addPrices, &addPrice) diff --git a/relay/util/quota.go b/relay/relay_util/quota.go similarity index 99% rename from relay/util/quota.go rename to relay/relay_util/quota.go index 79ba1a9a..c859361e 100644 --- a/relay/util/quota.go +++ b/relay/relay_util/quota.go @@ -1,4 +1,4 @@ -package util +package relay_util import ( "context" diff --git a/relay/util/type.go b/relay/relay_util/type.go similarity index 98% rename from relay/util/type.go rename to relay/relay_util/type.go index 6794ca16..9c2d72d3 100644 --- a/relay/util/type.go +++ b/relay/relay_util/type.go @@ -1,4 +1,4 @@ -package util +package relay_util import "one-api/common"