generated from charmbracelet/bubbletea-app-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.go
144 lines (127 loc) · 3.51 KB
/
cli.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package main
import (
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/priuatus/gpterm/internal/stdin"
"golang.org/x/term"
tea "github.com/charmbracelet/bubbletea"
gogpt "github.com/sashabaranov/go-openai"
)
type completionResponse struct {
*gogpt.CompletionStream
gogpt.CompletionResponse
streaming bool
}
type completion struct {
client *gogpt.Client
req gogpt.CompletionRequest
}
func (c *completion) addPromptString(prompt string) {
switch c.req.Prompt.(type) {
case string:
c.req.Prompt = c.req.Prompt.(string) + prompt
case []string:
c.req.Prompt = append(c.req.Prompt.([]string), prompt)
}
}
func (c *completion) IsEmptyPrompt() bool {
switch c.req.Prompt.(type) {
case string:
return c.req.Prompt.(string) == ""
case []string:
return len(c.req.Prompt.([]string)) == 0
}
return true
}
func checkPromptType(prompt any) bool {
_, isString := prompt.(string)
_, isStringSlice := prompt.([]string)
return isString || isStringSlice
}
func (c completion) Create(ctx context.Context) (resp completionResponse, err error) {
var stdIn string
stdIn, err = stdin.Read()
if err != nil && err != stdin.ErrEmpty {
return resp, err
}
if err == stdin.ErrEmpty {
stdIn = ""
}
c.addPromptString(stdIn)
if c.IsEmptyPrompt() {
return resp, fmt.Errorf("missing prompt")
}
if c.req.Stream {
resp.streaming = true
resp.CompletionStream, err = c.client.CreateCompletionStream(ctx, c.req)
return
}
resp.CompletionResponse, err = c.client.CreateCompletion(ctx, c.req)
if resp.Choices[0].FinishReason == "length" {
fmt.Fprintf(os.Stderr, "%s: --max-tokens %d reached consider increasing the limit\n", os.Args[0], resp.Usage.CompletionTokens)
}
return resp, err
}
type CLI struct {
APIKey string `short:"k" help:"OpenAI API Token." env:"OPENAI_API_KEY"`
Model string `short:"m" default:"text-davinci-003" help:"The model which will generate the completion."`
Temp float32 `short:"t" default:"0.0" help:"Generation creativity. Higher is crazier."`
MaxTokens int `short:"n" default:"100" help:"Max number of tokens to generate."`
Stream bool `short:"S" default:"true" help:"Whether to stream back partial progress."`
Quiet bool `short:"q" default:"false" help:"Print only the model response."`
Stop []string `short:"s" help:"Up to 4 sequences where the model will stop generating further. The returned text will not contain the stop sequence."`
Prompt []string `arg:"" optional:"" help:"text prompt"`
}
func (t CLI) Run() error {
cmpltn := completion{
client: gogpt.NewClient(t.APIKey),
req: gogpt.CompletionRequest{
Model: t.Model,
Prompt: strings.Join(t.Prompt, " "),
MaxTokens: t.MaxTokens,
Temperature: t.Temp,
TopP: 1.0,
Echo: true,
Stop: t.Stop,
Stream: t.Stream,
},
}
if t.Quiet {
cmpltn.req.Echo = false
}
if term.IsTerminal(int(os.Stdout.Fd())) && !t.Stream {
model := initialModel(cmpltn)
model.quiet = t.Quiet
_, err := tea.NewProgram(model).Run()
return err
}
// Non interactive use
resp, err := cmpltn.Create(context.Background())
if err != nil {
return err
}
if resp.streaming {
defer resp.Close()
for {
response, err := resp.Recv()
if response.Choices != nil {
fmt.Printf("%s", response.Choices[0].Text)
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("stream error: %v", err)
}
}
return nil
}
if resp.Choices != nil {
fmt.Printf("%s", strings.TrimLeft(resp.Choices[0].Text, "\n"))
}
return nil
}