diff --git a/go.mod b/go.mod index ae2a03c..a0697e6 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.21.5 require ( github.com/gin-gonic/gin v1.9.1 github.com/go-redis/redis/v8 v8.11.5 - github.com/weaviate/weaviate v1.23.3 + github.com/weaviate/weaviate v1.24.1 github.com/weaviate/weaviate-go-client/v4 v4.12.1 ) @@ -46,14 +46,15 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect go.mongodb.org/mongo-driver v1.13.1 // indirect golang.org/x/arch v0.7.0 // indirect - golang.org/x/crypto v0.18.0 // indirect + golang.org/x/crypto v0.19.0 // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect - golang.org/x/sys v0.16.0 // indirect + golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240108191215-35c7eff3a6b1 // indirect diff --git a/go.sum b/go.sum index 8969196..e91fb34 100644 --- a/go.sum +++ b/go.sum @@ -105,6 +105,8 @@ github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -126,6 +128,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/weaviate/weaviate v1.23.3 h1:9nbPTzwheh2XYk6hPCwCJnkuBDq5Ob/dXZWbBuLztaY= github.com/weaviate/weaviate v1.23.3/go.mod h1:afludwbcyIZa9HEBELvHNb8zjH+KcjcW/jb4SZ5C2T4= +github.com/weaviate/weaviate v1.24.1 h1:Cl/NnqgFlNfyC7KcjFtETf1bwtTQPLF3oz5vavs+Jq0= +github.com/weaviate/weaviate v1.24.1/go.mod h1:wcg1vJgdIQL5MWBN+871DFJQa+nI2WzyXudmGjJ8cG4= github.com/weaviate/weaviate-go-client/v4 v4.12.1 h1:XFKL49BgSOcxrFs5IV+Q5pydLTsh0HQHuWbKNSLMWLU= github.com/weaviate/weaviate-go-client/v4 v4.12.1/go.mod h1:r1PlU5sAZKFvAPgymEHQj0hjSAuEV9X77PJ/ffZ6cEo= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= @@ -143,6 +147,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -164,6 +170,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/main.go b/main.go index 87145b4..96d56b0 100644 --- a/main.go +++ b/main.go @@ -44,6 +44,9 @@ func main() { }) router.GET("/weaviate/retrieveresponse", func(c *gin.Context) { + + //TODO: allow user annotation + searchQuery := c.Query("query") upvoteStr := c.Query("best") best := upvoteStr == "true" @@ -76,6 +79,66 @@ func main() { c.JSON(http.StatusOK, response) }) + router.GET("/weaviate/retrieveresponselist", func(c *gin.Context) { + searchQuery := c.Query("query") + + decodedQuery, err := url.QueryUnescape(searchQuery) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid search query"}) + return + } + + log.Printf("Decoded Query: %s", decodedQuery) + + responseList, err := weaviate.ResponseList(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, responseList) + }) + + router.GET("/weaviate/responsebyid", func(c *gin.Context) { + searchQuery := c.Query("id") + + decodedQuery, err := url.QueryUnescape(searchQuery) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid search query"}) + return + } + + log.Printf("Decoded Query: %s", decodedQuery) + + response, err := weaviate.RetrieveResponseByID(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, response) + }) + + router.GET("/weaviate/propertiesbyid", func(c *gin.Context) { + searchQuery := c.Query("id") + + decodedQuery, err := url.QueryUnescape(searchQuery) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid search query"}) + return + } + + log.Printf("Decoded Query: %s", decodedQuery) + + response, err := weaviate.RetrieveProperties(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, response) + }) + router.POST("/generate", func(c *gin.Context) { var requestBody map[string]interface{} if err := c.BindJSON(&requestBody); err != nil { diff --git a/ollama/ollama.go b/ollama/ollama.go index 7d357e0..b67a16b 100644 --- a/ollama/ollama.go +++ b/ollama/ollama.go @@ -12,13 +12,7 @@ import ( "os" ) -type ResponseData struct { - Response string `json:"response"` - PromptID string `json:"promptID"` - Instruct string `json:"instruct"` -} - -func GenerateResponse(prompt map[string]interface{}) (ResponseData, error) { +func GenerateResponse(prompt map[string]interface{}) (weaviate.ResponseData, error) { url := os.Getenv("OLLAMA_URL") + "/api/generate" set, ok := prompt["instructType"].(string) @@ -26,36 +20,43 @@ func GenerateResponse(prompt map[string]interface{}) (ResponseData, error) { set = "default" } - instruct, err := redis.GetSetMember(set) - if err != nil { - return ResponseData{}, errors.New("no data available") + instruct, ok := prompt["instruct"].(string) + + log.Printf("ok: %v", ok) + if !ok { + instruct, _ = redis.GetSetMember(set) } log.Printf("Prompt: %s\n", instruct) code, ok := prompt["prompt"].(string) if !ok { - return ResponseData{}, errors.New("prompt field is not a string") + return weaviate.ResponseData{}, errors.New("prompt field is not a string") } log.Printf("Code: %s\n", code) completePrompt := instruct + " " + code + model, ok := prompt["model"].(string) + if !ok { + model = "codellama:13b-instruct" + } + requestBody := map[string]interface{}{ - "model": prompt["model"], + "model": model, "prompt": completePrompt, "stream": false, } jsonData, err := json.Marshal(requestBody) if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } req.Header.Set("Content-Type", "application/json") @@ -63,7 +64,7 @@ func GenerateResponse(prompt map[string]interface{}) (ResponseData, error) { client := &http.Client{} resp, err := client.Do(req) if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } defer func(Body io.ReadCloser) { err := Body.Close() @@ -74,31 +75,33 @@ func GenerateResponse(prompt map[string]interface{}) (ResponseData, error) { body, err := io.ReadAll(resp.Body) if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } var responseJSON map[string]interface{} err = json.Unmarshal(body, &responseJSON) if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } response, ok := responseJSON["response"].(string) + log.Printf("Reponse: %s\n", response) + if !ok { log.Println("Error: 'response' field is not a string array") - return ResponseData{}, errors.New("invalid response format") + return weaviate.ResponseData{}, errors.New("invalid response format") } - log.Printf("Reponse: %s\n", response) - PromptID, err := weaviate.CreatePromptObject(instruct, code, "Prompt") if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } + log.Printf("PromptID: %s\n", PromptID) + ResponseID, err := weaviate.CreateResponseObject(response, "Response") if err != nil { - return ResponseData{}, err + return weaviate.ResponseData{}, err } err = weaviate.CreateReferences(PromptID, ResponseID) @@ -106,7 +109,7 @@ func GenerateResponse(prompt map[string]interface{}) (ResponseData, error) { panic(err) } - responseData := ResponseData{ + responseData := weaviate.ResponseData{ Response: response, PromptID: PromptID, Instruct: instruct, diff --git a/redis/redis.go b/redis/redis.go index 63d02b4..3d44b71 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -26,6 +26,7 @@ func InitRedis() { members := []interface{}{ "Explain me this:", "How does the following code work?", + "Explain me step by step how this code works:", } rdb.SAdd(ctx, "default", members...) @@ -75,7 +76,7 @@ func AddInstruct(c *gin.Context) { return } - c.Status(http.StatusOK) + c.JSON(http.StatusOK, "added Item to list: "+listName) } func DeleteInstruct(c *gin.Context) { diff --git a/weaviate/weaviate.go b/weaviate/weaviate.go index ba7a19e..fabaa25 100644 --- a/weaviate/weaviate.go +++ b/weaviate/weaviate.go @@ -11,6 +11,29 @@ import ( "strings" ) +type Position struct { + Line int `json:"line"` + Character int `json:"character"` +} + +type Range struct { + Start Position `json:"start"` + End Position `json:"end"` +} + +type ResponseData struct { + Response string `json:"response"` + PromptID string `json:"promptID"` + Instruct string `json:"instruct"` +} + +type PromptProperties struct { + Code string `json:"code"` + HasResponse string `json:"hasResponse"` + Instruct string `json:"instruct"` + Rank int `json:"rank"` +} + func InitSchema() error { client, err := loadClient() @@ -135,14 +158,14 @@ func createClass(className, description, vectorizer string, properties []*models return nil } -func CreatePromptObject(prompt string, code string, class string) (string, error) { +func CreatePromptObject(instruct string, code string, class string) (string, error) { client, err := loadClient() if err != nil { return "", err } dataSchema := map[string]interface{}{ - "instruct": prompt, + "instruct": instruct, "code": code, "rank": 1, } diff --git a/weaviate/weaviate_retrieval.go b/weaviate/weaviate_retrieval.go index 032a412..31744ac 100644 --- a/weaviate/weaviate_retrieval.go +++ b/weaviate/weaviate_retrieval.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/weaviate/weaviate-go-client/v4/weaviate/filters" "github.com/weaviate/weaviate-go-client/v4/weaviate/graphql" "github.com/weaviate/weaviate/entities/models" @@ -12,20 +13,7 @@ import ( "time" ) -type PromptProperties struct { - Code string `json:"code"` - HasResponse []map[string]interface{} `json:"hasResponse"` - Instruct string `json:"instruct"` - Rank int `json:"rank"` -} - -type ResponseData struct { - ID string `json:"id"` - Response string `json:"response"` -} - func RetrieveProperties(id string) (PromptProperties, error) { - client, err := loadClient() if err != nil { return PromptProperties{}, err @@ -39,21 +27,52 @@ func RetrieveProperties(id string) (PromptProperties, error) { return PromptProperties{}, err } - properties := objects[0].Properties + if len(objects) == 0 { + return PromptProperties{}, fmt.Errorf("no object found with ID: %s", id) + } - propertiesJSON, err := json.Marshal(properties) + propertiesJSON, err := json.Marshal(objects[0].Properties) if err != nil { return PromptProperties{}, err } - var promptProperties PromptProperties - err = json.Unmarshal(propertiesJSON, &promptProperties) + var temp struct { + Code string `json:"code"` + HasResponse []map[string]interface{} `json:"hasResponse"` + Instruct string `json:"instruct"` + Rank int `json:"rank"` + } + + if err := json.Unmarshal(propertiesJSON, &temp); err != nil { + return PromptProperties{}, err + } + + var responseText string + if len(temp.HasResponse) > 0 { + responseTextBytes, err := json.Marshal(temp.HasResponse) + if err != nil { + return PromptProperties{}, err + } + responseText = string(responseTextBytes) + } + + response, err := RetrieveResponseByID(id) if err != nil { return PromptProperties{}, err } + responseText, ok := response.(string) + if !ok { + return PromptProperties{}, fmt.Errorf("response from RetrieveResponseByID is not a string") + } - return promptProperties, nil + promptProperties := PromptProperties{ + Code: temp.Code, + HasResponse: responseText, + Instruct: temp.Instruct, + Rank: temp.Rank, + } + return promptProperties, nil } func RetrievePromptCount(code string) (int, error) { @@ -120,6 +139,87 @@ func RetrievePromptCount(code string) (int, error) { return int(countFloat), nil } +func RetrieveResponseByID(id string) (interface{}, error) { + client, err := loadClient() + if err != nil { + return nil, err + } + + fields := []graphql.Field{ + {Name: "hasResponse", Fields: []graphql.Field{ + {Name: "... on Response", Fields: []graphql.Field{ + {Name: "response"}, + }}, + }}, + {Name: "_additional", Fields: []graphql.Field{ + {Name: "id"}, + }}, + {Name: "rank"}, + {Name: "instruct"}, + } + + withNearObject := &graphql.NearObjectArgumentBuilder{} + + withNearObject.WithID(id) + + ctx := context.Background() + result, err := client.GraphQL().Get(). + WithClassName("Prompt"). + WithFields(fields...). + WithLimit(1). + WithNearObject(withNearObject). + Do(ctx) + if err != nil { + return nil, err + } + + response, err := ExtractResponseFromGraphQL(result) + if err != nil { + return nil, err + } + + return response, nil +} + +func ResponseList(code string) ([]string, error) { + responses, err := RetrieveResponsesRankDesc(code) + if err != nil { + return nil, err + } + + getPrompt, ok := responses.Data["Get"].(map[string]interface{}) + if !ok { + return nil, errors.New("unexpected response format: 'Get' field not found or not a map") + } + + promptData, ok := getPrompt["Prompt"].([]interface{}) + if !ok { + return nil, errors.New("unexpected response format: 'Prompt' field not found") + } + + if len(promptData) == 0 { + return nil, errors.New("no prompt found") + } + + var RankIDs []string + + for _, prompt := range promptData { + promptMap, ok := prompt.(map[string]interface{}) + if !ok { + return nil, errors.New("unexpected response format: prompt data is not a map") + } + + id, err := ExtractID(promptMap) + if err != nil { + return nil, err + } + + RankIDs = append(RankIDs, id) + } + + return RankIDs, nil +} + func RetrieveBestResponse(code string) (ResponseData, error) { responses, err := RetrieveResponsesRankDesc(code) @@ -146,8 +246,6 @@ func RetrieveBestResponse(code string) (ResponseData, error) { return ResponseData{}, errors.New("unexpected response format: prompt data is not a map") } - log.Printf("%v", promptMap) - rankInterface, ok := promptMap["rank"] log.Printf("%v", rankInterface) if !ok { @@ -186,9 +284,15 @@ func RetrieveBestResponse(code string) (ResponseData, error) { return ResponseData{}, err } + instruct, err := ExtractInstruct(selectedPrompt) + if err != nil { + return ResponseData{}, err + } + responseData := ResponseData{ - ID: id, + PromptID: id, Response: response, + Instruct: instruct, } return responseData, nil @@ -235,9 +339,15 @@ func RetrieveRandomResponse(code string) (ResponseData, error) { return ResponseData{}, err } + instruct, err := ExtractInstruct(selectedPromptMap) + if err != nil { + return ResponseData{}, err + } + responseData := ResponseData{ - ID: id, + PromptID: id, Response: response, + Instruct: instruct, } return responseData, nil @@ -260,6 +370,7 @@ func RetrieveResponsesRankDesc(code string) (*models.GraphQLResponse, error) { {Name: "id"}, }}, {Name: "rank"}, + {Name: "instruct"}, } where := filters.Where(). @@ -322,3 +433,40 @@ func ExtractResponse(selectedPromptMap map[string]interface{}) (string, error) { return response, nil } + +func ExtractInstruct(selectedPrompt map[string]interface{}) (string, error) { + hasInstruct, ok := selectedPrompt["instruct"] + if !ok { + return "", errors.New("instruct field not found in prompt data") + } + + instruct, ok := hasInstruct.(string) + if !ok { + return "", errors.New("id field is not a string in _additional data") + } + + return instruct, nil +} + +func ExtractResponseFromGraphQL(query *models.GraphQLResponse) (string, error) { + getPrompt, ok := query.Data["Get"].(map[string]interface{}) + if !ok { + return "", errors.New("unexpected response format: 'Get' field not found or not a map") + } + + promptData, ok := getPrompt["Prompt"].([]interface{}) + if !ok || len(promptData) == 0 { + return "", errors.New("unexpected response format: 'Prompt' field not found or empty list") + } + selectedPrompt := promptData[0].(map[string]interface{}) + if !ok { + return "", errors.New("unexpected response format: selected prompt data is not a map") + } + + response, err := ExtractResponse(selectedPrompt) + if err != nil { + return "", err + } + return response, nil + +}