diff --git a/common/email.go b/common/email.go deleted file mode 100644 index b915f0f9..00000000 --- a/common/email.go +++ /dev/null @@ -1,86 +0,0 @@ -package common - -import ( - "crypto/rand" - "crypto/tls" - "encoding/base64" - "fmt" - "net/smtp" - "strings" - "time" -) - -func SendEmail(subject string, receiver string, content string) error { - if SMTPFrom == "" { // for compatibility - SMTPFrom = SMTPAccount - } - encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) - - // Extract domain from SMTPFrom - parts := strings.Split(SMTPFrom, "@") - var domain string - if len(parts) > 1 { - domain = parts[1] - } - // Generate a unique Message-ID - buf := make([]byte, 16) - _, err := rand.Read(buf) - if err != nil { - return err - } - messageId := fmt.Sprintf("<%x@%s>", buf, domain) - - mail := []byte(fmt.Sprintf("To: %s\r\n"+ - "From: %s<%s>\r\n"+ - "Subject: %s\r\n"+ - "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 - "Date: %s\r\n"+ - "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) - auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) - addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) - to := strings.Split(receiver, ";") - - if SMTPPort == 465 { - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - ServerName: SMTPServer, - } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) - if err != nil { - return err - } - client, err := smtp.NewClient(conn, SMTPServer) - if err != nil { - return err - } - defer client.Close() - if err = client.Auth(auth); err != nil { - return err - } - if err = client.Mail(SMTPFrom); err != nil { - return err - } - receiverEmails := strings.Split(receiver, ";") - for _, receiver := range receiverEmails { - if err = client.Rcpt(receiver); err != nil { - return err - } - } - w, err := client.Data() - if err != nil { - return err - } - _, err = w.Write(mail) - if err != nil { - return err - } - err = w.Close() - if err != nil { - return err - } - } else { - err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) - } - return err -} diff --git a/common/notify/channel/channel_test.go b/common/notify/channel/channel_test.go new file mode 100644 index 00000000..79a9bdc6 --- /dev/null +++ b/common/notify/channel/channel_test.go @@ -0,0 +1,128 @@ +package channel_test + +import ( + "context" + "fmt" + "testing" + + "one-api/common/notify/channel" + "one-api/common/requester" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func InitConfig() { + viper.AddConfigPath("/one-api") + viper.SetConfigName("config") + viper.ReadInConfig() + requester.InitHttpClient() +} + +func TestDingTalkSend(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.dingtalk.token") + secret := viper.GetString("notify.dingtalk.secret") + dingTalk := channel.NewDingTalk(access_token, secret) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestDingTalkSendWithKeyWord(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.dingtalk.token") + keyWord := viper.GetString("notify.dingtalk.keyWord") + + dingTalk := channel.NewDingTalkWithKeyWord(access_token, keyWord) + + err := dingTalk.Send(context.Background(), "Test Title", "Test Message") + assert.Nil(t, err) +} + +func TestDingTalkSendError(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.dingtalk.token") + secret := "test" + dingTalk := channel.NewDingTalk(access_token, secret) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Error(t, err) +} + +func TestLarkSend(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.lark.token") + secret := viper.GetString("notify.lark.secret") + dingTalk := channel.NewLark(access_token, secret) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestLarkSendWithKeyWord(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.lark.token") + keyWord := viper.GetString("notify.lark.keyWord") + + dingTalk := channel.NewLarkWithKeyWord(access_token, keyWord) + + err := dingTalk.Send(context.Background(), "Test Title", "Test Message\n\n- 111\n- 222") + assert.Nil(t, err) +} + +func TestLarkSendError(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.lark.token") + secret := "test" + dingTalk := channel.NewLark(access_token, secret) + + err := dingTalk.Send(context.Background(), "Title", "*Message*") + fmt.Println(err) + assert.Error(t, err) +} + +func TestPushdeerSend(t *testing.T) { + InitConfig() + pushkey := viper.GetString("notify.pushdeer.pushkey") + dingTalk := channel.NewPushdeer(pushkey, "") + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestPushdeerSendError(t *testing.T) { + InitConfig() + pushkey := "test" + dingTalk := channel.NewPushdeer(pushkey, "") + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Error(t, err) +} + +func TestTelegramSend(t *testing.T) { + InitConfig() + secret := viper.GetString("notify.telegram.bot_api_key") + chatID := viper.GetString("notify.telegram.chat_id") + dingTalk := channel.NewTelegram(secret, chatID) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestTelegramSendError(t *testing.T) { + InitConfig() + secret := "test" + chatID := viper.GetString("notify.telegram.chat_id") + dingTalk := channel.NewTelegram(secret, chatID) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Error(t, err) +} diff --git a/common/notify/channel/dingTalk.go b/common/notify/channel/dingTalk.go new file mode 100644 index 00000000..d20beddf --- /dev/null +++ b/common/notify/channel/dingTalk.go @@ -0,0 +1,127 @@ +package channel + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "one-api/common/requester" + "one-api/types" + "time" +) + +const dingTalkURL = "https://oapi.dingtalk.com/robot/send?" + +type DingTalk struct { + token string + secret string + keyWord string +} + +type dingTalkMessage struct { + MsgType string `json:"msgtype"` + Markdown struct { + Title string `json:"title"` + Text string `json:"text"` + } `json:"markdown"` +} + +type dingTalkResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +func NewDingTalk(token string, secret string) *DingTalk { + return &DingTalk{ + token: token, + secret: secret, + } +} + +func NewDingTalkWithKeyWord(token string, keyWord string) *DingTalk { + return &DingTalk{ + token: token, + keyWord: keyWord, + } +} + +func (d *DingTalk) Name() string { + return "DingTalk" +} + +func (d *DingTalk) Send(ctx context.Context, title, message string) error { + msg := dingTalkMessage{ + MsgType: "markdown", + } + msg.Markdown.Title = title + msg.Markdown.Text = message + + if d.keyWord != "" { + msg.Markdown.Text = fmt.Sprintf("%s\n%s", d.keyWord, msg.Markdown.Text) + } + + query := url.Values{} + query.Set("access_token", d.token) + if d.secret != "" { + t := time.Now().UnixMilli() + query.Set("timestamp", fmt.Sprintf("%d", t)) + query.Set("sign", d.sign(t)) + } + uri := dingTalkURL + query.Encode() + + client := requester.NewHTTPRequester("", dingtalkErrFunc) + client.Context = ctx + client.IsOpenAI = false + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + resp, errWithOP := client.SendRequestRaw(req) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + defer resp.Body.Close() + + dingtalkErr := dingtalkErrFunc(resp) + if dingtalkErr != nil { + return fmt.Errorf("%s", dingtalkErr.Message) + } + + return nil +} + +func (d *DingTalk) sign(timestamp int64) string { + stringToHash := fmt.Sprintf("%d\n%s", timestamp, d.secret) + hmac256 := hmac.New(sha256.New, []byte(d.secret)) + hmac256.Write([]byte(stringToHash)) + data := hmac256.Sum(nil) + signature := base64.StdEncoding.EncodeToString(data) + + return url.QueryEscape(signature) +} + +func dingtalkErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &dingTalkResponse{} + + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + fmt.Println(err) + return nil + } + + if respMsg.ErrCode == 0 { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.ErrMsg), + Type: "dingtalk_error", + Code: fmt.Sprintf("%d", respMsg.ErrCode), + } +} diff --git a/common/notify/channel/email.go b/common/notify/channel/email.go new file mode 100644 index 00000000..6fde98a5 --- /dev/null +++ b/common/notify/channel/email.go @@ -0,0 +1,50 @@ +package channel + +import ( + "context" + "errors" + "one-api/common" + "one-api/common/stmp" + + "github.com/gomarkdown/markdown" + "github.com/gomarkdown/markdown/html" + "github.com/gomarkdown/markdown/parser" +) + +type Email struct { + To string +} + +func NewEmail(to string) *Email { + return &Email{ + To: to, + } +} + +func (e *Email) Name() string { + return "Email" +} + +func (e *Email) Send(ctx context.Context, title, message string) error { + to := e.To + if to == "" { + to = common.RootUserEmail + } + + if common.SMTPServer == "" || common.SMTPAccount == "" || common.SMTPToken == "" || to == "" { + return errors.New("smtp config is not set, skip send email notifier") + } + + p := parser.NewWithExtensions(parser.CommonExtensions | parser.DefinitionLists | parser.OrderedListStart) + doc := p.Parse([]byte(message)) + + htmlFlags := html.CommonFlags | html.HrefTargetBlank + opts := html.RendererOptions{Flags: htmlFlags} + renderer := html.NewRenderer(opts) + + body := markdown.Render(doc, renderer) + + emailClient := stmp.NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom) + + return emailClient.Send(to, title, string(body)) +} diff --git a/common/notify/channel/lark.go b/common/notify/channel/lark.go new file mode 100644 index 00000000..cddcfc1b --- /dev/null +++ b/common/notify/channel/lark.go @@ -0,0 +1,149 @@ +package channel + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/types" + "strconv" + "time" +) + +const larkURL = "https://open.feishu.cn/open-apis/bot/v2/hook/" + +type Lark struct { + token string + secret string + keyWord string +} + +type larkMessage struct { + MessageType string `json:"msg_type"` + Timestamp string `json:"timestamp,omitempty"` + Sign string `json:"sign,omitempty"` + Card larkCardContent `json:"card"` +} + +type larkCardContent struct { + Config struct { + WideScreenMode bool `json:"wide_screen_mode"` + EnableForward bool `json:"enable_forward"` + } + Elements []larkMessageRequestCardElement `json:"elements"` +} + +type larkMessageRequestCardElementText struct { + Content string `json:"content"` + Tag string `json:"tag"` +} + +type larkMessageRequestCardElement struct { + Tag string `json:"tag"` + Text larkMessageRequestCardElementText `json:"text"` +} + +type larkResponse struct { + Code int `json:"code"` + Message string `json:"msg"` +} + +func NewLark(token, secret string) *Lark { + return &Lark{ + token: token, + secret: secret, + } +} + +func NewLarkWithKeyWord(token, keyWord string) *Lark { + return &Lark{ + token: token, + keyWord: keyWord, + } +} + +func (l *Lark) Name() string { + return "Lark" +} + +func (l *Lark) Send(ctx context.Context, title, message string) error { + msg := larkMessage{ + MessageType: "interactive", + } + + if l.keyWord != "" { + title = fmt.Sprintf("%s(%s)", title, l.keyWord) + } + + msg.Card.Config.WideScreenMode = true + msg.Card.Config.EnableForward = true + msg.Card.Elements = append(msg.Card.Elements, larkMessageRequestCardElement{ + Tag: "div", + Text: larkMessageRequestCardElementText{ + Content: fmt.Sprintf("**%s**\n%s", title, message), + Tag: "lark_md", + }, + }) + + if l.secret != "" { + t := time.Now().Unix() + msg.Timestamp = strconv.FormatInt(t, 10) + msg.Sign = l.sign(t) + } + + uri := larkURL + l.token + client := requester.NewHTTPRequester("", larkErrFunc) + client.Context = ctx + client.IsOpenAI = false + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + resp, errWithOP := client.SendRequestRaw(req) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + defer resp.Body.Close() + + larkErr := larkErrFunc(resp) + if larkErr != nil { + return fmt.Errorf("%s", larkErr.Message) + } + + return nil + +} + +func (l *Lark) sign(timestamp int64) string { + //timestamp + key 做sha256, 再进行base64 encode + stringToSign := fmt.Sprintf("%v", timestamp) + "\n" + l.secret + var data []byte + h := hmac.New(sha256.New, []byte(stringToSign)) + h.Write(data) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func larkErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &larkResponse{} + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + return nil + } + + if respMsg.Code == 0 { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.Message), + Type: "lark_error", + Code: respMsg.Code, + } +} diff --git a/common/notify/channel/pushdeer.go b/common/notify/channel/pushdeer.go new file mode 100644 index 00000000..20a5f1b2 --- /dev/null +++ b/common/notify/channel/pushdeer.go @@ -0,0 +1,96 @@ +package channel + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/types" + "strings" +) + +const pushdeerURL = "https://api2.pushdeer.com" + +type Pushdeer struct { + url string + pushkey string +} + +type pushdeerMessage struct { + Text string `json:"text"` + Desp string `json:"desp"` + Type string `json:"type"` +} + +type pushdeerResponse struct { + Code int `json:"code,omitempty"` + Error string `json:"error,omitempty"` + Message string `json:"message,omitempty"` +} + +func NewPushdeer(pushkey, url string) *Pushdeer { + return &Pushdeer{ + url: url, + pushkey: pushkey, + } +} + +func (p *Pushdeer) Name() string { + return "Pushdeer" +} + +func (p *Pushdeer) Send(ctx context.Context, title, message string) error { + msg := pushdeerMessage{ + Text: title, + Desp: message, + Type: "markdown", + } + + url := p.url + if url == "" { + url = pushdeerURL + } + + // 去除最后一个/ + url = strings.TrimSuffix(url, "/") + uri := fmt.Sprintf("%s/message/push?pushkey=%s", url, p.pushkey) + + client := requester.NewHTTPRequester("", pushdeerErrFunc) + client.Context = ctx + client.IsOpenAI = false + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + respMsg := &pushdeerResponse{} + _, errWithOP := client.SendRequest(req, respMsg, false) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + + if respMsg.Code != 0 { + return fmt.Errorf("send msg err. err msg: %s", respMsg.Error) + } + + return nil +} + +func pushdeerErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &pushdeerResponse{} + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + return nil + } + + if respMsg.Message == "" { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.Message), + Type: "pushdeer_error", + } +} diff --git a/common/notify/channel/telegram.go b/common/notify/channel/telegram.go new file mode 100644 index 00000000..26889479 --- /dev/null +++ b/common/notify/channel/telegram.go @@ -0,0 +1,114 @@ +package channel + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/types" +) + +const telegramURL = "https://api.telegram.org/bot" + +type Telegram struct { + secret string + chatID string +} + +type telegramMessage struct { + ChatID string `json:"chat_id"` + Text string `json:"text"` + ParseMode string `json:"parse_mode"` +} + +type telegramResponse struct { + Ok bool `json:"ok"` + Description string `json:"description"` +} + +func NewTelegram(secret string, chatID string) *Telegram { + return &Telegram{ + secret: secret, + chatID: chatID, + } +} + +func (t *Telegram) Name() string { + return "Telegram" +} + +func (t *Telegram) Send(ctx context.Context, title, message string) error { + const maxMessageLength = 4096 + message = fmt.Sprintf("*%s*\n%s", title, message) + messages := splitTelegramMessageIntoParts(message, maxMessageLength) + + client := requester.NewHTTPRequester("", telegramErrFunc) + client.Context = ctx + client.IsOpenAI = false + + for _, msg := range messages { + err := t.sendMessage(msg, client) + if err != nil { + return err + } + } + + return nil +} + +func (t *Telegram) sendMessage(message string, client *requester.HTTPRequester) error { + msg := telegramMessage{ + ChatID: t.chatID, + Text: message, + ParseMode: "Markdown", + } + + uri := telegramURL + t.secret + "/sendMessage" + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + resp, errWithOP := client.SendRequestRaw(req) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + defer resp.Body.Close() + + telegramErr := telegramErrFunc(resp) + if telegramErr != nil { + return fmt.Errorf("%s", telegramErr.Message) + } + + return nil +} + +func splitTelegramMessageIntoParts(message string, partSize int) []string { + var parts []string + for len(message) > partSize { + parts = append(parts, message[:partSize]) + message = message[partSize:] + } + parts = append(parts, message) + + return parts +} + +func telegramErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &telegramResponse{} + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + return nil + } + + if respMsg.Ok { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.Description), + Type: "telegram_error", + } +} diff --git a/common/notify/notifier.go b/common/notify/notifier.go new file mode 100644 index 00000000..c70c4eb3 --- /dev/null +++ b/common/notify/notifier.go @@ -0,0 +1,98 @@ +package notify + +import ( + "context" + "one-api/common" + "one-api/common/notify/channel" + + "github.com/spf13/viper" +) + +type Notifier interface { + Send(context.Context, string, string) error + Name() string +} + +func InitNotifier() { + InitEmailNotifier() + InitDingTalkNotifier() + InitLarkNotifier() + InitPushdeerNotifier() + InitTelegramNotifier() +} + +func InitEmailNotifier() { + if viper.GetBool("notify.email.disable") { + common.SysLog("email notifier disabled") + return + } + smtp_to := viper.GetString("notify.email.smtp_to") + emailNotifier := channel.NewEmail(smtp_to) + AddNotifiers(emailNotifier) + common.SysLog("email notifier enable") +} + +func InitDingTalkNotifier() { + access_token := viper.GetString("notify.dingtalk.token") + secret := viper.GetString("notify.dingtalk.secret") + keyWord := viper.GetString("notify.dingtalk.keyWord") + if access_token == "" || (secret == "" && keyWord == "") { + return + } + + var dingTalkNotifier Notifier + + if secret != "" { + dingTalkNotifier = channel.NewDingTalk(access_token, secret) + } else { + dingTalkNotifier = channel.NewDingTalkWithKeyWord(access_token, keyWord) + } + + AddNotifiers(dingTalkNotifier) + common.SysLog("dingtalk notifier enable") +} + +func InitLarkNotifier() { + access_token := viper.GetString("notify.lark.token") + secret := viper.GetString("notify.lark.secret") + keyWord := viper.GetString("notify.lark.keyWord") + if access_token == "" || (secret == "" && keyWord == "") { + return + } + + var larkNotifier Notifier + + if secret != "" { + larkNotifier = channel.NewLark(access_token, secret) + } else { + larkNotifier = channel.NewLarkWithKeyWord(access_token, keyWord) + } + + AddNotifiers(larkNotifier) + common.SysLog("lark notifier enable") +} + +func InitPushdeerNotifier() { + pushkey := viper.GetString("notify.pushdeer.pushkey") + if pushkey == "" { + return + } + + pushdeerNotifier := channel.NewPushdeer(pushkey, viper.GetString("notify.pushdeer.url")) + + AddNotifiers(pushdeerNotifier) + common.SysLog("pushdeer notifier enable") +} + +func InitTelegramNotifier() { + bot_token := viper.GetString("notify.telegram.bot_api_key") + chat_id := viper.GetString("notify.telegram.chat_id") + if bot_token == "" || chat_id == "" { + return + } + + telegramNotifier := channel.NewTelegram(bot_token, chat_id) + + AddNotifiers(telegramNotifier) + common.SysLog("telegram notifier enable") +} diff --git a/common/notify/notify.go b/common/notify/notify.go new file mode 100644 index 00000000..a6c9b9f1 --- /dev/null +++ b/common/notify/notify.go @@ -0,0 +1,35 @@ +package notify + +var notifyChannels = New() + +type Notify struct { + notifiers map[string]Notifier +} + +func (n *Notify) addChannel(channel Notifier) { + if channel != nil { + channelName := channel.Name() + if _, ok := n.notifiers[channelName]; ok { + return + } + n.notifiers[channelName] = channel + } +} + +func (n *Notify) addChannels(channel ...Notifier) { + for _, s := range channel { + n.addChannel(s) + } +} + +func New() *Notify { + notify := &Notify{ + notifiers: make(map[string]Notifier, 0), + } + + return notify +} + +func AddNotifiers(channel ...Notifier) { + notifyChannels.addChannels(channel...) +} diff --git a/common/notify/send.go b/common/notify/send.go new file mode 100644 index 00000000..6246af6d --- /dev/null +++ b/common/notify/send.go @@ -0,0 +1,30 @@ +package notify + +import ( + "context" + "fmt" + "one-api/common" +) + +func (n *Notify) Send(ctx context.Context, title, message string) { + if ctx == nil { + ctx = context.Background() + } + + for channelName, channel := range n.notifiers { + if channel == nil { + continue + } + err := channel.Send(ctx, title, message) + if err != nil { + common.LogError(ctx, fmt.Sprintf("%s err: %s", channelName, err.Error())) + } + } +} + +func Send(title, message string) { + //lint:ignore SA1029 reason: 需要使用该类型作为错误处理 + ctx := context.WithValue(context.Background(), common.RequestIdKey, "NotifyTask") + + notifyChannels.Send(ctx, title, message) +} diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go index ccc27eb8..f36601fe 100644 --- a/common/requester/http_requester.go +++ b/common/requester/http_requester.go @@ -24,6 +24,7 @@ type HTTPRequester struct { ErrorHandler HttpErrorHandler proxyAddr string Context context.Context + IsOpenAI bool } // NewHTTPRequester 创建一个新的 HTTPRequester 实例。 @@ -39,6 +40,7 @@ func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequ ErrorHandler: errorHandler, proxyAddr: proxyAddr, Context: context.Background(), + IsOpenAI: true, } } @@ -94,10 +96,14 @@ func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp // 处理响应 if r.IsFailureStatusCode(resp) { - return nil, HandleErrorResp(resp, r.ErrorHandler) + return nil, HandleErrorResp(resp, r.ErrorHandler, r.IsOpenAI) } // 解析响应 + if response == nil { + return resp, nil + } + if outputResp { var buf bytes.Buffer tee := io.TeeReader(resp.Body, &buf) @@ -126,7 +132,7 @@ func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *type // 处理响应 if r.IsFailureStatusCode(resp) { - return nil, HandleErrorResp(resp, r.ErrorHandler) + return nil, HandleErrorResp(resp, r.ErrorHandler, r.IsOpenAI) } return resp, nil @@ -136,7 +142,7 @@ func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *type func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) { // 如果返回的头是json格式 说明有错误 if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { - return nil, HandleErrorResp(resp, requester.ErrorHandler) + return nil, HandleErrorResp(resp, requester.ErrorHandler, requester.IsOpenAI) } stream := &streamReader[T]{ @@ -180,7 +186,7 @@ func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool { } // 处理错误响应 -func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode { +func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler, isPrefix bool) *types.OpenAIErrorWithStatusCode { openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, @@ -199,12 +205,19 @@ func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types if errorResponse != nil && errorResponse.Message != "" { openAIErrorWithStatusCode.OpenAIError = *errorResponse - openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: %s", openAIErrorWithStatusCode.OpenAIError.Message) + if isPrefix { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: %s", openAIErrorWithStatusCode.OpenAIError.Message) + } } } if openAIErrorWithStatusCode.OpenAIError.Message == "" { openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: bad response status code %d", resp.StatusCode) + if isPrefix { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: bad response status code %d", resp.StatusCode) + } else { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } } return openAIErrorWithStatusCode @@ -218,6 +231,12 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +func GetJsonHeaders() map[string]string { + return map[string]string{ + "Content-type": "application/json", + } +} + type Stringer interface { GetString() *string } diff --git a/common/stmp/email.go b/common/stmp/email.go new file mode 100644 index 00000000..8d513452 --- /dev/null +++ b/common/stmp/email.go @@ -0,0 +1,168 @@ +package stmp + +import ( + "fmt" + "one-api/common" + "strings" + + "github.com/wneessen/go-mail" +) + +type StmpConfig struct { + Host string + Port int + Username string + Password string + From string +} + +func NewStmp(host string, port int, username string, password string, from string) *StmpConfig { + if from == "" { + from = username + } + + return &StmpConfig{ + Host: host, + Port: port, + Username: username, + Password: password, + From: from, + } +} + +func (s *StmpConfig) Send(to, subject, body string) error { + message := mail.NewMsg() + message.From(s.From) + message.To(to) + message.Subject(subject) + message.SetGenHeader("References", s.getReferences()) + message.SetBodyString(mail.TypeTextHTML, body) + message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", common.Version)) + + client, err := mail.NewClient( + s.Host, + mail.WithPort(s.Port), + mail.WithUsername(s.Username), + mail.WithPassword(s.Password), + mail.WithSMTPAuth(mail.SMTPAuthPlain)) + + if err != nil { + return err + } + + switch s.Port { + case 465: + client.SetSSL(true) + case 587: + client.SetTLSPolicy(mail.TLSMandatory) + } + + if err := client.DialAndSend(message); err != nil { + return err + } + + return nil +} + +func (s *StmpConfig) getReferences() string { + froms := strings.Split(s.From, "@") + return fmt.Sprintf("<%s.%s@%s>", froms[0], common.GetUUID(), froms[1]) +} + +func (s *StmpConfig) Render(to, subject, content string) error { + body := getDefaultTemplate(content) + + return s.Send(to, subject, body) +} + +func GetSystemStmp() (*StmpConfig, error) { + if common.SMTPServer == "" || common.SMTPPort == 0 || common.SMTPAccount == "" || common.SMTPToken == "" { + return nil, fmt.Errorf("SMTP 信息未配置") + } + + return NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom), nil +} + +func SendPasswordResetEmail(userName, email, link string) error { + stmp, err := GetSystemStmp() + + if err != nil { + return err + } + + contentTemp := `
Hi %s,
++ 您正在进行密码重置。点击下方按钮以重置密码。 +
+ ++ 重置密码 +
+ +
+ 如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开
%s
+
重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
` + + subject := fmt.Sprintf("%s密码重置", common.SystemName) + content := fmt.Sprintf(contentTemp, userName, link, link, common.VerificationValidMinutes) + + return stmp.Render(email, subject, content) +} + +func SendVerificationCodeEmail(email, code string) error { + stmp, err := GetSystemStmp() + + if err != nil { + return err + } + + contentTemp := ` ++ 您正在进行邮箱验证。您的验证码为: +
+ ++ %s +
+ ++ 验证码 %d 分钟内有效,如果不是本人操作,请忽略。 +
` + + subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes) + + return stmp.Render(email, subject, content) +} + +func SendQuotaWarningCodeEmail(userName, email string, quota int, noMoreQuota bool) error { + stmp, err := GetSystemStmp() + + if err != nil { + return err + } + + contentTemp := `Hi %s,
++ %s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。 +
+ ++ 点击充值 +
+ +
+ 如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开
%s
+
+ |
+
+ ` + content + ` + | +
+ © ` + getSystemName() + ` + |
+
您好,你正在进行%s邮箱验证。
"+ - "您的验证码为: %s
"+ - "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := stmp.SendVerificationCodeEmail(email, code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -136,22 +133,29 @@ func SendPasswordResetEmail(c *gin.Context) { }) return } - if !model.IsEmailAlreadyTaken(email) { + + user := &model.User{ + Email: email, + } + + if err := user.FillUserByEmail(); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该邮箱地址未注册", }) return } + + userName := user.DisplayName + if userName == "" { + userName = user.Username + } + code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) - subject := fmt.Sprintf("%s密码重置", common.SystemName) - content := fmt.Sprintf("您好,你正在进行%s密码重置。
"+ - "点击 此处 进行密码重置。
"+ - "如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s
重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := stmp.SendPasswordResetEmail(userName, email, link) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -173,6 +177,14 @@ type PasswordResetRequest struct { func ResetPassword(c *gin.Context) { var req PasswordResetRequest err := json.NewDecoder(c.Request.Body).Decode(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if req.Email == "" || req.Token == "" { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/user.go b/controller/user.go index 42faad3e..7f69998f 100644 --- a/controller/user.go +++ b/controller/user.go @@ -696,9 +696,7 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { - common.RootUserEmail = email - } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/go.mod b/go.mod index 290c82dc..da655611 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -42,6 +43,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.18.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/wneessen/go-mail v0.4.1 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect diff --git a/go.sum b/go.sum index a426c315..a1fa358e 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 h1:4gjrh/PN2MuWCCElk8/I4OCKRKWCCo2zEct3VKCbibU= +github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -262,6 +264,8 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/wneessen/go-mail v0.4.1 h1:m2rSg/sc8FZQCdtrV5M8ymHYOFrC6KJAQAIcgrXvqoo= +github.com/wneessen/go-mail v0.4.1/go.mod h1:zxOlafWCP/r6FEhAaRgH4IC1vg2YXxO0Nar9u0IScZ8= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/main.go b/main.go index 02417e03..df5478df 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/common/config" + "one-api/common/notify" "one-api/common/requester" "one-api/common/telegram" "one-api/controller" @@ -46,6 +47,7 @@ func main() { telegram.InitTelegramBot() controller.InitMidjourneyTask() + notify.InitNotifier() initHttpServer() } diff --git a/model/channel.go b/model/channel.go index 7dc58115..c7cd9a75 100644 --- a/model/channel.go +++ b/model/channel.go @@ -265,6 +265,19 @@ func (channel *Channel) Delete() error { return err } +func (channel *Channel) StatusToStr() string { + switch channel.Status { + case common.ChannelStatusEnabled: + return "启用" + case common.ChannelStatusAutoDisabled: + return "自动禁用" + case common.ChannelStatusManuallyDisabled: + return "手动禁用" + } + + return "禁用" +} + func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { diff --git a/model/main.go b/model/main.go index 61e8c9b6..5e682af2 100644 --- a/model/main.go +++ b/model/main.go @@ -22,6 +22,7 @@ func SetupDB() { common.FatalLog("failed to initialize database: " + err.Error()) } ChannelGroup.Load() + common.RootUserEmail = GetRootUserEmail() if viper.GetBool("BATCH_UPDATE_ENABLED") { common.BatchUpdateEnabled = true diff --git a/model/token.go b/model/token.go index 11f10d1f..689ec49e 100644 --- a/model/token.go +++ b/model/token.go @@ -2,8 +2,8 @@ package model import ( "errors" - "fmt" "one-api/common" + "one-api/common/stmp" "gorm.io/gorm" ) @@ -114,15 +114,13 @@ func GetTokenById(id int) (*Token, error) { } func (token *Token) Insert() error { - var err error - err = DB.Create(token).Error + err := DB.Create(token).Error return err } // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { - var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error return err } @@ -132,8 +130,7 @@ func (token *Token) SelectUpdate() error { } func (token *Token) Delete() error { - var err error - err = DB.Delete(token).Error + err := DB.Delete(token).Error return err } @@ -228,26 +225,35 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { } func sendQuotaWarningEmail(userId int, userQuota int, noMoreQuota bool) { - email, err := GetUserEmail(userId) - if err != nil { + user := User{Id: userId} + + if err := user.FillUserById(); err != nil { common.SysError("failed to fetch user email: " + err.Error()) + return } - prompt := "您的额度即将用尽" - if noMoreQuota { - prompt = "您的额度已用尽" + + if user.Email == "" { + common.SysError("user email is empty") + return } - if email != "" { - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) - err = common.SendEmail(prompt, email, - fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。