Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:ecodeclub/webook into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
juniaoshaonian committed Dec 5, 2024
2 parents dec2798 + bc975ec commit b237003
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 49 deletions.
9 changes: 6 additions & 3 deletions internal/ai/internal/domain/jd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const (
AnalysisJDTech = "analysis_jd_tech"
AnalysisJDBiz = "analysis_jd_biz"
AnalysisJDPosition = "analysis_jd_position"
AnalysisJDSubtext = "analysis_jd_subtext"
)

type JDEvaluation struct {
Expand All @@ -13,7 +14,9 @@ type JDEvaluation struct {

type JD struct {
Amount int64
TechScore *JDEvaluation
BizScore *JDEvaluation
PosScore *JDEvaluation
TechScore JDEvaluation
BizScore JDEvaluation
PosScore JDEvaluation
// 潜台词
Subtext string
}
53 changes: 34 additions & 19 deletions internal/ai/internal/integration/llm_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func (s *LLMServiceSuite) SetupSuite() {
err = s.db.Create(&dao.BizConfig{
Biz: domain.BizQuestionExamine,
MaxInput: 100,
Price: 1,
PromptTemplate: "这是问题 %s,这是问题内容 %s,这是用户输入 %s",
KnowledgeId: knowledgeId,
Ctime: now,
Expand All @@ -68,6 +69,7 @@ func (s *LLMServiceSuite) SetupSuite() {
err = s.db.Create(&dao.BizConfig{
Biz: domain.BizCaseExamine,
MaxInput: 100,
Price: 1,
PromptTemplate: "这是案例 %s,这是案例内容 %s,这是用户输入 %s",
KnowledgeId: knowledgeId,
Ctime: now,
Expand Down Expand Up @@ -105,6 +107,15 @@ func (s *LLMServiceSuite) SetupSuite() {
}).Error
s.NoError(err)

err = s.db.Create(&dao.BizConfig{
Biz: domain.AnalysisJDSubtext,
MaxInput: 100,
PromptTemplate: "这是岗位描述Subtext %s",
KnowledgeId: knowledgeId,
Ctime: now,
Utime: now,
}).Error
s.NoError(err)
}

func (s *LLMServiceSuite) TearDownSuite() {
Expand Down Expand Up @@ -639,31 +650,34 @@ func (s *LLMServiceSuite) TestHandler_AnalysisJD() {
llmHdl := hdlmocks.NewMockHandler(ctrl)
llmHdl.EXPECT().Handle(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, request domain.LLMRequest) (domain.LLMResponse, error) {
if request.Biz == "analysis_jd_tech" {
switch request.Biz {
case domain.AnalysisJDTech:
return domain.LLMResponse{
Tokens: 1000,
Amount: 100,
Answer: `score: 6
这是技术前景`,
Answer: `{"score":6, "summary":["这是技术前景"]}`,
}, nil
}
if request.Biz == "analysis_jd_biz" {
case domain.AnalysisJDBiz:
return domain.LLMResponse{
Tokens: 100,
Amount: 200,
Answer: `score: 7
这是业务前景`,
Answer: `{"score":7, "summary":["这是业务前景"]}`,
}, nil
}
if request.Biz == "analysis_jd_position" {
case domain.AnalysisJDPosition:
return domain.LLMResponse{
Tokens: 100,
Amount: 100,
Answer: `{"score":8, "summary":["这是公司地位"]}`,
}, nil
case domain.AnalysisJDSubtext:
return domain.LLMResponse{
Tokens: 100,
Amount: 100,
Answer: `score: 8
这是公司地位`,
Answer: `这是我的分析`,
}, nil
default:
return domain.LLMResponse{}, errors.New("unknown biz")
}
return domain.LLMResponse{}, errors.New("unknown biz")
}).AnyTimes()
creditSvc := creditmocks.NewMockService(ctrl)
creditSvc.EXPECT().GetCreditsByUID(gomock.Any(), gomock.Any()).Return(credit.Credit{
Expand All @@ -676,19 +690,20 @@ func (s *LLMServiceSuite) TestHandler_AnalysisJD() {
after: func(t *testing.T, resp web.JDResponse) {
// 校验response写入的内容是否正确
assert.Equal(t, web.JDResponse{
Amount: 400,
TechScore: &web.JDEvaluation{
Amount: 500,
TechScore: web.JDEvaluation{
Score: 6,
Analysis: "这是技术前景",
Analysis: "- 这是技术前景",
},
BizScore: &web.JDEvaluation{
BizScore: web.JDEvaluation{
Score: 7,
Analysis: "这是业务前景",
Analysis: "- 这是业务前景",
},
PosScore: &web.JDEvaluation{
PosScore: web.JDEvaluation{
Score: 8,
Analysis: "这是公司地位",
Analysis: "- 这是公司地位",
},
Subtext: "这是我的分析",
}, resp)

},
Expand Down
69 changes: 51 additions & 18 deletions internal/ai/internal/service/jd_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,45 @@ package service

import (
"context"
"errors"
"strconv"
"encoding/json"
"regexp"
"strings"
"sync/atomic"

"github.com/gotomicro/ego/core/elog"

"github.com/ecodeclub/webook/internal/ai/internal/domain"
"github.com/ecodeclub/webook/internal/ai/internal/service/llm"
"github.com/lithammer/shortuuid/v4"
"golang.org/x/sync/errgroup"
)

// 最简单的提取方式
const jsonExpr = `\{(.|\n|\r)+\}`

type JDService interface {
// Evaluate 测评
Evaluate(ctx context.Context, uid int64, jd string) (domain.JD, error)
}

type jdSvc struct {
aiSvc llm.Service
aiSvc llm.Service
logger *elog.Component
expr *regexp.Regexp
}

func NewJDService(aiSvc llm.Service) JDService {
return &jdSvc{
aiSvc: aiSvc,
aiSvc: aiSvc,
logger: elog.DefaultLogger,
expr: regexp.MustCompile(jsonExpr),
}
}

func (j *jdSvc) Evaluate(ctx context.Context, uid int64, jd string) (domain.JD, error) {
var techJD, bizJD, positionJD *domain.JDEvaluation
var techJD, bizJD, positionJD domain.JDEvaluation
var amount int64
var subtext string
var eg errgroup.Group
eg.Go(func() error {
var err error
Expand Down Expand Up @@ -62,6 +72,19 @@ func (j *jdSvc) Evaluate(ctx context.Context, uid int64, jd string) (domain.JD,
atomic.AddInt64(&amount, positionAmount)
return nil
})

eg.Go(func() error {
tid := shortuuid.New()
resp, err := j.aiSvc.Invoke(ctx, domain.LLMRequest{
Uid: uid,
Tid: tid,
Biz: domain.AnalysisJDSubtext,
Input: []string{jd},
})
subtext = resp.Answer
atomic.AddInt64(&amount, resp.Amount)
return err
})
if err := eg.Wait(); err != nil {
return domain.JD{}, err
}
Expand All @@ -70,10 +93,11 @@ func (j *jdSvc) Evaluate(ctx context.Context, uid int64, jd string) (domain.JD,
TechScore: techJD,
BizScore: bizJD,
PosScore: positionJD,
Subtext: subtext,
}, nil
}

func (j *jdSvc) analysisJd(ctx context.Context, uid int64, biz string, jd string) (int64, *domain.JDEvaluation, error) {
func (j *jdSvc) analysisJd(ctx context.Context, uid int64, biz string, jd string) (int64, domain.JDEvaluation, error) {
tid := shortuuid.New()
aiReq := domain.LLMRequest{
Uid: uid,
Expand All @@ -83,20 +107,29 @@ func (j *jdSvc) analysisJd(ctx context.Context, uid int64, biz string, jd string
}
resp, err := j.aiSvc.Invoke(ctx, aiReq)
if err != nil {
return 0, nil, err
return 0, domain.JDEvaluation{}, err
}
answer := strings.SplitN(resp.Answer, "\n", 2)
if len(answer) != 2 {
return 0, nil, errors.New("不符合预期的大模型响应")
}
score := answer[0]
scoreNum, err := strconv.ParseFloat(strings.TrimSpace(strings.TrimPrefix(score, "score:")), 64)
jsonStr := j.expr.FindString(resp.Answer)
var (
scoreResp ScoreResp
analysis string
)
err = json.Unmarshal([]byte(jsonStr), &scoreResp)
if err != nil {
return 0, nil, errors.New("分数返回的数据不对")
j.logger.Error("不符合预期的大模型响应",
elog.FieldErr(err),
elog.String("resp", resp.Answer))
} else {
analysis = "- " + strings.Join(scoreResp.Summary, "\n- ")
}

return resp.Amount, &domain.JDEvaluation{
Score: scoreNum,
Analysis: strings.TrimSpace(strings.TrimPrefix(answer[1], "analysis:")),
return resp.Amount, domain.JDEvaluation{
Score: scoreResp.Score,
// 按照 Markdown 的写法,拼接起来
Analysis: analysis,
}, nil
}

type ScoreResp struct {
Score float64 `json:"score"`
Summary []string `json:"summary"`
}
50 changes: 50 additions & 0 deletions internal/ai/internal/service/jd_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2023 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package service

import (
"regexp"
"testing"

"github.com/stretchr/testify/assert"
)

// TestJSONExpression 测试利用正则表达式提取 JSON 串
func TestJSONExpression(t *testing.T) {
testCases := []struct {
name string
input string
want string
}{
{
name: "本身就是JSON",
input: `{"abc": "bcd"}`,
want: `{"abc": "bcd"}`,
},
{
name: "有前缀后缀",
input: "```json{\"abc\": \"bcd\"}```",
want: `{"abc": "bcd"}`,
},
}

expr := regexp.MustCompile(jsonExpr)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
val := expr.FindString(tc.input)
assert.Equal(t, tc.want, val)
})
}
}
4 changes: 4 additions & 0 deletions internal/ai/internal/service/llm/handler/credit/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func NewHandlerBuilder(creSvc credit.Service, repo repository.LLMCreditLogRepo)

func (h *HandlerBuilder) Next(next handler.Handler) handler.Handler {
return handler.HandleFunc(func(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
// 不需要扣除积分
if req.Config.Price == 0 {
return next.Handle(ctx, req)
}
cre, err := h.creditSvc.GetCreditsByUID(ctx, req.Uid)
if err != nil {
return domain.LLMResponse{}, err
Expand Down
4 changes: 2 additions & 2 deletions internal/ai/internal/service/llm/handler/log/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ func (h *HandlerBuilder) Next(next handler.Handler) handler.Handler {
elog.Int64("uid", req.Uid),
elog.String("biz", req.Biz))
// 记录请求
logger.Info("请求 LLM")
logger.Debug("请求 LLM")
resp, err := next.Handle(ctx, req)
if err != nil {
// 记录错误
logger.Error("请求 LLM 服务失败", elog.FieldErr(err))
return resp, err
}
// 记录响应
logger.Info("请求 LLM 服务响应成功", elog.Int64("tokens", resp.Tokens))
logger.Debug("请求 LLM 服务响应成功", elog.Int64("tokens", resp.Tokens))
return resp, err
})
}
5 changes: 3 additions & 2 deletions internal/ai/internal/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (h *Handler) AnalysisJd(ctx *ginx.Context, req JDRequest, sess session.Sess
TechScore: h.newJD(resp.TechScore),
BizScore: h.newJD(resp.BizScore),
PosScore: h.newJD(resp.PosScore),
Subtext: resp.Subtext,
},
}, nil
default:
Expand All @@ -73,8 +74,8 @@ func (h *Handler) AnalysisJd(ctx *ginx.Context, req JDRequest, sess session.Sess

}

func (h *Handler) newJD(jd *domain.JDEvaluation) *JDEvaluation {
return &JDEvaluation{
func (h *Handler) newJD(jd domain.JDEvaluation) JDEvaluation {
return JDEvaluation{
Score: jd.Score,
Analysis: jd.Analysis,
}
Expand Down
9 changes: 5 additions & 4 deletions internal/ai/internal/web/vo.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ type JDRequest struct {
}

type JDResponse struct {
Amount int64 `json:"amount"`
TechScore *JDEvaluation `json:"techScore"`
BizScore *JDEvaluation `json:"bizScore"`
PosScore *JDEvaluation `json:"posScore"`
Amount int64 `json:"amount"`
TechScore JDEvaluation `json:"techScore"`
BizScore JDEvaluation `json:"bizScore"`
PosScore JDEvaluation `json:"posScore"`
Subtext string `json:"subtext"`
}

type JDEvaluation struct {
Expand Down
4 changes: 3 additions & 1 deletion ioc/private/nonsense/non_sense_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
package nonsense

import (
"log/slog"

"github.com/gin-gonic/gin"
)

// NonSenseV1
var NonSenseV1 gin.HandlerFunc = func(ct *gin.Context) {
// 啥也不做
println("hello")
slog.Debug("进来了 NonSenseV1")
}

0 comments on commit b237003

Please sign in to comment.