Skip to content

Commit

Permalink
Merge pull request #845 from nanjj/copyDriverFiles
Browse files Browse the repository at this point in the history
Refactor copy driver files
  • Loading branch information
stgraber authored May 13, 2024
2 parents 1713fa9 + 14721a3 commit 57ef557
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 94 deletions.
167 changes: 78 additions & 89 deletions distrobuilder/main_repack-windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ func (c *cmdRepackWindows) modifyWimIndex(wimFile string, index int, name string

logger.Info("Modifying")
// Create registry entries and copy files
err = c.injectDrivers(dirs)
err = c.injectDrivers(dirs["inf"], dirs["drivers"], dirs["filerepository"], dirs["config"])
if err != nil {
return fmt.Errorf("Failed to inject drivers: %w", err)
}
Expand Down Expand Up @@ -460,7 +460,7 @@ func (c *cmdRepackWindows) getWindowsDirectories(wimPath string) (dirs map[strin
return
}

func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error {
func (c *cmdRepackWindows) injectDrivers(infDir, driversDir, filerepositoryDir, configDir string) error {
logger := c.global.logger

driverPath := filepath.Join(c.global.flagCacheDir, "drivers")
Expand All @@ -469,132 +469,94 @@ func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error {
driversRegistry := "Windows Registry Editor Version 5.00"
systemRegistry := "Windows Registry Editor Version 5.00"
softwareRegistry := "Windows Registry Editor Version 5.00"

for driver, info := range windows.Drivers {
logger.WithField("driver", driver).Debug("Injecting driver")

ctx := pongo2.Context{
"infFile": fmt.Sprintf("oem%d.inf", i),
"packageName": info.PackageName,
"driverName": driver,
}

sourceDir := filepath.Join(driverPath, driver, c.flagWindowsVersion, c.flagWindowsArchitecture)
targetBasePath := filepath.Join(dirs["filerepository"], info.PackageName)

if !incus.PathExists(targetBasePath) {
err := os.MkdirAll(targetBasePath, 0755)
for driverName, driverInfo := range windows.Drivers {
logger.WithField("driver", driverName).Debug("Injecting driver")
infFilename := fmt.Sprintf("oem%d.inf", i)
sourceDir := filepath.Join(driverPath, driverName, c.flagWindowsVersion, c.flagWindowsArchitecture)
targetBaseDir := filepath.Join(filerepositoryDir, driverInfo.PackageName)
if !incus.PathExists(targetBaseDir) {
err := os.MkdirAll(targetBaseDir, 0755)
if err != nil {
return fmt.Errorf("Failed to create directory %q: %w", targetBasePath, err)
logger.Error(err)
return err
}
}

err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
ext := filepath.Ext(path)
targetPath := filepath.Join(targetBasePath, filepath.Base(path))

// Copy driver files
if slices.Contains([]string{".cat", ".dll", ".inf", ".sys"}, ext) {
logger.WithFields(logrus.Fields{"src": path, "dest": targetPath}).Debug("Copying file")

err := shared.Copy(path, targetPath)
if err != nil {
return fmt.Errorf("Failed to copy %q to %q: %w", filepath.Base(path), targetPath, err)
}
for ext, dir := range map[string]string{"inf": infDir, "cat": driversDir, "dll": driversDir, "sys": driversDir} {
driverPath, err := shared.FindFirstMatch(sourceDir, fmt.Sprintf("*.%s", ext))
if err != nil {
logger.Debugf("failed to find first match %q %q", driverName, ext)
continue
}

// Copy .inf file
if ext == ".inf" {
target := filepath.Join(dirs["inf"], ctx["infFile"].(string))
logger.WithFields(logrus.Fields{"src": path, "dest": target}).Debug("Copying file")

err := shared.Copy(path, target)
if err != nil {
return fmt.Errorf("Failed to copy %q to %q: %w", filepath.Base(path), target, err)
}

// Retrieve the ClassGuid which is needed for the Windows registry entries.
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("Failed to open %s: %w", path, err)
}

re := regexp.MustCompile(`(?i)^ClassGuid[ ]*=[ ]*(.+)$`)
scanner := bufio.NewScanner(file)

for scanner.Scan() {
matches := re.FindStringSubmatch(scanner.Text())

if len(matches) > 0 {
ctx["classGuid"] = strings.TrimSpace(matches[1])
}
}

file.Close()

_, ok := ctx["classGuid"]
if !ok {
return fmt.Errorf("Failed to determine classGUID for driver %q", driver)
}
targetName := filepath.Base(driverPath)
if err = shared.Copy(driverPath, filepath.Join(targetBaseDir, targetName)); err != nil {
return err
}

// Copy .sys and .dll files
if ext == ".dll" || ext == ".sys" {
target := filepath.Join(dirs["drivers"], filepath.Base(path))
logger.WithFields(logrus.Fields{"src": path, "dest": target}).Debug("Copying file")
if ext == "cat" {
continue
} else if ext == "inf" {
targetName = infFilename
}

err := shared.Copy(path, target)
if err != nil {
return fmt.Errorf("Failed to copy %q to %q: %w", filepath.Base(path), target, err)
}
if err = shared.Copy(driverPath, filepath.Join(dir, targetName)); err != nil {
return err
}
}

return nil
})
classGuid, err := parseDriverClassGuid(driverName, filepath.Join(infDir, infFilename))
if err != nil {
return fmt.Errorf("Failed to copy driver files: %w", err)
return err
}

ctx := pongo2.Context{
"infFile": infFilename,
"packageName": driverInfo.PackageName,
"driverName": driverName,
"classGuid": classGuid,
}

// Update Windows DRIVERS registry
if info.DriversRegistry != "" {
tpl, err := pongo2.FromString(info.DriversRegistry)
if driverInfo.DriversRegistry != "" {
tpl, err := pongo2.FromString(driverInfo.DriversRegistry)
if err != nil {
return fmt.Errorf("Failed to parse template for driver %q: %w", driver, err)
return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err)
}

out, err := tpl.Execute(ctx)
if err != nil {
return fmt.Errorf("Failed to render template for driver %q: %w", driver, err)
return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err)
}

driversRegistry = fmt.Sprintf("%s\n\n%s", driversRegistry, out)
}

// Update Windows SYSTEM registry
if info.SystemRegistry != "" {
tpl, err := pongo2.FromString(info.SystemRegistry)
if driverInfo.SystemRegistry != "" {
tpl, err := pongo2.FromString(driverInfo.SystemRegistry)
if err != nil {
return fmt.Errorf("Failed to parse template for driver %q: %w", driver, err)
return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err)
}

out, err := tpl.Execute(ctx)
if err != nil {
return fmt.Errorf("Failed to render template for driver %q: %w", driver, err)
return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err)
}

systemRegistry = fmt.Sprintf("%s\n\n%s", systemRegistry, out)
}

// Update Windows SOFTWARE registry
if info.SoftwareRegistry != "" {
tpl, err := pongo2.FromString(info.SoftwareRegistry)
if driverInfo.SoftwareRegistry != "" {
tpl, err := pongo2.FromString(driverInfo.SoftwareRegistry)
if err != nil {
return fmt.Errorf("Failed to parse template for driver %q: %w", driver, err)
return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err)
}

out, err := tpl.Execute(ctx)
if err != nil {
return fmt.Errorf("Failed to render template for driver %q: %w", driver, err)
return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err)
}

softwareRegistry = fmt.Sprintf("%s\n\n%s", softwareRegistry, out)
Expand All @@ -605,28 +567,55 @@ func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error {

logger.WithField("hivefile", "DRIVERS").Debug("Updating Windows registry")

err := shared.RunCommand(c.global.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(dirs["config"], "DRIVERS"))
err := shared.RunCommand(c.global.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(configDir, "DRIVERS"))
if err != nil {
return fmt.Errorf("Failed to edit Windows DRIVERS registry: %w", err)
}

logger.WithField("hivefile", "SYSTEM").Debug("Updating Windows registry")

err = shared.RunCommand(c.global.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(dirs["config"], "SYSTEM"))
err = shared.RunCommand(c.global.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(configDir, "SYSTEM"))
if err != nil {
return fmt.Errorf("Failed to edit Windows SYSTEM registry: %w", err)
}

logger.WithField("hivefile", "SOFTWARE").Debug("Updating Windows registry")

err = shared.RunCommand(c.global.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(dirs["config"], "SOFTWARE"))
err = shared.RunCommand(c.global.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(configDir, "SOFTWARE"))
if err != nil {
return fmt.Errorf("Failed to edit Windows SOFTWARE registry: %w", err)
}

return nil
}

func parseDriverClassGuid(driverName, infPath string) (classGuid string, err error) {
// Retrieve the ClassGuid which is needed for the Windows registry entries.
file, err := os.Open(infPath)
if err != nil {
err = fmt.Errorf("Failed to open driver %s inf %s: %w", driverName, infPath, err)
return
}

defer func() {
file.Close()
if classGuid == "" {
err = fmt.Errorf("Failed to parse driver %s classGuid %s", driverName, infPath)
}
}()
re := regexp.MustCompile(`(?i)^ClassGuid[ ]*=[ ]*(.+)$`)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
matches := re.FindStringSubmatch(scanner.Text())
if len(matches) > 1 {
classGuid = strings.TrimSpace(matches[1])
return
}
}

return
}

// toHex is a pongo2 filter which converts the provided value to a hex value understood by the Windows registry.
func toHex(in *pongo2.Value, param *pongo2.Value) (out *pongo2.Value, err *pongo2.Error) {
dst := make([]byte, hex.EncodedLen(len(in.String())))
Expand Down
11 changes: 6 additions & 5 deletions shared/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ func CaseInsensitive(s string) (pattern string) {
b := s2[i : i+1]
if a != b {
pattern += "[" + a + b + "]"
} else if a != "/" {
pattern += "\\" + a
} else if strings.Contains("?*[]/", a) {
pattern += a
} else {
pattern += "/"
pattern += "\\" + a
}
}
return
Expand All @@ -70,13 +70,14 @@ func FindFirstMatch(dir string, elem ...string) (found string, err error) {
names = append(names, CaseInsensitive(name))
}

matches, err := filepath.Glob(filepath.Join(names...))
pattern := filepath.Join(names...)
matches, err := filepath.Glob(pattern)
if err != nil {
return
}

if len(matches) == 0 {
err = fmt.Errorf("No match found")
err = fmt.Errorf("No match found %s", pattern)
return
}

Expand Down

0 comments on commit 57ef557

Please sign in to comment.