Skip to content

Commit

Permalink
change embeddings to built in text-2-vec and add upvote/retrieval fun…
Browse files Browse the repository at this point in the history
…ctionality
  • Loading branch information
ili16 committed Feb 23, 2024
1 parent 30b44fa commit cb4c004
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 26 deletions.
39 changes: 39 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ func main() {
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(strconv.Itoa(count)))
})

router.GET("/weaviate/retrieveresponse", func(c *gin.Context) {
searchQuery := c.Query("query")

// Decode the search query
decodedQuery, err := url.QueryUnescape(searchQuery)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid search query"})
return
}

log.Printf("Decoded Query: %s", decodedQuery)

response, err := weaviate.RetrieveResponse(decodedQuery)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(response))
})

router.POST("/generate", func(c *gin.Context) {
// Parse the JSON request body
var requestBody map[string]interface{}
Expand All @@ -68,6 +89,24 @@ func main() {
c.JSON(http.StatusOK, response)
})

router.POST("/vote", func(c *gin.Context) {
// Parse the JSON request body
var requestBody map[string]interface{}
if err := c.BindJSON(&requestBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Call ollama.GenerateResponse with the parsed request body
err := weaviate.UpdateRankPrompt(requestBody)
if err != nil {
return
}

// Return the response
c.JSON(http.StatusOK, "OK")
})

err = router.Run(":8080")
if err != nil {
return
Expand Down
14 changes: 2 additions & 12 deletions ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,22 +177,12 @@ func GenerateResponse(prompt map[string]interface{}) (string, error) {

log.Printf("Reponse: %s\n", response)

vector, err := CreateEmbedding(code)
PromptID, err := weaviate.CreatePromptObject(chosenPrompt, code, "Prompt")
if err != nil {
return "", err
}

PromptID, err := weaviate.CreatePromptObject(vector, chosenPrompt, code, "Prompt")
if err != nil {
return "", err
}

vector, err = CreateEmbedding(response)
if err != nil {
return "", err
}

ResponseID, err := weaviate.CreateResponseObject(vector, response, "Response")
ResponseID, err := weaviate.CreateResponseObject(response, "Response")
if err != nil {
return "", err
}
Expand Down
191 changes: 177 additions & 14 deletions weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/weaviate/weaviate-go-client/v4/weaviate/filters"
"github.com/weaviate/weaviate-go-client/v4/weaviate/graphql"
"log"
"math/rand"
"os"
"strings"
"time"

"github.com/weaviate/weaviate-go-client/v4/weaviate"
"github.com/weaviate/weaviate-go-client/v4/weaviate/auth"
Expand Down Expand Up @@ -52,18 +55,16 @@ func InitSchema() error {
classObj := &models.Class{
Class: "Response",
Description: "This class contains the responses to prompts",
Vectorizer: "none",
Vectorizer: "text2vec-transformers",
ModuleConfig: map[string]interface{}{
"text2vec-transformers": map[string]interface{}{},
},
Properties: []*models.Property{
{
DataType: []string{"text"},
Description: "The generated response by the LLM",
Name: "response",
},
{
DataType: []string{"int"},
Description: "The relative rank for this response against other ones regarding the same code",
Name: "rank",
},
},
}

Expand All @@ -85,11 +86,20 @@ func InitSchema() error {
classObj := &models.Class{
Class: "Prompt",
Description: "This class holds information regarding the prompt, code and count of queries regarding ones codebase",
Vectorizer: "text2vec-transformers",
ModuleConfig: map[string]interface{}{
"text2vec-transformers": map[string]interface{}{},
},
Properties: []*models.Property{
{
DataType: []string{"text"},
Description: "The specific instruct or question prepended to the code",
Name: "instruct",
ModuleConfig: map[string]interface{}{
"text2vec-transformers": map[string]interface{}{
"skip": true,
},
},
},
{
DataType: []string{"text"},
Expand All @@ -99,9 +109,18 @@ func InitSchema() error {
{
DataType: []string{"Response"},
Name: "hasResponse",
ModuleConfig: map[string]interface{}{
"text2vec-transformers": map[string]interface{}{
"skip": true,
},
},
},
{
DataType: []string{"int"},
Description: "The relative rank for this response against other ones regarding the same code",
Name: "rank",
},
},
Vectorizer: "none",
}

err = client.Schema().ClassCreator().WithClass(classObj).Do(context.Background())
Expand Down Expand Up @@ -139,7 +158,7 @@ func createClass(className, description, vectorizer string, properties []*models
return nil
}

func CreatePromptObject(vector []float32, prompt string, code string, class string) (string, error) {
func CreatePromptObject(prompt string, code string, class string) (string, error) {
client, err := loadClient()
if err != nil {
return "", err
Expand All @@ -148,12 +167,12 @@ func CreatePromptObject(vector []float32, prompt string, code string, class stri
dataSchema := map[string]interface{}{
"instruct": prompt,
"code": code,
"rank": 1,
}

weaviateObject, err := client.Data().Creator().
WithClassName(class).
WithProperties(dataSchema).
WithVector(vector).
Do(context.Background())
if err != nil {
return "", err
Expand All @@ -162,21 +181,45 @@ func CreatePromptObject(vector []float32, prompt string, code string, class stri
return string(weaviateObject.Object.ID), nil
}

func CreateResponseObject(vector []float32, response string, class string) (string, error) {
func UpdateRankPrompt(prompt map[string]interface{}) error {
id, ok := prompt["id"].(string)
if !ok {
return errors.New("ID not found in request body")
}

client, err := loadClient()
if err != nil {
return err
}

err = client.Data().Updater().
WithMerge().
WithID(id).
WithClassName("Prompt").
WithProperties(map[string]interface{}{
"rank": 2,
}).
Do(context.Background())
if err != nil {
return err
}

return nil
}

func CreateResponseObject(response string, class string) (string, error) {
client, err := loadClient()
if err != nil {
return "", err
}

dataSchema := map[string]interface{}{
strings.ToLower(class): response,
"rank": 1,
}

weaviateObject, err := client.Data().Creator().
WithClassName(class).
WithProperties(dataSchema).
WithVector(vector).
Do(context.Background())

if err != nil {
Expand Down Expand Up @@ -250,6 +293,126 @@ func RetrievePromptCount(code string) (int, error) {
return int(countFloat), nil
}

func RetrieveResponse(code string) (string, error) {

client, err := loadClient()
if err != nil {
return "", err
}

fields := []graphql.Field{
{Name: "instruct"},
{Name: "rank"},
{Name: "hasResponse", Fields: []graphql.Field{
{Name: "... on Response", Fields: []graphql.Field{
{Name: "response"},
}},
}},
}

where := filters.Where().
WithPath([]string{"code"}).
WithOperator(filters.Like).
WithValueText(code)

byRankDesc := graphql.Sort{
Path: []string{"rank"}, Order: graphql.Desc,
}

ctx := context.Background()
result, err := client.GraphQL().Get().
WithClassName("Prompt").
WithSort(byRankDesc).
WithFields(fields...).
WithWhere(where).
Do(ctx)
if err != nil {
panic(err)
}

log.Printf("result= %v\n", result)

getPrompt, ok := result.Data["Get"].(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: 'Get' field not found or not a map")
}

promptData, ok := getPrompt["Prompt"].([]interface{})
if !ok || len(promptData) == 0 {
return "", errors.New("unexpected response format: 'Prompt' field not found or empty list")
}

// Initialize variables to track the prompt with the highest rank
var highestRank int
var highestRankPrompts []map[string]interface{}

// Iterate through each prompt to find the one with the highest rank
for _, prompt := range promptData {
promptMap, ok := prompt.(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: prompt data is not a map")
}

rankInterface, ok := promptMap["rank"]
if !ok {
return "", errors.New("rank field not found in prompt data")
}

rank, ok := rankInterface.(float64)
if !ok {
return "", errors.New("rank field is not a number")
}

// Convert float64 to int
rankInt := int(rank)

if rankInt > highestRank {
highestRank = rankInt
highestRankPrompts = []map[string]interface{}{promptMap}
} else if rankInt == highestRank {
highestRankPrompts = append(highestRankPrompts, promptMap)
}
}

// If there are prompts with the same highest rank, select one randomly
if len(highestRankPrompts) > 0 {
rand.Seed(time.Now().UnixNano())
randomIndex := rand.Intn(len(highestRankPrompts))
selectedPrompt := highestRankPrompts[randomIndex]

hasResponse, ok := selectedPrompt["hasResponse"].([]interface{})
if !ok || len(hasResponse) == 0 {
return "", errors.New("hasResponse field not found in prompt data or empty list")
}

firstResponseMap, ok := hasResponse[0].(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: response data is not a map")
}

response, ok := firstResponseMap["response"].(string)
if !ok {
return "", errors.New("response field not found in response data or not a string")
}

jsonData, err := json.Marshal(response)
if err != nil {
fmt.Println("Error:", err)
return "", err
}

log.Printf("Selected Response: %v\n", response)

// Add a newline character to the end of the string
jsonDataWithNewline := append(jsonData, '\n')

return string(jsonDataWithNewline), nil
}

return "", errors.New("no prompt found")

}

func CreateObject(vector []float32, body string, class string) error {
client, err := loadClient()
if err != nil {
Expand All @@ -275,8 +438,8 @@ func CreateObject(vector []float32, body string, class string) error {

func loadClient() (*weaviate.Client, error) {
cfg := weaviate.Config{
Host: os.Getenv("WEAVIATE_HOST"), // Replace with your endpoint
Scheme: "http",
Host: os.Getenv("WEAVIATE_HOST"),
Scheme: os.Getenv("WEAVIATE_SCHEME"),
AuthConfig: auth.ApiKey{Value: os.Getenv("WEAVIATE_KEY")},
}

Expand Down

0 comments on commit cb4c004

Please sign in to comment.