Skip to content

Commit

Permalink
add property extraction by id, vote mechanism and generation enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
ili16 committed Feb 25, 2024
1 parent fafff56 commit e58a3e5
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 32 deletions.
31 changes: 20 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
)

func main() {

err := weaviate.InitSchema()
if err != nil {
panic(err)
Expand Down Expand Up @@ -66,38 +65,48 @@ func main() {
})

router.POST("/generate", func(c *gin.Context) {
// Parse the JSON request body
var requestBody map[string]interface{}
if err := c.BindJSON(&requestBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Call ollama.GenerateResponse with the parsed request body
response, err := ollama.GenerateResponse(requestBody)
if err != nil {
return
}

// Return the response
c.JSON(http.StatusOK, response)
})

router.POST("/vote", func(c *gin.Context) {
// Parse the JSON request body

upvoteStr := c.Query("upvote")
upvote := upvoteStr == "true"

var requestBody map[string]interface{}
if err := c.BindJSON(&requestBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Call ollama.GenerateResponse with the parsed request body
err := weaviate.UpdateRankPrompt(requestBody)
if err != nil {
return
log.Printf("%v\n", requestBody)

if upvote {
err := weaviate.UpdateRankPrompt(requestBody, true)
if err != nil {
log.Printf("%v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
} else {
err := weaviate.UpdateRankPrompt(requestBody, false)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
log.Printf("%v", err)
return
}
}

// Return the response
c.JSON(http.StatusOK, "OK")
})

Expand Down
41 changes: 23 additions & 18 deletions ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/rwth-acis/modernizer/redis"
"github.com/rwth-acis/modernizer/weaviate"
"io"
Expand All @@ -13,22 +12,24 @@ import (
"os"
)

func GenerateResponse(prompt map[string]interface{}) (string, error) {
url := os.Getenv("OLLAMA_URL") + "/api/generate"
type ResponseData struct {
Response string `json:"response"`
PromptID string `json:"promptID"`
}

// TODO add possibility to differentiate between system prompt roles/creativity
// TODO add routes to show and add prompts
func GenerateResponse(prompt map[string]interface{}) (ResponseData, error) {
url := os.Getenv("OLLAMA_URL") + "/api/generate"

instruct, err := redis.GetSetMember("default")
if err != nil {
return "", err
return ResponseData{}, errors.New("no data available")
}

log.Printf("Prompt: %s\n", instruct)

code, ok := prompt["prompt"].(string)
if !ok {
return "", errors.New("prompt field is not a string")
return ResponseData{}, errors.New("prompt field is not a string")
}

log.Printf("Code: %s\n", code)
Expand All @@ -38,25 +39,25 @@ func GenerateResponse(prompt map[string]interface{}) (string, error) {
requestBody := map[string]interface{}{
"model": prompt["model"],
"prompt": completePrompt,
"stream": prompt["stream"],
"stream": false,
}

jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", err
return ResponseData{}, err
}

req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
return ResponseData{}, err
}

req.Header.Set("Content-Type", "application/json")

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
return ResponseData{}, err
}
defer func(Body io.ReadCloser) {
err := Body.Close()
Expand All @@ -67,38 +68,42 @@ func GenerateResponse(prompt map[string]interface{}) (string, error) {

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
return ResponseData{}, err
}

var responseJSON map[string]interface{}
err = json.Unmarshal(body, &responseJSON)
if err != nil {
fmt.Println("Error decoding JSON response:", err)
return "", err
return ResponseData{}, err
}

response, ok := responseJSON["response"].(string)
if !ok {
log.Println("Error: 'response' field is not a string array")
return "", errors.New("invalid response format")
return ResponseData{}, errors.New("invalid response format")
}

log.Printf("Reponse: %s\n", response)

PromptID, err := weaviate.CreatePromptObject(instruct, code, "Prompt")
if err != nil {
return "", err
return ResponseData{}, err
}

ResponseID, err := weaviate.CreateResponseObject(response, "Response")
if err != nil {
return "", err
return ResponseData{}, err
}

err = weaviate.CreateReferences(PromptID, ResponseID)
if err != nil {
panic(err)
}

return response, nil
responseData := ResponseData{
Response: response,
PromptID: PromptID,
}

return responseData, nil
}
59 changes: 56 additions & 3 deletions weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,76 @@ func CreatePromptObject(prompt string, code string, class string) (string, error
return string(weaviateObject.Object.ID), nil
}

func UpdateRankPrompt(prompt map[string]interface{}) error {
type PromptProperties struct {
Code string `json:"code"`
HasResponse []map[string]interface{} `json:"hasResponse"`
Instruct string `json:"instruct"`
Rank int `json:"rank"`
}

func GetProperties(id string) (PromptProperties, error) {

client, err := loadClient()
if err != nil {
return PromptProperties{}, err
}

objects, err := client.Data().ObjectsGetter().
WithID(id).
WithClassName("Prompt").
Do(context.Background())
if err != nil {
return PromptProperties{}, err
}

properties := objects[0].Properties

propertiesJSON, err := json.Marshal(properties)
if err != nil {
return PromptProperties{}, err
}

var promptProperties PromptProperties
err = json.Unmarshal(propertiesJSON, &promptProperties)
if err != nil {
return PromptProperties{}, err
}

return promptProperties, nil

}

func UpdateRankPrompt(prompt map[string]interface{}, upvote bool) error {
id, ok := prompt["id"].(string)
if !ok {
return errors.New("ID not found in request body")
}

log.Printf("%v\n", id)

client, err := loadClient()
if err != nil {
return err
}

promptProperties, err := GetProperties(id)
if err != nil {
return err
}

var rank int
if upvote {
rank = promptProperties.Rank + 1
} else {
rank = promptProperties.Rank - 1
}

err = client.Data().Updater().
WithMerge().
WithID(id).
WithClassName("Prompt").
WithProperties(map[string]interface{}{
"rank": 2,
"rank": rank,
}).
Do(context.Background())
if err != nil {
Expand Down Expand Up @@ -232,7 +285,7 @@ func RetrievePromptCount(code string) (int, error) {
WithWhere(where).
Do(ctx)
if err != nil {
panic(err)
return 0, err
}

if len(result.Errors) > 0 {
Expand Down

0 comments on commit e58a3e5

Please sign in to comment.