109 lines
2.5 KiB
Go
109 lines
2.5 KiB
Go
package embedding
|
||
|
||
import (
|
||
"context"
|
||
)
|
||
|
||
func (s *Service) TextEmbedding(ctx context.Context, input []string) ([][]float32, error) {
|
||
|
||
var r = make([][]float32, len(input)-1)
|
||
|
||
for _, v := range input {
|
||
embedding2, err := s.OpenAI.CreateEmbedding(ctx, []string{v})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
r = append(r, embedding2[0])
|
||
|
||
//embedding, err := s.getCache(ctx, v)
|
||
//if err != nil {
|
||
// return r, err
|
||
//}
|
||
//
|
||
//if embedding != nil {
|
||
// r = append(r, embedding)
|
||
// continue
|
||
//} else {
|
||
// embedding2, err := s.OpenAI.CreateEmbedding(ctx, []string{v})
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
//
|
||
// r = append(r, embedding2[0])
|
||
//
|
||
// err = s.setCache(ctx, v, embedding2[0])
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
//}
|
||
}
|
||
|
||
return r, nil
|
||
}
|
||
|
||
//
|
||
//func (s *Service) getCache(ctx context.Context, input string) ([]float32, error) {
|
||
// md5Str, err := md5.Md5(input)
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
//
|
||
// c, err := s.dao.WithContext(ctx).Embedding.Where(s.dao.Embedding.TextMd5.Eq(md5Str)).
|
||
// Where(s.dao.Embedding.EmbeddingModel.Eq(s.config.OpenAI.EmbeddingModel)).
|
||
// Count()
|
||
// if c == 0 {
|
||
// return nil, err
|
||
// }
|
||
//
|
||
// first, err := s.dao.WithContext(ctx).Embedding.Where(s.dao.Embedding.TextMd5.Eq(md5Str)).
|
||
// Where(s.dao.Embedding.EmbeddingModel.Eq(s.config.OpenAI.EmbeddingModel)).
|
||
// First()
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
//
|
||
// // byte to float32
|
||
// return first.Vector, nil
|
||
//}
|
||
//
|
||
//func (s *Service) setCache(ctx context.Context, input string, embedding []float32) error {
|
||
// md5Str, err := md5.Md5(input)
|
||
// if err != nil {
|
||
// return err
|
||
// }
|
||
//
|
||
// // redis 锁
|
||
// var key = "lock_" + md5Str
|
||
// lock, err := s.redis.Locker.Obtain(ctx, key, 3*time.Second, nil)
|
||
// if errors.Is(err, redislock.ErrNotObtained) {
|
||
// s.Logger.Sugar.Warnf("redis lock %s not obtained", md5Str)
|
||
// } else if err != nil {
|
||
// return err
|
||
// }
|
||
// defer func(lock *redislock.Lock, ctx context.Context) {
|
||
// err := lock.Release(ctx)
|
||
// if err != nil {
|
||
// s.Logger.Sugar.Error(err)
|
||
// }
|
||
// }(lock, ctx)
|
||
//
|
||
// // 如果没有 cache,则设置
|
||
// c, err := s.dao.WithContext(ctx).Embedding.Where(s.dao.Embedding.TextMd5.Eq(md5Str)).
|
||
// Where(s.dao.Embedding.EmbeddingModel.Eq(s.config.OpenAI.EmbeddingModel)).
|
||
// Count()
|
||
// if err != nil {
|
||
// return err
|
||
// }
|
||
// if c == 0 {
|
||
// return s.dao.WithContext(ctx).Embedding.Create(&entity.Embedding{
|
||
// Text: input,
|
||
// TextMd5: md5Str,
|
||
// Vector: embedding,
|
||
// EmbeddingModel: s.config.OpenAI.EmbeddingModel,
|
||
// })
|
||
// }
|
||
//
|
||
// return nil
|
||
//}
|