From 405e3967579c9466009e8098dfe1abe1d4d843f6 Mon Sep 17 00:00:00 2001 From: JUN JIE NAN Date: Mon, 13 May 2024 14:34:58 +0800 Subject: [PATCH] Refactor copy driver files Signed-off-by: JUN JIE NAN --- distrobuilder/main_repack-windows.go | 189 +++++++++++++-------------- shared/util.go | 11 +- 2 files changed, 99 insertions(+), 101 deletions(-) diff --git a/distrobuilder/main_repack-windows.go b/distrobuilder/main_repack-windows.go index c5dfb717..4e16dd09 100644 --- a/distrobuilder/main_repack-windows.go +++ b/distrobuilder/main_repack-windows.go @@ -364,21 +364,21 @@ func (c *cmdRepackWindows) modifyWim(wimFile string, info shared.WimInfo) (err e wimName := filepath.Base(wimFile) // Injects the drivers for idx := 1; idx <= info.ImageCount(); idx++ { - name := info.Name(idx) - err = c.modifyWimIndex(wimFile, idx, name) + ext := info.Name(idx) + err = c.modifyWimIndex(wimFile, idx, ext) if err != nil { - return fmt.Errorf("Failed to modify index %d=%s of %q: %w", idx, name, wimName, err) + return fmt.Errorf("Failed to modify index %d=%s of %q: %w", idx, ext, wimName, err) } } return } -func (c *cmdRepackWindows) modifyWimIndex(wimFile string, index int, name string) error { +func (c *cmdRepackWindows) modifyWimIndex(wimFile string, index int, ext string) error { wimIndex := strconv.Itoa(index) wimPath := filepath.Join(c.global.flagCacheDir, "wim", wimIndex) wimName := filepath.Base(wimFile) logger := c.global.logger.WithFields(logrus.Fields{"wim": strings.TrimSuffix(wimName, ".wim"), - "idx": wimIndex + ":" + name}) + "idx": wimIndex + ":" + ext}) if !incus.PathExists(wimPath) { err := os.MkdirAll(wimPath, 0755) if err != nil { @@ -407,7 +407,7 @@ func (c *cmdRepackWindows) modifyWimIndex(wimFile string, index int, name string logger.Info("Modifying") // Create registry entries and copy files - err = c.injectDrivers(dirs) + err = c.injectDrivers(dirs["inf"], dirs["drivers"], dirs["filerepository"], dirs["config"]) if err != nil { return fmt.Errorf("Failed to inject drivers: %w", err) } @@ -460,7 +460,7 @@ func (c *cmdRepackWindows) getWindowsDirectories(wimPath string) (dirs map[strin return } -func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error { +func (c *cmdRepackWindows) injectDrivers(infDir, driversDir, filerepositoryDir, configDir string) error { logger := c.global.logger driverPath := filepath.Join(c.global.flagCacheDir, "drivers") @@ -469,132 +469,89 @@ func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error { driversRegistry := "Windows Registry Editor Version 5.00" systemRegistry := "Windows Registry Editor Version 5.00" softwareRegistry := "Windows Registry Editor Version 5.00" - - for driver, info := range windows.Drivers { - logger.WithField("driver", driver).Debug("Injecting driver") - - ctx := pongo2.Context{ - "infFile": fmt.Sprintf("oem%d.inf", i), - "packageName": info.PackageName, - "driverName": driver, - } - - sourceDir := filepath.Join(driverPath, driver, c.flagWindowsVersion, c.flagWindowsArchitecture) - targetBasePath := filepath.Join(dirs["filerepository"], info.PackageName) - - if !incus.PathExists(targetBasePath) { - err := os.MkdirAll(targetBasePath, 0755) + for driverName, driverInfo := range windows.Drivers { + logger.WithField("driver", driverName).Debug("Injecting driver") + infFilename := fmt.Sprintf("oem%d.inf", i) + sourceDir := filepath.Join(driverPath, driverName, c.flagWindowsVersion, c.flagWindowsArchitecture) + targetBaseDir := filepath.Join(filerepositoryDir, driverInfo.PackageName) + if !incus.PathExists(targetBaseDir) { + err := os.MkdirAll(targetBaseDir, 0755) if err != nil { - return fmt.Errorf("Failed to create directory %q: %w", targetBasePath, err) + logger.Error(err) + return err } } - err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { - ext := filepath.Ext(path) - targetPath := filepath.Join(targetBasePath, filepath.Base(path)) - - // Copy driver files - if slices.Contains([]string{".cat", ".dll", ".inf", ".sys"}, ext) { - logger.WithFields(logrus.Fields{"src": path, "dest": targetPath}).Debug("Copying file") - - err := shared.Copy(path, targetPath) - if err != nil { - return fmt.Errorf("Failed to copy %q to %q: %w", filepath.Base(path), targetPath, err) - } + for ext, dir := range map[string]string{"inf": infDir, "cat": driversDir, "dll": driversDir, "sys": driversDir} { + driverPath, err := shared.FindFirstMatch(sourceDir, fmt.Sprintf("*.%s", ext)) + if err != nil { + logger.Debugf("failed to find first match %q %q", driverName, ext) + continue } - // Copy .inf file - if ext == ".inf" { - target := filepath.Join(dirs["inf"], ctx["infFile"].(string)) - logger.WithFields(logrus.Fields{"src": path, "dest": target}).Debug("Copying file") - - err := shared.Copy(path, target) - if err != nil { - return fmt.Errorf("Failed to copy %q to %q: %w", filepath.Base(path), target, err) - } - - // Retrieve the ClassGuid which is needed for the Windows registry entries. - file, err := os.Open(path) - if err != nil { - return fmt.Errorf("Failed to open %s: %w", path, err) - } - - re := regexp.MustCompile(`(?i)^ClassGuid[ ]*=[ ]*(.+)$`) - scanner := bufio.NewScanner(file) - - for scanner.Scan() { - matches := re.FindStringSubmatch(scanner.Text()) - - if len(matches) > 0 { - ctx["classGuid"] = strings.TrimSpace(matches[1]) - } - } - - file.Close() - - _, ok := ctx["classGuid"] - if !ok { - return fmt.Errorf("Failed to determine classGUID for driver %q", driver) - } + targetName := filepath.Base(driverPath) + c.targetCopy(driverPath, filepath.Join(targetBaseDir, targetName)) + if ext == "cat" { + continue + } else if ext == "inf" { + targetName = infFilename } - // Copy .sys and .dll files - if ext == ".dll" || ext == ".sys" { - target := filepath.Join(dirs["drivers"], filepath.Base(path)) - logger.WithFields(logrus.Fields{"src": path, "dest": target}).Debug("Copying file") - - err := shared.Copy(path, target) - if err != nil { - return fmt.Errorf("Failed to copy %q to %q: %w", filepath.Base(path), target, err) - } - } + c.targetCopy(driverPath, filepath.Join(dir, targetName)) + } - return nil - }) + classGuid, err := parseDriverClassGuid(driverName, filepath.Join(infDir, infFilename)) if err != nil { - return fmt.Errorf("Failed to copy driver files: %w", err) + return err + } + + ctx := pongo2.Context{ + "infFile": infFilename, + "packageName": driverInfo.PackageName, + "driverName": driverName, + "classGuid": classGuid, } // Update Windows DRIVERS registry - if info.DriversRegistry != "" { - tpl, err := pongo2.FromString(info.DriversRegistry) + if driverInfo.DriversRegistry != "" { + tpl, err := pongo2.FromString(driverInfo.DriversRegistry) if err != nil { - return fmt.Errorf("Failed to parse template for driver %q: %w", driver, err) + return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) } out, err := tpl.Execute(ctx) if err != nil { - return fmt.Errorf("Failed to render template for driver %q: %w", driver, err) + return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) } driversRegistry = fmt.Sprintf("%s\n\n%s", driversRegistry, out) } // Update Windows SYSTEM registry - if info.SystemRegistry != "" { - tpl, err := pongo2.FromString(info.SystemRegistry) + if driverInfo.SystemRegistry != "" { + tpl, err := pongo2.FromString(driverInfo.SystemRegistry) if err != nil { - return fmt.Errorf("Failed to parse template for driver %q: %w", driver, err) + return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) } out, err := tpl.Execute(ctx) if err != nil { - return fmt.Errorf("Failed to render template for driver %q: %w", driver, err) + return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) } systemRegistry = fmt.Sprintf("%s\n\n%s", systemRegistry, out) } // Update Windows SOFTWARE registry - if info.SoftwareRegistry != "" { - tpl, err := pongo2.FromString(info.SoftwareRegistry) + if driverInfo.SoftwareRegistry != "" { + tpl, err := pongo2.FromString(driverInfo.SoftwareRegistry) if err != nil { - return fmt.Errorf("Failed to parse template for driver %q: %w", driver, err) + return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) } out, err := tpl.Execute(ctx) if err != nil { - return fmt.Errorf("Failed to render template for driver %q: %w", driver, err) + return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) } softwareRegistry = fmt.Sprintf("%s\n\n%s", softwareRegistry, out) @@ -605,21 +562,21 @@ func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error { logger.WithField("hivefile", "DRIVERS").Debug("Updating Windows registry") - err := shared.RunCommand(c.global.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(dirs["config"], "DRIVERS")) + err := shared.RunCommand(c.global.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(configDir, "DRIVERS")) if err != nil { return fmt.Errorf("Failed to edit Windows DRIVERS registry: %w", err) } logger.WithField("hivefile", "SYSTEM").Debug("Updating Windows registry") - err = shared.RunCommand(c.global.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(dirs["config"], "SYSTEM")) + err = shared.RunCommand(c.global.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(configDir, "SYSTEM")) if err != nil { return fmt.Errorf("Failed to edit Windows SYSTEM registry: %w", err) } logger.WithField("hivefile", "SOFTWARE").Debug("Updating Windows registry") - err = shared.RunCommand(c.global.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(dirs["config"], "SOFTWARE")) + err = shared.RunCommand(c.global.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(configDir, "SOFTWARE")) if err != nil { return fmt.Errorf("Failed to edit Windows SOFTWARE registry: %w", err) } @@ -627,6 +584,46 @@ func (c *cmdRepackWindows) injectDrivers(dirs map[string]string) error { return nil } +func (c *cmdRepackWindows) targetCopy(sourcePath, targetPath string) { + logger := c.global.logger.WithFields(logrus.Fields{ + "sourcePath": sourcePath, + "targetPath": targetPath, + }) + + logger.Debug("Copying file") + err := shared.Copy(sourcePath, targetPath) + if err != nil { + logger.Error(err) + } +} + +func parseDriverClassGuid(driverName, infPath string) (classGuid string, err error) { + // Retrieve the ClassGuid which is needed for the Windows registry entries. + file, err := os.Open(infPath) + if err != nil { + err = fmt.Errorf("Failed to open driver %s inf %s: %w", driverName, infPath, err) + return + } + + defer func() { + file.Close() + if classGuid == "" { + err = fmt.Errorf("Failed to parse driver %s classGuid %s", driverName, infPath) + } + }() + re := regexp.MustCompile(`(?i)^ClassGuid[ ]*=[ ]*(.+)$`) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + matches := re.FindStringSubmatch(scanner.Text()) + if len(matches) > 1 { + classGuid = strings.TrimSpace(matches[1]) + return + } + } + + return +} + // toHex is a pongo2 filter which converts the provided value to a hex value understood by the Windows registry. func toHex(in *pongo2.Value, param *pongo2.Value) (out *pongo2.Value, err *pongo2.Error) { dst := make([]byte, hex.EncodedLen(len(in.String()))) diff --git a/shared/util.go b/shared/util.go index 9860a474..3feff32c 100644 --- a/shared/util.go +++ b/shared/util.go @@ -54,10 +54,10 @@ func CaseInsensitive(s string) (pattern string) { b := s2[i : i+1] if a != b { pattern += "[" + a + b + "]" - } else if a != "/" { - pattern += "\\" + a + } else if strings.Contains("?*[]/", a) { + pattern += a } else { - pattern += "/" + pattern += "\\" + a } } return @@ -70,13 +70,14 @@ func FindFirstMatch(dir string, elem ...string) (found string, err error) { names = append(names, CaseInsensitive(name)) } - matches, err := filepath.Glob(filepath.Join(names...)) + pattern := filepath.Join(names...) + matches, err := filepath.Glob(pattern) if err != nil { return } if len(matches) == 0 { - err = fmt.Errorf("No match found") + err = fmt.Errorf("No match found %s", pattern) return }