Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-9370: change param prefix check in module generation #4654

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 33 additions & 57 deletions cli/module_generate/scripts/generate_stubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
//go:embed tmpl-module
var goTmpl string

// typePrefixes lists possible prefixes before function parameter and return types.
var typePrefixes = []string{"*", "[]*", "[]", "chan "}

// getClientCode grabs client.go code of component type.
func getClientCode(module modulegen.ModuleInputs) (string, error) {
url := fmt.Sprintf("https://raw.githubusercontent.com/viamrobotics/rdk/refs/tags/v%s/%ss/%s/client.go",
Expand Down Expand Up @@ -90,7 +93,6 @@ func setGoModuleTemplate(clientCode string, module modulegen.ModuleInputs) (*mod
if funcDecl, ok := n.(*ast.FuncDecl); ok {
name, receiver, args, returns := parseFunctionSignature(
module.ResourceSubtype,
module.ResourceSubtypePascal,
module.ModuleCamel+module.ModelPascal,
funcDecl,
)
Expand Down Expand Up @@ -118,28 +120,41 @@ func setGoModuleTemplate(clientCode string, module modulegen.ModuleInputs) (*mod
return &goTmplInputs, nil
}

// formatType outputs typeExpr as readable string.
func formatType(typeExpr ast.Expr) string {
// formatType formats typeExpr as readable string with correct attribution if applicable.
func formatType(typeExpr ast.Expr, resourceSubtype string) string {
var buf bytes.Buffer
err := printer.Fprint(&buf, token.NewFileSet(), typeExpr)
if err != nil {
return fmt.Sprintf("Error formatting type: %v", err)
}
return buf.String()
}
typeString := buf.String()

func handleMapType(str, resourceSubtype string) string {
endStr := strings.Index(str, "]")
keyType := strings.TrimSpace(str[4:endStr])
valueType := strings.TrimSpace(str[endStr+1:])
if unicode.IsUpper(rune(keyType[0])) {
keyType = fmt.Sprintf("%s.%s", resourceSubtype, keyType)
// checkUpper adds "<resourceSubtype>." to the type if type is capitalized after prefix.
checkUpper := func(str, prefix string) string {
prefixLen := len(prefix)
if unicode.IsUpper(rune(str[prefixLen])) {
return fmt.Sprintf("%s%s.%s", prefix, resourceSubtype, str[prefixLen:])
}
return str
}
for _, prefix := range typePrefixes {
if strings.HasPrefix(typeString, prefix) {
return checkUpper(typeString, prefix)
}
}
if unicode.IsUpper(rune(valueType[0])) {
valueType = fmt.Sprintf("%s.%s", resourceSubtype, valueType)
if strings.HasPrefix(typeString, "map[") {
endStr := strings.Index(typeString, "]")
keyType := strings.TrimSpace(typeString[4:endStr])
valueType := strings.TrimSpace(typeString[endStr+1:])
if unicode.IsUpper(rune(keyType[0])) {
keyType = checkUpper(keyType, "")
}
if unicode.IsUpper(rune(valueType[0])) {
valueType = checkUpper(valueType, "")
}
return fmt.Sprintf("map[%s]%s", keyType, valueType)
}

return fmt.Sprintf("map[%s]%s", keyType, valueType)
return checkUpper(typeString, "")
}

func formatStruct(typeSpec *ast.TypeSpec, modelType string) string {
Expand All @@ -153,8 +168,7 @@ func formatStruct(typeSpec *ast.TypeSpec, modelType string) string {

// parseFunctionSignature parses function declarations into the function name, the arguments, and the return types.
func parseFunctionSignature(
resourceSubtype,
resourceSubtypePascal string,
resourceSubtype string,
modelType string,
funcDecl *ast.FuncDecl,
) (name, receiver, args string, returns []string) {
Expand Down Expand Up @@ -188,20 +202,7 @@ func parseFunctionSignature(
var params []string
if funcDecl.Type.Params != nil {
for _, param := range funcDecl.Type.Params.List {
paramType := formatType(param.Type)

// Check if `paramType` is a type that is capitalized.
// If so, attribute the type to <resourceSubtype>.
switch {
case unicode.IsUpper(rune(paramType[0])):
paramType = fmt.Sprintf("%s.%s", resourceSubtype, paramType)
// IF `paramType` has a prefix, check if type is capitalized after prefix.
case strings.HasPrefix(paramType, "[]") && unicode.IsUpper(rune(paramType[2])):
paramType = fmt.Sprintf("[]%s.%s", resourceSubtype, paramType[2:])
case strings.HasPrefix(paramType, "chan ") && unicode.IsUpper(rune(paramType[5])):
paramType = fmt.Sprintf("chan %s.%s", resourceSubtype, paramType[5:])
}

paramType := formatType(param.Type, resourceSubtype)
for _, name := range param.Names {
params = append(params, name.Name+" "+paramType)
}
Expand All @@ -211,32 +212,7 @@ func parseFunctionSignature(
// Return types
if funcDecl.Type.Results != nil {
for _, result := range funcDecl.Type.Results.List {
str := formatType(result.Type)
isPointer := false
isMapPointer := false
if str[0] == '*' {
str = str[1:]
isPointer = true
} else if str[2] == '*' {
str = str[3:]
isMapPointer = true
}

switch {
case strings.HasPrefix(str, "map["):
str = handleMapType(str, resourceSubtype)
case unicode.IsUpper(rune(str[0])):
str = fmt.Sprintf("%s.%s", resourceSubtype, str)
case strings.HasPrefix(str, "[]") && unicode.IsUpper(rune(str[2])):
str = fmt.Sprintf("[]%s.%s", resourceSubtype, str[2:])
case str == resourceSubtypePascal:
str = fmt.Sprintf("%s.%s", resourceSubtype, resourceSubtypePascal)
}
if isPointer {
str = fmt.Sprintf("*%s", str)
} else if isMapPointer {
str = fmt.Sprintf("[]*%s", str)
}
str := formatType(result.Type, resourceSubtype)
// fixing vision service package imports
if strings.Contains(str, "vision.Object") {
str = strings.ReplaceAll(str, "vision.Object", "vis.Object")
Expand Down
23 changes: 23 additions & 0 deletions cli/module_generate/scripts/generate_stubs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package scripts

import (
"fmt"
"go/ast"
"testing"

"go.viam.com/test"
)

func TestGenerateStubs(t *testing.T) {
t.Run("test type formatting", func(t *testing.T) {
subtype := "resource"
testType := "Test"

paramType := formatType(ast.NewIdent(testType), subtype)
test.That(t, paramType, test.ShouldEqual, fmt.Sprintf("%s.%s", subtype, testType))
for _, prefix := range typePrefixes {
paramType := formatType(ast.NewIdent(prefix+testType), subtype)
test.That(t, paramType, test.ShouldEqual, fmt.Sprintf("%s%s.%s", prefix, subtype, testType))
}
})
}
Loading