Skip to content

Commit

Permalink
feat: 支持新绘画模型及模型判断逻辑收敛 (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankCheungDev authored Nov 18, 2023
1 parent 3588050 commit 7d3a8f5
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 35 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ api_key: "xxxxxxxxx"
base_url: ""
# 指定模型,默认为 gpt-3.5-turbo , 可选参数有: "gpt-4-32k-0613", "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0613", "gpt-4-0314", "gpt-4", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0301", "gpt-3.5-turbo",如果使用gpt-4,请确认自己是否有接口调用白名单,如果你是用的是azure,则该配置项可以留空或者直接忽略
model: "gpt-3.5-turbo"
# 指定绘画模型,默认为 dall-e-2 , 可选参数有:"dall-e-2", "dall-e-3"
image_model: "dall-e-2"
# 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
session_timeout: 600
# 最大问题长度
Expand Down
2 changes: 2 additions & 0 deletions config.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ api_key: "xxxxxxxxx"
base_url: ""
# 指定模型,默认为 gpt-3.5-turbo , 可选参数有:"gpt-4-32k-0613", "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0613", "gpt-4-0314", "gpt-4-turbo-preview", "gpt-4-vision-preview", "gpt-4", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo",如果使用gpt-4,请确认自己是否有接口调用白名单,如果你是用的是azure,则该配置项可以留空或者直接忽略
model: "gpt-3.5-turbo"
# 指定绘画模型,默认为 dall-e-2 , 可选参数有:"dall-e-2", "dall-e-3"
image_model: "dall-e-2"
# 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
session_timeout: 600
# 最大问题长度
Expand Down
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type Configuration struct {
BaseURL string `yaml:"base_url"`
// 使用模型
Model string `yaml:"model"`
// 使用绘画模型
ImageModel string `yaml:"image_model"`
// 会话超时时间
SessionTimeout time.Duration `yaml:"session_timeout"`
// 最大问题长度
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ services:
RUN_MODE: "stream" # 运行模式,http 或者 stream ,强烈建议你使用stream模式,通过此链接了解:https://open.dingtalk.com/document/isvapp/stream
BASE_URL: "" # 如果你使用官方的接口地址 https://api.openai.com,则留空即可,如果你想指定请求url的地址,可通过这个参数进行配置,注意需要带上 http 协议
MODEL: "gpt-3.5-turbo" # 指定模型,默认为 gpt-3.5-turbo , 可选参数有: "gpt-4-32k-0613", "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0613", "gpt-4-0314", "gpt-4", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0301", "gpt-3.5-turbo",如果使用gpt-4,请确认自己是否有接口调用白名单,如果你是用的是azure,则该配置项可以留空或者直接忽略
IMAGE_MODEL: "dall-e-2" # 指定绘画模型,默认为 dall-e-2 , 可选参数有:"dall-e-2", "dall-e-3"
SESSION_TIMEOUT: 600 # 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
MAX_QUESTION_LEN: 2048 # 最大问题长度,默认4096 token,正常情况默认值即可,如果使用gpt4-8k或gpt4-32k,可根据模型token上限修改。
MAX_ANSWER_LEN: 2048 # 最大回答长度,默认4096 token,正常情况默认值即可,如果使用gpt4-8k或gpt4-32k,可根据模型token上限修改。
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/avast/retry-go v3.0.0+incompatible // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.8.0 // indirect
github.com/chai2010/webp v1.1.1 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dlclark/regexp2 v1.9.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.8.0 h1:ea0Xadu+sHlu7x5O3gKhRpQ1IKiMrSiHttPF0ybECuA=
github.com/bytedance/sonic v1.8.0/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E=
github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c=
github.com/charmbracelet/log v0.2.1 h1:1z7jpkk4yKyjwlmKmKMM5qnEDSpV32E7XtWhuv0mTZE=
Expand Down
61 changes: 39 additions & 22 deletions pkg/chatgpt/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ import (
"encoding/gob"
"errors"
"fmt"

"github.com/chai2010/webp"
"image"
_ "image/gif"
_ "image/jpeg"
"image/png"

"os"
"strings"
"time"
Expand Down Expand Up @@ -137,6 +143,22 @@ func (c *ChatContext) SetPreset(preset string) {
c.preset = preset
}

// 通过 base64 编码字符串开头字符判断图像类型
func getImageTypeFromBase64(base64Str string) string {
switch {
case strings.HasPrefix(base64Str, "/9j/"):
return "JPEG"
case strings.HasPrefix(base64Str, "iVBOR"):
return "PNG"
case strings.HasPrefix(base64Str, "R0lG"):
return "GIF"
case strings.HasPrefix(base64Str, "UklG"):
return "WebP"
default:
return "Unknown"
}
}

func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
question = question + "."
if tokenizer.MustCalToken(question) > c.maxQuestionLen {
Expand Down Expand Up @@ -181,20 +203,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
if public.Config.AzureOn {
userId = ""
}
if model == openai.GPT432K0613 ||
model == openai.GPT432K0314 ||
model == openai.GPT432K ||
model == openai.GPT40613 ||
model == openai.GPT40314 ||
model == openai.GPT4TurboPreview ||
model == openai.GPT4VisionPreview ||
model == openai.GPT4 ||
model == openai.GPT3Dot5Turbo1106 ||
model == openai.GPT3Dot5Turbo0613 ||
model == openai.GPT3Dot5Turbo0301 ||
model == openai.GPT3Dot5Turbo16K ||
model == openai.GPT3Dot5Turbo16K0613 ||
model == openai.GPT3Dot5Turbo {
if isModelSupportedChatCompletions(model) {
req := openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
Expand Down Expand Up @@ -248,14 +257,13 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
return resp.Choices[0].Text, nil
}
}
func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
func (c *ChatGPT) GenerateImage(ctx context.Context, prompt string) (string, error) {
model := public.Config.Model
if model == openai.GPT3Dot5Turbo || model == openai.GPT3Dot5Turbo0301 || model == openai.GPT3Dot5Turbo0613 ||
model == openai.GPT3Dot5Turbo16K || model == openai.GPT3Dot5Turbo16K0613 ||
model == openai.GPT4 || model == openai.GPT40314 || model == openai.GPT40613 ||
model == openai.GPT432K || model == openai.GPT432K0314 || model == openai.GPT432K0613 {
imageModel := public.Config.ImageModel
if isModelSupportedChatCompletions(model) {
req := openai.ImageRequest{
Prompt: prompt,
Model: imageModel,
Size: openai.CreateImageSize1024x1024,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
N: 1,
Expand All @@ -271,9 +279,18 @@ func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, err
}

r := bytes.NewReader(imgBytes)
imgData, err := png.Decode(r)
if err != nil {
return "", err

// dall-e-3 返回的是 WebP 格式的图片,需要判断处理
imgType := getImageTypeFromBase64(respBase64.Data[0].B64JSON)
var imgData image.Image
var imgErr error
if imgType == "WebP" {
imgData, imgErr = webp.Decode(r)
} else {
imgData, _, imgErr = image.Decode(r)
}
if imgErr != nil {
return "", imgErr
}

imageName := time.Now().Format("20060102-150405") + ".png"
Expand Down
2 changes: 1 addition & 1 deletion pkg/chatgpt/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func ImageQa(ctx context.Context, question, userId string) (answer string, err e
// 使用重试策略进行重试
err = retry.Do(
func() error {
answer, err = chat.GenreateImage(ctx, question)
answer, err = chat.GenerateImage(ctx, question)
if err != nil {
return err
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/chatgpt/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ go 1.18

require (
github.com/avast/retry-go v3.0.0+incompatible
github.com/chai2010/webp v1.1.1
github.com/eryajf/chatgpt-dingtalk v1.0.11
github.com/pandodao/tokenizer-go v0.2.0
github.com/sashabaranov/go-openai v1.17.6
)

Expand All @@ -14,11 +16,16 @@ require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/charmbracelet/log v0.2.1 // indirect
github.com/dlclark/regexp2 v1.9.0 // indirect
github.com/dop251/goja v0.0.0-20230402114112-623f9dda9079 // indirect
github.com/dop251/goja_nodejs v0.0.0-20230322100729-2550c7b6c124 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/glebarez/go-sqlite v1.20.3 // indirect
github.com/glebarez/sqlite v1.7.0 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
github.com/google/pprof v0.0.0-20230406165453-00490a63f317 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
Expand All @@ -32,6 +39,7 @@ require (
github.com/rivo/uniseg v0.2.0 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.9.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gorm.io/gorm v1.24.6 // indirect
modernc.org/libc v1.22.3 // indirect
Expand Down
Loading

0 comments on commit 7d3a8f5

Please sign in to comment.