🔖 chore: Rename relay/util to relay/relay_util package and add utils package

This commit is contained in:
MartialBE 2024-05-29 00:36:54 +08:00
parent 853f2681f4
commit 79524108a3
No known key found for this signature in database
GPG Key ID: 27C0267EC84B0A5C
61 changed files with 309 additions and 265 deletions

View File

@ -3,13 +3,13 @@ package cli
import ( import (
"encoding/json" "encoding/json"
"one-api/common" "one-api/common"
"one-api/relay/util" "one-api/relay/relay_util"
"os" "os"
"sort" "sort"
) )
func ExportPrices() { func ExportPrices() {
prices := util.GetPricesList("default") prices := relay_util.GetPricesList("default")
if len(prices) == 0 { if len(prices) == 0 {
common.SysError("No prices found") common.SysError("No prices found")

View File

@ -6,6 +6,7 @@ import (
"one-api/cli" "one-api/cli"
"one-api/common" "one-api/common"
"one-api/common/utils"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -22,11 +23,11 @@ func InitConf() {
common.IsMasterNode = viper.GetString("node_type") != "slave" common.IsMasterNode = viper.GetString("node_type") != "slave"
common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second
common.SessionSecret = common.GetOrDefault("session_secret", common.SessionSecret) common.SessionSecret = utils.GetOrDefault("session_secret", common.SessionSecret)
} }
func setConfigFile() { func setConfigFile() {
if !common.IsFileExist(*cli.Config) { if !utils.IsFileExist(*cli.Config) {
return return
} }

View File

@ -16,6 +16,15 @@ import (
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
) )
// var ImageHttpClients = &http.Client{
// Transport: &http.Transport{
// DialContext: requester.Socks5ProxyFunc,
// Proxy: requester.ProxyFunc,
// },
// //
// // // Timeout: 30 * time.Second,
// }
func IsImageUrl(url string) (bool, error) { func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url) resp, err := http.Head(url)
if err != nil { if err != nil {

View File

@ -10,6 +10,8 @@ import (
"sync" "sync"
"time" "time"
"one-api/common/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -62,13 +64,13 @@ func getLogDir() string {
} }
var err error var err error
logDir, err = filepath.Abs(viper.GetString("log_dir")) logDir, err = filepath.Abs(logDir)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
return "" return ""
} }
if !IsFileExist(logDir) { if !utils.IsFileExist(logDir) {
err = os.Mkdir(logDir, 0777) err = os.Mkdir(logDir, 0777)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -1,87 +1,24 @@
package requester package requester
import ( import (
"context"
"fmt"
"net"
"net/http" "net/http"
"net/url" "one-api/common/utils"
"one-api/common"
"time" "time"
"golang.org/x/net/proxy"
) )
type ContextKey string
const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
const ProxySock5AddrKey ContextKey = "proxySock5Addr"
func proxyFunc(req *http.Request) (*url.URL, error) {
proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
if proxyAddr == nil {
return nil, nil
}
proxyURL, err := url.Parse(proxyAddr.(string))
if err != nil {
return nil, fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
return proxyURL, nil
}
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 设置TCP超时
dialer := &net.Dialer{
Timeout: time.Duration(common.GetOrDefault("connect_timeout", 5)) * time.Second,
KeepAlive: 30 * time.Second,
}
// 从上下文中获取代理地址
proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string)
if !ok {
return dialer.DialContext(ctx, network, addr)
}
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return nil, fmt.Errorf("error parsing proxy address: %w", err)
}
var auth *proxy.Auth = nil
password, isSetPassword := proxyURL.User.Password()
if isSetPassword {
auth = &proxy.Auth{
User: proxyURL.User.Username(),
Password: password,
}
}
proxyDialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("error creating socks5 dialer: %w", err)
}
return proxyDialer.Dial(network, addr)
}
var HTTPClient *http.Client var HTTPClient *http.Client
func InitHttpClient() { func InitHttpClient() {
trans := &http.Transport{ trans := &http.Transport{
DialContext: socks5ProxyFunc, DialContext: utils.Socks5ProxyFunc,
Proxy: proxyFunc, Proxy: utils.ProxyFunc,
} }
HTTPClient = &http.Client{ HTTPClient = &http.Client{
Transport: trans, Transport: trans,
} }
relayTimeout := common.GetOrDefault("relay_timeout", 600) relayTimeout := utils.GetOrDefault("relay_timeout", 600)
if relayTimeout != 0 { if relayTimeout != 0 {
HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second
} }

View File

@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strconv" "strconv"
"strings" "strings"
@ -52,18 +53,7 @@ type requestOptions struct {
type requestOption func(*requestOptions) type requestOption func(*requestOptions)
func (r *HTTPRequester) setProxy() context.Context { func (r *HTTPRequester) setProxy() context.Context {
if r.proxyAddr == "" { return utils.SetProxy(r.Context, r.proxyAddr)
return r.Context
}
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
if strings.HasPrefix(r.proxyAddr, "socks5://") {
return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr)
}
// 否则使用 http 代理
return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr)
} }
// 创建请求 // 创建请求

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/common/utils"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -14,7 +15,7 @@ import (
func GetWSClient(proxyAddr string) *websocket.Dialer { func GetWSClient(proxyAddr string) *websocket.Dialer {
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
HandshakeTimeout: time.Duration(common.GetOrDefault("connect_timeout", 5)) * time.Second, HandshakeTimeout: time.Duration(utils.GetOrDefault("connect_timeout", 5)) * time.Second,
} }
if proxyAddr != "" { if proxyAddr != "" {
@ -38,20 +39,16 @@ func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
case "http", "https": case "http", "https":
dialer.Proxy = http.ProxyURL(proxyURL) dialer.Proxy = http.ProxyURL(proxyURL)
case "socks5": case "socks5":
var auth *proxy.Auth = nil proxyDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
password, isSetPassword := proxyURL.User.Password()
if isSetPassword {
auth = &proxy.Auth{
User: proxyURL.User.Username(),
Password: password,
}
}
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if err != nil { if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err) return fmt.Errorf("error creating proxy dialer: %w", err)
} }
originalNetDial := dialer.NetDial
dialer.NetDial = func(network, addr string) (net.Conn, error) { dialer.NetDial = func(network, addr string) (net.Conn, error) {
return socks5Proxy.Dial(network, addr) if originalNetDial != nil {
return originalNetDial(network, addr)
}
return proxyDialer.Dial(network, addr)
} }
default: default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)

View File

@ -3,6 +3,7 @@ package stmp
import ( import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"strings" "strings"
"github.com/wneessen/go-mail" "github.com/wneessen/go-mail"
@ -67,7 +68,7 @@ func (s *StmpConfig) Send(to, subject, body string) error {
func (s *StmpConfig) getReferences() string { func (s *StmpConfig) getReferences() string {
froms := strings.Split(s.From, "@") froms := strings.Split(s.From, "@")
return fmt.Sprintf("<%s.%s@%s>", froms[0], common.GetUUID(), froms[1]) return fmt.Sprintf("<%s.%s@%s>", froms[0], utils.GetUUID(), froms[1])
} }
func (s *StmpConfig) Render(to, subject, content string) error { func (s *StmpConfig) Render(to, subject, content string) error {

View File

@ -5,7 +5,8 @@ import (
"fmt" "fmt"
"testing" "testing"
"one-api/common" "one-api/common/utils"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/storage/drives" "one-api/common/storage/drives"
@ -32,7 +33,7 @@ func TestSMMSUpload(t *testing.T) {
fmt.Println(err) fmt.Println(err)
} }
url, err := smUpload.Upload(image, common.GetUUID()+".png") url, err := smUpload.Upload(image, utils.GetUUID()+".png")
fmt.Println(url) fmt.Println(url)
fmt.Println(err) fmt.Println(err)
assert.Nil(t, err) assert.Nil(t, err)
@ -48,7 +49,7 @@ func TestImgurUpload(t *testing.T) {
fmt.Println(err) fmt.Println(err)
} }
url, err := imgurUpload.Upload(image, common.GetUUID()+".png") url, err := imgurUpload.Upload(image, utils.GetUUID()+".png")
fmt.Println(url) fmt.Println(url)
fmt.Println(err) fmt.Println(err)
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -2,6 +2,7 @@ package telegram
import ( import (
"one-api/common" "one-api/common"
"one-api/common/utils"
"strings" "strings"
"github.com/PaulSonOfLars/gotgbot/v2" "github.com/PaulSonOfLars/gotgbot/v2"
@ -15,7 +16,7 @@ func commandAffStart(b *gotgbot.Bot, ctx *ext.Context) error {
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = common.GetRandomString(4) user.AffCode = utils.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
ctx.EffectiveMessage.Reply(b, "系统错误,请稍后再试", nil) ctx.EffectiveMessage.Reply(b, "系统错误,请稍后再试", nil)
return nil return nil

View File

@ -1,8 +1,10 @@
package telegram package telegram
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
@ -243,22 +245,16 @@ func getHttpClient() (httpClient *http.Client) {
}, },
} }
case "socks5": case "socks5":
var auth *proxy.Auth = nil dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
password, isSetPassword := proxyURL.User.Password()
if isSetPassword {
auth = &proxy.Auth{
User: proxyURL.User.Username(),
Password: password,
}
}
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if err != nil { if err != nil {
common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error()) common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error())
return return
} }
httpClient = &http.Client{ httpClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Dial: dialer.Dial, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
},
}, },
} }
default: default:

View File

@ -119,7 +119,7 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in
imageTokens, err := countImageTokens(url, detail) imageTokens, err := countImageTokens(url, detail)
if err != nil { if err != nil {
//Due to the excessive length of the error information, only extract and record the most critical part. //Due to the excessive length of the error information, only extract and record the most critical part.
SysError("error counting image tokens: " + err.Error()) SysError("error counting image tokens: " + err.Error())
} else { } else {
tokenNum += imageTokens tokenNum += imageTokens
} }

View File

@ -1,4 +1,4 @@
package common package utils
import ( import (
"encoding/json" "encoding/json"
@ -109,13 +109,13 @@ func Seconds2Time(num int) (time string) {
} }
func Interface2String(inter interface{}) string { func Interface2String(inter interface{}) string {
switch inter.(type) { switch inter := inter.(type) {
case string: case string:
return inter.(string) return inter
case int: case int:
return fmt.Sprintf("%d", inter.(int)) return fmt.Sprintf("%d", inter)
case float64: case float64:
return fmt.Sprintf("%f", inter.(float64)) return fmt.Sprintf("%f", inter)
} }
return "Not Implemented" return "Not Implemented"
} }
@ -140,12 +140,7 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string { func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48) key := make([]byte, 48)
for i := 0; i < 16; i++ { for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))] key[i] = keyChars[rand.Intn(len(keyChars))]
@ -162,7 +157,6 @@ func GenerateKey() string {
} }
func GetRandomString(length int) string { func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length) key := make([]byte, length)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))] key[i] = keyChars[rand.Intn(len(keyChars))]

77
common/utils/proxy.go Normal file
View File

@ -0,0 +1,77 @@
package utils
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/net/proxy"
)
type ContextKey string
const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
const ProxySock5AddrKey ContextKey = "proxySock5Addr"
func ProxyFunc(req *http.Request) (*url.URL, error) {
proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
if proxyAddr == nil {
return nil, nil
}
proxyURL, err := url.Parse(proxyAddr.(string))
if err != nil {
return nil, fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
return proxyURL, nil
}
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
func Socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: time.Duration(GetOrDefault("connect_timeout", 5)) * time.Second,
KeepAlive: 30 * time.Second,
}
proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string)
if !ok {
return dialer.DialContext(ctx, network, addr)
}
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return nil, fmt.Errorf("error parsing proxy address: %w", err)
}
proxyDialer, err := proxy.FromURL(proxyURL, dialer)
if err != nil {
return nil, fmt.Errorf("error creating proxy dialer: %w", err)
}
return proxyDialer.Dial(network, addr)
}
func SetProxy(ctx context.Context, proxyAddr string) context.Context {
if proxyAddr == "" {
return ctx
}
key := ProxyHTTPAddrKey
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
if strings.HasPrefix(proxyAddr, "socks5://") {
key = ProxySock5AddrKey
}
// 否则使用 http 代理
return context.WithValue(ctx, key, proxyAddr)
}

View File

@ -8,6 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"one-api/common" "one-api/common"
"one-api/common/notify" "one-api/common/notify"
"one-api/common/utils"
"one-api/model" "one-api/model"
"one-api/providers" "one-api/providers"
providers_base "one-api/providers/base" providers_base "one-api/providers/base"
@ -153,7 +154,7 @@ func testAllChannels(isNotify bool) error {
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", common.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr()) sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr())
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel, "") err, openaiErr := testChannel(channel, "")
tok := time.Now() tok := time.Now()
@ -161,7 +162,7 @@ func testAllChannels(isNotify bool) error {
// 通道为禁用状态,并且还是请求错误 或者 响应时间超过阈值 直接跳过,也不需要更新响应时间。 // 通道为禁用状态,并且还是请求错误 或者 响应时间超过阈值 直接跳过,也不需要更新响应时间。
if !isChannelEnabled { if !isChannelEnabled {
if err != nil { if err != nil {
sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", common.EscapeMarkdownText(err.Error())) sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", utils.EscapeMarkdownText(err.Error()))
continue continue
} }
if milliseconds > disableThreshold { if milliseconds > disableThreshold {
@ -187,13 +188,13 @@ func testAllChannels(isNotify bool) error {
} }
if ShouldDisableChannel(openaiErr, -1) { if ShouldDisableChannel(openaiErr, -1) {
sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", common.EscapeMarkdownText(err.Error())) sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", utils.EscapeMarkdownText(err.Error()))
DisableChannel(channel.Id, channel.Name, err.Error(), false) DisableChannel(channel.Id, channel.Name, err.Error(), false)
continue continue
} }
if err != nil { if err != nil {
sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", common.EscapeMarkdownText(err.Error())) sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", utils.EscapeMarkdownText(err.Error()))
continue continue
} }
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
@ -64,7 +65,7 @@ func AddChannel(c *gin.Context) {
}) })
return return
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = utils.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
@ -216,7 +217,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
state := common.GetRandomString(12) state := utils.GetRandomString(12)
session.Set("oauth_state", state) session.Set("oauth_state", state)
err := session.Save() err := session.Save()
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strings" "strings"
@ -19,7 +20,7 @@ func GetOptions(c *gin.Context) {
} }
options = append(options, &model.Option{ options = append(options, &model.Option{
Key: k, Key: k,
Value: common.Interface2String(v), Value: utils.Interface2String(v),
}) })
} }
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()

View File

@ -6,7 +6,7 @@ import (
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/relay_util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -14,7 +14,7 @@ import (
func GetPricesList(c *gin.Context) { func GetPricesList(c *gin.Context) {
pricesType := c.DefaultQuery("type", "db") pricesType := c.DefaultQuery("type", "db")
prices := util.GetPricesList(pricesType) prices := relay_util.GetPricesList(pricesType)
if len(prices) == 0 { if len(prices) == 0 {
common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found")) common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found"))
@ -33,7 +33,7 @@ func GetPricesList(c *gin.Context) {
} }
func GetAllModelList(c *gin.Context) { func GetAllModelList(c *gin.Context) {
prices := util.PricingInstance.GetAllPrices() prices := relay_util.PricingInstance.GetAllPrices()
channelModel := model.ChannelGroup.Rule channelModel := model.ChannelGroup.Rule
modelsMap := make(map[string]bool) modelsMap := make(map[string]bool)
@ -68,7 +68,7 @@ func AddPrice(c *gin.Context) {
return return
} }
if err := util.PricingInstance.AddPrice(&price); err != nil { if err := relay_util.PricingInstance.AddPrice(&price); err != nil {
common.APIRespondWithError(c, http.StatusOK, err) common.APIRespondWithError(c, http.StatusOK, err)
return return
} }
@ -94,7 +94,7 @@ func UpdatePrice(c *gin.Context) {
return return
} }
if err := util.PricingInstance.UpdatePrice(modelName, &price); err != nil { if err := relay_util.PricingInstance.UpdatePrice(modelName, &price); err != nil {
common.APIRespondWithError(c, http.StatusOK, err) common.APIRespondWithError(c, http.StatusOK, err)
return return
} }
@ -114,7 +114,7 @@ func DeletePrice(c *gin.Context) {
modelName = modelName[1:] modelName = modelName[1:]
modelName, _ = url.PathUnescape(modelName) modelName, _ = url.PathUnescape(modelName)
if err := util.PricingInstance.DeletePrice(modelName); err != nil { if err := relay_util.PricingInstance.DeletePrice(modelName); err != nil {
common.APIRespondWithError(c, http.StatusOK, err) common.APIRespondWithError(c, http.StatusOK, err)
return return
} }
@ -127,7 +127,7 @@ func DeletePrice(c *gin.Context) {
type PriceBatchRequest struct { type PriceBatchRequest struct {
OriginalModels []string `json:"original_models"` OriginalModels []string `json:"original_models"`
util.BatchPrices relay_util.BatchPrices
} }
func BatchSetPrices(c *gin.Context) { func BatchSetPrices(c *gin.Context) {
@ -137,7 +137,7 @@ func BatchSetPrices(c *gin.Context) {
return return
} }
if err := util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil { if err := relay_util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil {
common.APIRespondWithError(c, http.StatusOK, err) common.APIRespondWithError(c, http.StatusOK, err)
return return
} }
@ -159,7 +159,7 @@ func BatchDeletePrices(c *gin.Context) {
return return
} }
if err := util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil { if err := relay_util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil {
common.APIRespondWithError(c, http.StatusOK, err) common.APIRespondWithError(c, http.StatusOK, err)
return return
} }
@ -184,7 +184,7 @@ func SyncPricing(c *gin.Context) {
return return
} }
err := util.PricingInstance.SyncPricing(prices, overwrite == "true") err := relay_util.PricingInstance.SyncPricing(prices, overwrite == "true")
if err != nil { if err != nil {
common.APIRespondWithError(c, http.StatusOK, err) common.APIRespondWithError(c, http.StatusOK, err)
return return

View File

@ -3,6 +3,7 @@ package controller
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
@ -85,12 +86,12 @@ func AddRedemption(c *gin.Context) {
} }
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := common.GetUUID() key := utils.GetUUID()
cleanRedemption := model.Redemption{ cleanRedemption := model.Redemption{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: redemption.Name, Name: redemption.Name,
Key: key, Key: key,
CreatedTime: common.GetTimestamp(), CreatedTime: utils.GetTimestamp(),
Quota: redemption.Quota, Quota: redemption.Quota,
} }
err = cleanRedemption.Insert() err = cleanRedemption.Insert()

View File

@ -3,6 +3,7 @@ package controller
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
@ -62,9 +63,9 @@ func GetPlaygroundToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: userId, UserId: userId,
Name: tokenName, Name: tokenName,
Key: common.GenerateKey(), Key: utils.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: utils.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: utils.GetTimestamp(),
ExpiredTime: 0, ExpiredTime: 0,
RemainQuota: 0, RemainQuota: 0,
UnlimitedQuota: true, UnlimitedQuota: true,
@ -132,9 +133,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: common.GenerateKey(), Key: utils.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: utils.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: utils.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota, RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
@ -199,7 +200,7 @@ func UpdateToken(c *gin.Context) {
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
@ -261,7 +262,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = common.GetUUID() user.AccessToken = utils.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -297,7 +298,7 @@ func GetAffCode(c *gin.Context) {
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = common.GetRandomString(4) user.AffCode = utils.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -13,7 +13,7 @@ import (
"one-api/cron" "one-api/cron"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/relay_util"
"one-api/router" "one-api/router"
"time" "time"
@ -40,7 +40,7 @@ func main() {
common.InitRedisClient() common.InitRedisClient()
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
util.NewPricing() relay_util.NewPricing()
initMemoryCache() initMemoryCache()
initSync() initSync()
@ -112,6 +112,6 @@ func SyncChannelCache(frequency int) {
time.Sleep(time.Duration(frequency) * time.Second) time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database") common.SysLog("syncing channels from database")
model.ChannelGroup.Load() model.ChannelGroup.Load()
util.PricingInstance.Init() relay_util.PricingInstance.Init()
} }
} }

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"strings" "strings"
@ -109,10 +110,10 @@ func tokenAuth(c *gin.Context, key string) {
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
if strings.HasPrefix(parts[1], "!") { if strings.HasPrefix(parts[1], "!") {
channelId := common.String2Int(parts[1][1:]) channelId := utils.String2Int(parts[1][1:])
c.Set("skip_channel_id", channelId) c.Set("skip_channel_id", channelId)
} else { } else {
channelId := common.String2Int(parts[1]) channelId := utils.String2Int(parts[1])
if channelId == 0 { if channelId == 0 {
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
return return

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/utils"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -99,11 +100,11 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
func GlobalWebRateLimit() func(c *gin.Context) { func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GetOrDefault("global.web_rate_limit", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW") return rateLimitFactory(utils.GetOrDefault("global.web_rate_limit", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW")
} }
func GlobalAPIRateLimit() func(c *gin.Context) { func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GetOrDefault("global.api_rate_limit", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA") return rateLimitFactory(utils.GetOrDefault("global.api_rate_limit", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA")
} }
func CriticalRateLimit() func(c *gin.Context) { func CriticalRateLimit() func(c *gin.Context) {

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"one-api/common" "one-api/common"
"one-api/common/utils"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -10,7 +11,7 @@ import (
func RequestId() func(c *gin.Context) { func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8) id := utils.GetTimeString() + utils.GetRandomString(8)
c.Set(common.RequestIdKey, id) c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
ctx = context.WithValue(ctx, "requestStartTime", time.Now()) ctx = context.WithValue(ctx, "requestStartTime", time.Now())

View File

@ -1,14 +1,16 @@
package middleware package middleware
import ( import (
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/common/utils"
"github.com/gin-gonic/gin"
) )
func abortWithMessage(c *gin.Context, statusCode int, message string) { func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": utils.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "one_api_error", "type": "one_api_error",
}, },
}) })

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"math/rand" "math/rand"
"one-api/common" "one-api/common"
"one-api/common/utils"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -105,7 +106,7 @@ func (cc *ChannelsChooser) Next(group, modelName string, filters ...ChannelsFilt
channelsPriority, ok := cc.Rule[group][modelName] channelsPriority, ok := cc.Rule[group][modelName]
if !ok { if !ok {
matchModel := common.GetModelsWithMatch(&cc.Match, modelName) matchModel := utils.GetModelsWithMatch(&cc.Match, modelName)
channelsPriority, ok = cc.Rule[group][matchModel] channelsPriority, ok = cc.Rule[group][matchModel]
if !ok { if !ok {
return nil, errors.New("model not found") return nil, errors.New("model not found")
@ -199,7 +200,7 @@ func (cc *ChannelsChooser) Load() {
// 逗号分割 ability.ChannelId // 逗号分割 ability.ChannelId
channelIds := strings.Split(ability.ChannelIds, ",") channelIds := strings.Split(ability.ChannelIds, ",")
for _, channelId := range channelIds { for _, channelId := range channelIds {
priorityIds = append(priorityIds, common.String2Int(channelId)) priorityIds = append(priorityIds, utils.String2Int(channelId))
} }
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds) newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)

View File

@ -2,6 +2,7 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/common/utils"
"strings" "strings"
"gorm.io/datatypes" "gorm.io/datatypes"
@ -235,7 +236,7 @@ func (channel *Channel) UpdateRaw(overwrite bool) error {
func (channel *Channel) UpdateResponseTime(responseTime int64) { func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: common.GetTimestamp(), TestTime: utils.GetTimestamp(),
ResponseTime: int(responseTime), ResponseTime: int(responseTime),
}).Error }).Error
if err != nil { if err != nil {
@ -245,7 +246,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
func (channel *Channel) UpdateBalance(balance float64) { func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(), BalanceUpdatedTime: utils.GetTimestamp(),
Balance: balance, Balance: balance,
}).Error }).Error
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -41,7 +42,7 @@ func RecordLog(userId int, logType int, content string) {
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(), CreatedAt: utils.GetTimestamp(),
Type: logType, Type: logType,
Content: content, Content: content,
} }
@ -59,7 +60,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(), CreatedAt: utils.GetTimestamp(),
Type: LogTypeConsume, Type: LogTypeConsume,
Content: content, Content: content,
PromptTokens: promptTokens, PromptTokens: promptTokens,

View File

@ -3,6 +3,7 @@ package model
import ( import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -26,7 +27,7 @@ func SetupDB() {
if viper.GetBool("batch_update_enabled") { if viper.GetBool("batch_update_enabled") {
common.BatchUpdateEnabled = true common.BatchUpdateEnabled = true
common.BatchUpdateInterval = common.GetOrDefault("batch_update_interval", 5) common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
InitBatchUpdater() InitBatchUpdater()
} }
@ -47,7 +48,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: common.GetUUID(), AccessToken: utils.GetUUID(),
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@ -78,7 +79,7 @@ func chooseDB() (*gorm.DB, error) {
// Use SQLite // Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database") common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.GetOrDefault("sqlite_busy_timeout", 3000)) config := fmt.Sprintf("?_busy_timeout=%d", utils.GetOrDefault("sqlite_busy_timeout", 3000))
return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{ return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
@ -96,9 +97,9 @@ func InitDB() (err error) {
return err return err
} }
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxIdleConns(utils.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode { if !common.IsMasterNode {
return nil return nil

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -33,7 +34,7 @@ func GetRedemptionsList(params *GenericParams) (*DataResult[Redemption], error)
var redemptions []*Redemption var redemptions []*Redemption
db := DB db := DB
if params.Keyword != "" { if params.Keyword != "" {
db = db.Where("id = ? or name LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%") db = db.Where("id = ? or name LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%")
} }
return PaginateAndOrder[Redemption](db, &params.PaginationParams, &redemptions, allowedRedemptionslOrderFields) return PaginateAndOrder[Redemption](db, &params.PaginationParams, &redemptions, allowedRedemptionslOrderFields)
@ -75,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return err return err
} }
redemption.RedeemedTime = common.GetTimestamp() redemption.RedeemedTime = utils.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err

View File

@ -2,7 +2,7 @@ package model
import ( import (
"errors" "errors"
"one-api/common" "one-api/common/utils"
) )
type TelegramMenu struct { type TelegramMenu struct {
@ -22,7 +22,7 @@ func GetTelegramMenusList(params *GenericParams) (*DataResult[TelegramMenu], err
var menus []*TelegramMenu var menus []*TelegramMenu
db := DB db := DB
if params.Keyword != "" { if params.Keyword != "" {
db = db.Where("id = ? or command LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%") db = db.Where("id = ? or command LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%")
} }
return PaginateAndOrder[TelegramMenu](db, &params.PaginationParams, &menus, allowedTelegramMenusOrderFields) return PaginateAndOrder[TelegramMenu](db, &params.PaginationParams, &menus, allowedTelegramMenusOrderFields)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/stmp" "one-api/common/stmp"
"one-api/common/utils"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -71,7 +72,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = common.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
@ -188,7 +189,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota), "remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(), "accessed_time": utils.GetTimestamp(),
}, },
).Error ).Error
return err return err
@ -210,7 +211,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota), "remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(), "accessed_time": utils.GetTimestamp(),
}, },
).Error ).Error
return err return err

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
@ -55,7 +56,7 @@ func GetUsersList(params *GenericParams) (*DataResult[User], error) {
var users []*User var users []*User
db := DB.Omit("password") db := DB.Omit("password")
if params.Keyword != "" { if params.Keyword != "" {
db = db.Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%", params.Keyword+"%", params.Keyword+"%") db = db.Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%", params.Keyword+"%", params.Keyword+"%")
} }
return PaginateAndOrder[User](db, &params.PaginationParams, &users, allowedUserOrderFields) return PaginateAndOrder[User](db, &params.PaginationParams, &users, allowedUserOrderFields)
@ -115,9 +116,9 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID() user.AccessToken = utils.GetUUID()
user.AffCode = common.GetRandomString(4) user.AffCode = utils.GetRandomString(4)
user.CreatedTime = common.GetTimestamp() user.CreatedTime = utils.GetTimestamp()
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
@ -165,7 +166,7 @@ func (user *User) Delete() error {
} }
// 不改变当前数据库索引,通过更改用户名来删除用户 // 不改变当前数据库索引,通过更改用户名来删除用户
user.Username = user.Username + "_del_" + common.GetRandomString(6) user.Username = user.Username + "_del_" + utils.GetRandomString(6)
err := user.Update(false) err := user.Update(false)
if err != nil { if err != nil {
return err return err

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -92,7 +93,7 @@ func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *ty
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: response.RequestId, ID: response.RequestId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: request.Model, Model: request.Model,
Choices: response.Output.ToChatCompletionChoices(), Choices: response.Output.ToChatCompletionChoices(),
Usage: &types.Usage{ Usage: &types.Usage{
@ -223,7 +224,7 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, d
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: aliResponse.RequestId, ID: aliResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }

View File

@ -8,6 +8,7 @@ import (
"one-api/common" "one-api/common"
"one-api/common/image" "one-api/common/image"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/providers/base" "one-api/providers/base"
"one-api/types" "one-api/types"
"strings" "strings"
@ -172,7 +173,7 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeRespon
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: response.Id, ID: response.Id,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice}, Choices: []types.ChatCompletionChoice{choice},
Model: request.Model, Model: request.Model,
Usage: &types.Usage{ Usage: &types.Usage{
@ -264,9 +265,9 @@ func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStream
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
chatCompletion := types.ChatCompletionStreamResponse{ chatCompletion := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -85,9 +86,9 @@ func (p *CloudflareAIProvider) convertToChatOpenai(response *ChatRespone, reques
} }
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: request.Model, Model: request.Model,
Choices: []types.ChatCompletionChoice{{ Choices: []types.ChatCompletionChoice{{
Index: 0, Index: 0,
@ -155,9 +156,9 @@ func (h *CloudflareAIStreamHandler) handlerStream(rawLine *[]byte, dataChan chan
func (h *CloudflareAIStreamHandler) convertToOpenaiStream(chatResponse *ChatResult, dataChan chan string, isStop bool) { func (h *CloudflareAIStreamHandler) convertToOpenaiStream(chatResponse *ChatResult, dataChan chan string, isStop bool) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
} }

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/storage" "one-api/common/storage"
"one-api/common/utils"
"one-api/types" "one-api/types"
"time" "time"
) )
@ -46,7 +47,7 @@ func (p *CloudflareAIProvider) CreateImageGenerations(request *types.ImageReques
url := "" url := ""
if request.ResponseFormat == "" || request.ResponseFormat == "url" { if request.ResponseFormat == "" || request.ResponseFormat == "url" {
url = storage.Upload(body, common.GetUUID()+".png") url = storage.Upload(body, utils.GetUUID()+".png")
} }
openaiResponse := &types.ImageResponse{ openaiResponse := &types.ImageResponse{

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/providers/base" "one-api/providers/base"
"one-api/types" "one-api/types"
"strings" "strings"
@ -138,7 +139,7 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereRespon
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: response.GenerationID, ID: response.GenerationID,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice}, Choices: []types.ChatCompletionChoice{choice},
Model: request.Model, Model: request.Model,
Usage: &types.Usage{}, Usage: &types.Usage{},
@ -190,9 +191,9 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream
} }
chatCompletion := types.ChatCompletionStreamResponse{ chatCompletion := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -89,9 +90,9 @@ func (p *CozeProvider) convertToChatOpenai(response *CozeResponse, request *type
} }
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: request.Model, Model: request.Model,
Choices: []types.ChatCompletionChoice{{ Choices: []types.ChatCompletionChoice{{
Index: 0, Index: 0,
@ -168,9 +169,9 @@ func (h *CozeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string,
func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) { func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
} }

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -145,9 +146,9 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
} }
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: request.Model, Model: request.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
} }
@ -191,9 +192,9 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) { func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
// Choices: choices, // Choices: choices,
} }

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/image" "one-api/common/image"
"one-api/common/utils"
"one-api/types" "one-api/types"
) )
@ -120,7 +121,7 @@ func (g *GeminiFunctionCall) ToOpenAITool() *types.ChatCompletionToolCalls {
args, _ := json.Marshal(g.Args) args, _ := json.Marshal(g.Args)
return &types.ChatCompletionToolCalls{ return &types.ChatCompletionToolCalls{
Id: "call_" + common.GetRandomString(24), Id: "call_" + utils.GetRandomString(24),
Type: types.ChatMessageRoleFunction, Type: types.ChatMessageRoleFunction,
Index: 0, Index: 0,
Function: &types.ChatCompletionToolCallsFunction{ Function: &types.ChatCompletionToolCallsFunction{

View File

@ -7,6 +7,7 @@ import (
"one-api/common" "one-api/common"
"one-api/common/image" "one-api/common/image"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -100,9 +101,9 @@ func (p *OllamaProvider) convertToChatOpenai(response *ChatResponse, request *ty
} }
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: request.Model, Model: request.Model,
Choices: []types.ChatCompletionChoice{choices}, Choices: []types.ChatCompletionChoice{choices},
Usage: &types.Usage{ Usage: &types.Usage{
@ -195,9 +196,9 @@ func (h *ollamaStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
} }
chatCompletion := types.ChatCompletionStreamResponse{ chatCompletion := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -177,11 +178,11 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp
choice.FinishReason = types.FinishReasonStop choice.FinishReason = types.FinishReasonStop
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Model: h.Request.Model, Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
} }
responseBody, _ := json.Marshal(streamResponse) responseBody, _ := json.Marshal(streamResponse)

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/storage" "one-api/common/storage"
"one-api/common/utils"
"one-api/types" "one-api/types"
"time" "time"
) )
@ -71,7 +72,7 @@ func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest
if request.ResponseFormat == "" || request.ResponseFormat == "url" { if request.ResponseFormat == "" || request.ResponseFormat == "url" {
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image) body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
if err == nil { if err == nil {
imgUrl = storage.Upload(body, common.GetUUID()+".png") imgUrl = storage.Upload(body, utils.GetUUID()+".png")
} }
} }

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
) )
@ -101,7 +102,7 @@ func (p *TencentProvider) convertToChatOpenai(response *TencentChatResponse, req
openaiResponse = &types.ChatCompletionResponse{ openaiResponse = &types.ChatCompletionResponse{
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Usage: response.Usage, Usage: response.Usage,
Model: request.Model, Model: request.Model,
} }
@ -137,9 +138,9 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq
stream = 1 stream = 1
} }
return &TencentChatRequest{ return &TencentChatRequest{
Timestamp: common.GetTimestamp(), Timestamp: utils.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60, Expired: utils.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(), QueryID: utils.GetUUID(),
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
Stream: stream, Stream: stream,
@ -178,7 +179,7 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string) { func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
} }
if len(tencentChatResponse.Choices) > 0 { if len(tencentChatResponse.Choices) > 0 {

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/types" "one-api/types"
"strings" "strings"
@ -194,7 +195,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa
ID: xunfeiResponse.Header.Sid, ID: xunfeiResponse.Header.Sid,
Object: "chat.completion", Object: "chat.completion",
Model: h.Request.Model, Model: h.Request.Model,
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice}, Choices: []types.ChatCompletionChoice{choice},
Usage: &xunfeiResponse.Payload.Usage.Text, Usage: &xunfeiResponse.Payload.Usage.Text,
} }
@ -310,7 +311,7 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp
chatCompletion := types.ChatCompletionStreamResponse{ chatCompletion := types.ChatCompletionStreamResponse{
ID: xunfeiChatResponse.Header.Sid, ID: xunfeiChatResponse.Header.Sid,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
} }

View File

@ -1,7 +1,7 @@
package relay package relay
import ( import (
"one-api/relay/util" "one-api/relay/relay_util"
"one-api/types" "one-api/types"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
@ -14,7 +14,7 @@ type relayBase struct {
provider providersBase.ProviderInterface provider providersBase.ProviderInterface
originalModel string originalModel string
modelName string modelName string
cache *util.ChatCacheProps cache *relay_util.ChatCacheProps
} }
type RelayBaseInterface interface { type RelayBaseInterface interface {
@ -28,14 +28,14 @@ type RelayBaseInterface interface {
getModelName() string getModelName() string
getContext() *gin.Context getContext() *gin.Context
SetChatCache(allow bool) SetChatCache(allow bool)
GetChatCache() *util.ChatCacheProps GetChatCache() *relay_util.ChatCacheProps
} }
func (r *relayBase) SetChatCache(allow bool) { func (r *relayBase) SetChatCache(allow bool) {
r.cache = util.NewChatCacheProps(r.c, allow) r.cache = relay_util.NewChatCacheProps(r.c, allow)
} }
func (r *relayBase) GetChatCache() *util.ChatCacheProps { func (r *relayBase) GetChatCache() *relay_util.ChatCacheProps {
return r.cache return r.cache
} }

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
"one-api/types" "one-api/types"
@ -100,9 +101,9 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
func (r *relayChat) getUsageResponse() string { func (r *relayChat) getUsageResponse() string {
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage { if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
usageResponse := types.ChatCompletionStreamResponse{ usageResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: r.chatRequest.Model, Model: r.chatRequest.Model,
Choices: []types.ChatCompletionStreamChoice{}, Choices: []types.ChatCompletionStreamChoice{},
Usage: r.provider.GetUsage(), Usage: r.provider.GetUsage(),

View File

@ -9,11 +9,12 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
"one-api/controller" "one-api/controller"
"one-api/model" "one-api/model"
"one-api/providers" "one-api/providers"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
"one-api/relay/util" "one-api/relay/relay_util"
"one-api/types" "one-api/types"
"strings" "strings"
@ -142,7 +143,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
type StreamEndHandler func() string type StreamEndHandler func() string
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) { func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *relay_util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c) requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv() dataChan, errChan := stream.Recv()
@ -257,7 +258,7 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) err.OpenAIError.Message = utils.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{ c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError, "error": err.OpenAIError,
}) })

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/common/utils"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
"one-api/types" "one-api/types"
@ -93,9 +94,9 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
func (r *relayCompletions) getUsageResponse() string { func (r *relayCompletions) getUsageResponse() string {
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage { if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
usageResponse := types.CompletionResponse{ usageResponse := types.CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: utils.GetTimestamp(),
Model: r.request.Model, Model: r.request.Model,
Choices: []types.CompletionChoice{}, Choices: []types.CompletionChoice{},
Usage: r.provider.GetUsage(), Usage: r.provider.GetUsage(),

View File

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/relay_util"
"one-api/types" "one-api/types"
"time" "time"
@ -96,8 +96,8 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
relay.getProvider().SetUsage(usage) relay.getProvider().SetUsage(usage)
var quota *util.Quota var quota *relay_util.Quota
quota, err = util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens) quota, err = relay_util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens)
if err != nil { if err != nil {
done = true done = true
return return
@ -119,7 +119,7 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
return return
} }
func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) { func cacheProcessing(c *gin.Context, cacheProps *relay_util.ChatCacheProps) {
responseCache(c, cacheProps.Response) responseCache(c, cacheProps.Response)
// 写入日志 // 写入日志

View File

@ -14,7 +14,7 @@ import (
"one-api/model" "one-api/model"
provider "one-api/providers/midjourney" provider "one-api/providers/midjourney"
"one-api/relay" "one-api/relay"
"one-api/relay/util" "one-api/relay/relay_util"
"one-api/types" "one-api/types"
"strconv" "strconv"
"strings" "strings"
@ -539,10 +539,10 @@ func getMjRequestPath(path string) string {
return requestURL return requestURL
} }
func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) { func getQuota(c *gin.Context, action string) (*relay_util.Quota, *types.OpenAIErrorWithStatusCode) {
modelName := CoverActionToModelName(action) modelName := CoverActionToModelName(action)
return util.NewQuota(c, modelName, 1000) return relay_util.NewQuota(c, modelName, 1000)
} }
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {

View File

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/relay_util"
"one-api/types" "one-api/types"
"sort" "sort"
@ -90,7 +90,7 @@ func ListModels(c *gin.Context) {
} }
func ListModelsForAdmin(c *gin.Context) { func ListModelsForAdmin(c *gin.Context) {
prices := util.PricingInstance.GetAllPrices() prices := relay_util.PricingInstance.GetAllPrices()
var openAIModels []OpenAIModels var openAIModels []OpenAIModels
for modelId, price := range prices { for modelId, price := range prices {
openAIModels = append(openAIModels, OpenAIModels{ openAIModels = append(openAIModels, OpenAIModels{
@ -123,7 +123,7 @@ func ListModelsForAdmin(c *gin.Context) {
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context) {
modelName := c.Param("model") modelName := c.Param("model")
openaiModel := getOpenAIModelWithName(modelName) openaiModel := getOpenAIModelWithName(modelName)
if *openaiModel.OwnedBy != util.UnknownOwnedBy { if *openaiModel.OwnedBy != relay_util.UnknownOwnedBy {
c.JSON(200, openaiModel) c.JSON(200, openaiModel)
} else { } else {
openAIError := types.OpenAIError{ openAIError := types.OpenAIError{
@ -139,15 +139,15 @@ func RetrieveModel(c *gin.Context) {
} }
func getModelOwnedBy(channelType int) (ownedBy *string) { func getModelOwnedBy(channelType int) (ownedBy *string) {
if ownedByName, ok := util.ModelOwnedBy[channelType]; ok { if ownedByName, ok := relay_util.ModelOwnedBy[channelType]; ok {
return &ownedByName return &ownedByName
} }
return &util.UnknownOwnedBy return &relay_util.UnknownOwnedBy
} }
func getOpenAIModelWithName(modelName string) *OpenAIModels { func getOpenAIModelWithName(modelName string) *OpenAIModels {
price := util.PricingInstance.GetPrice(modelName) price := relay_util.PricingInstance.GetPrice(modelName)
return &OpenAIModels{ return &OpenAIModels{
Id: modelName, Id: modelName,
@ -164,6 +164,6 @@ func GetModelOwnedBy(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": util.ModelOwnedBy, "data": relay_util.ModelOwnedBy,
}) })
} }

View File

@ -1,10 +1,11 @@
package util package relay_util
import ( import (
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -37,7 +38,7 @@ func GetDebugList(userId int) ([]*ChatCacheProps, error) {
var props []*ChatCacheProps var props []*ChatCacheProps
for _, cache := range caches { for _, cache := range caches {
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data) prop, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil { if err != nil {
continue continue
} }
@ -77,7 +78,7 @@ func (p *ChatCacheProps) SetHash(request any) {
return return
} }
p.hash(common.Marshal(request)) p.hash(utils.Marshal(request))
} }
func (p *ChatCacheProps) SetResponse(response any) { func (p *ChatCacheProps) SetResponse(response any) {
@ -90,7 +91,7 @@ func (p *ChatCacheProps) SetResponse(response any) {
return return
} }
responseStr := common.Marshal(response) responseStr := utils.Marshal(response)
if responseStr == "" { if responseStr == "" {
return return
} }

View File

@ -1,8 +1,8 @@
package util package relay_util
import ( import (
"errors" "errors"
"one-api/common" "one-api/common/utils"
"one-api/model" "one-api/model"
"time" "time"
) )
@ -15,7 +15,7 @@ func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
return nil return nil
} }
props, err := common.UnmarshalString[ChatCacheProps](cache.Data) props, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil { if err != nil {
return nil return nil
} }
@ -28,7 +28,7 @@ func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) err
} }
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error { func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
data := common.Marshal(props) data := utils.Marshal(props)
if data == "" { if data == "" {
return errors.New("marshal error") return errors.New("marshal error")
} }

View File

@ -1,9 +1,10 @@
package util package relay_util
import ( import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/utils"
"time" "time"
) )
@ -17,7 +18,7 @@ func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
return nil return nil
} }
props, err := common.UnmarshalString[ChatCacheProps](cache) props, err := utils.UnmarshalString[ChatCacheProps](cache)
if err != nil { if err != nil {
return nil return nil
} }
@ -31,7 +32,7 @@ func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) e
return nil return nil
} }
data := common.Marshal(&props) data := utils.Marshal(&props)
if data == "" { if data == "" {
return errors.New("marshal error") return errors.New("marshal error")
} }

View File

@ -1,9 +1,10 @@
package util package relay_util
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"one-api/common" "one-api/common"
"one-api/common/utils"
"one-api/model" "one-api/model"
"sort" "sort"
"strings" "strings"
@ -98,7 +99,7 @@ func (p *Pricing) GetPrice(modelName string) *model.Price {
return price return price
} }
matchModel := common.GetModelsWithMatch(&p.Match, modelName) matchModel := utils.GetModelsWithMatch(&p.Match, modelName)
if price, ok := p.Prices[matchModel]; ok { if price, ok := p.Prices[matchModel]; ok {
return price return price
} }
@ -281,7 +282,7 @@ func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []stri
var updatePrices []string var updatePrices []string
for _, model := range originalModels { for _, model := range originalModels {
if !common.Contains(model, batchPrices.Models) { if !utils.Contains(model, batchPrices.Models) {
deletePrices = append(deletePrices, model) deletePrices = append(deletePrices, model)
} else { } else {
updatePrices = append(updatePrices, model) updatePrices = append(updatePrices, model)
@ -289,7 +290,7 @@ func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []stri
} }
for _, model := range batchPrices.Models { for _, model := range batchPrices.Models {
if !common.Contains(model, originalModels) { if !utils.Contains(model, originalModels) {
addPrice := batchPrices.Price addPrice := batchPrices.Price
addPrice.Model = model addPrice.Model = model
addPrices = append(addPrices, &addPrice) addPrices = append(addPrices, &addPrice)

View File

@ -1,4 +1,4 @@
package util package relay_util
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package util package relay_util
import "one-api/common" import "one-api/common"