71 lines
1.7 KiB
Go
71 lines
1.7 KiB
Go
package user
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
|
entity2 "github.com/milvus-io/milvus-sdk-go/v2/entity"
|
|
"leafdev.top/Ecosystem/recommender/internal/entity"
|
|
)
|
|
|
|
func (s *Service) SuggestPosts(c context.Context, externalUserEntity *entity.ExternalUser, categoryEntity *entity.Category) ([]*entity.Post, error) {
|
|
emb, err := s.embedding.TextEmbedding(c, []string{externalUserEntity.Summary})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var filter = fmt.Sprintf("application_id == %d && category_id == %s", externalUserEntity.ApplicationId, categoryEntity.Id)
|
|
sp, err := entity2.NewIndexAUTOINDEXSearchParam(1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vector := entity2.FloatVector(emb[0])
|
|
postResults, err := s.milvus.Search(c, s.config.Milvus.PostCollection,
|
|
[]string{},
|
|
filter,
|
|
[]string{"post_id", "category_id"},
|
|
[]entity2.Vector{vector},
|
|
"vector",
|
|
entity2.L2,
|
|
3,
|
|
sp, client.WithLimit(7))
|
|
|
|
var ids []uint
|
|
|
|
for _, res := range postResults {
|
|
// 没找到,直接返回空的
|
|
if res.ResultCount == 0 {
|
|
return make([]*entity.Post, 0), nil
|
|
}
|
|
|
|
var blockIdColumn *entity2.ColumnInt64
|
|
for _, field := range res.Fields {
|
|
if field.Name() == "post_id" {
|
|
c, ok := field.(*entity2.ColumnInt64)
|
|
if ok {
|
|
blockIdColumn = c
|
|
}
|
|
}
|
|
}
|
|
|
|
// 没有记录
|
|
if blockIdColumn == nil {
|
|
return make([]*entity.Post, 0), nil
|
|
//return nil, fmt.Errorf("block_id column not found")
|
|
}
|
|
|
|
for i := 0; i < res.ResultCount; i++ {
|
|
id, err := blockIdColumn.ValueByIdx(i)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ids = append(ids, uint(id))
|
|
|
|
}
|
|
}
|
|
|
|
posts, err := s.dao.Post.Where(s.dao.Post.Where(s.dao.Post.Id.In(ids...))).Find()
|
|
return posts, err
|
|
}
|