Skip to content

Commit

Permalink
Merge pull request #884 from gibmat/refactor-wim-driver-injection
Browse files Browse the repository at this point in the history
Refactor code for injecting drivers into a library for easier shared use
  • Loading branch information
stgraber authored Oct 28, 2024
2 parents bf4232a + 8b651a1 commit 2c5ab32
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 243 deletions.
251 changes: 8 additions & 243 deletions distrobuilder/main_repack-windows.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"context"
"encoding/hex"
"errors"
Expand All @@ -11,7 +10,6 @@ import (
"os/exec"
"path/filepath"
"slices"
"strconv"
"strings"

"github.com/flosch/pongo2/v4"
Expand Down Expand Up @@ -264,12 +262,14 @@ func (c *cmdRepackWindows) run(cmd *cobra.Command, args []string, overlayDir str
}
}

bootWimInfo, err := c.getWimInfo(bootWim)
repackUtil := windows.NewRepackUtil(c.global.flagCacheDir, c.global.ctx, c.global.logger)

bootWimInfo, err := repackUtil.GetWimInfo(bootWim)
if err != nil {
return fmt.Errorf("Failed to get boot wim info: %w", err)
}

installWimInfo, err := c.getWimInfo(installWim)
installWimInfo, err := repackUtil.GetWimInfo(installWim)
if err != nil {
return fmt.Errorf("Failed to get install wim info: %w", err)
}
Expand All @@ -290,14 +290,16 @@ func (c *cmdRepackWindows) run(cmd *cobra.Command, args []string, overlayDir str
return errors.New("Failed to detect Windows architecture. Please provide the architecture using the --windows-arch flag")
}

repackUtil.SetWindowsVersionArchitecture(c.flagWindowsVersion, c.flagWindowsArchitecture)

// This injects the drivers into the installation process
err = c.modifyWim(bootWim, bootWimInfo)
err = repackUtil.InjectDriversIntoWim(bootWim, bootWimInfo, filepath.Join(c.global.flagCacheDir, "drivers"))
if err != nil {
return fmt.Errorf("Failed to modify wim %q: %w", filepath.Base(bootWim), err)
}

// This injects the drivers into the final OS
err = c.modifyWim(installWim, installWimInfo)
err = repackUtil.InjectDriversIntoWim(installWim, installWimInfo, filepath.Join(c.global.flagCacheDir, "drivers"))
if err != nil {
return fmt.Errorf("Failed to modify wim %q: %w", filepath.Base(installWim), err)
}
Expand Down Expand Up @@ -355,85 +357,6 @@ func (c *cmdRepackWindows) run(cmd *cobra.Command, args []string, overlayDir str
return nil
}

func (c *cmdRepackWindows) getWimInfo(wimFile string) (info windows.WimInfo, err error) {
wimName := filepath.Base(wimFile)
var buf bytes.Buffer
err = shared.RunCommand(c.global.ctx, nil, &buf, "wimlib-imagex", "info", wimFile)
if err != nil {
err = fmt.Errorf("Failed to retrieve wim %q information: %w", wimName, err)
return
}

info, err = windows.ParseWimInfo(&buf)
if err != nil {
err = fmt.Errorf("Failed to parse wim info %s: %w", wimFile, err)
return
}

return
}

func (c *cmdRepackWindows) modifyWim(wimFile string, info windows.WimInfo) (err error) {
wimName := filepath.Base(wimFile)
// Injects the drivers
for idx := 1; idx <= info.ImageCount(); idx++ {
name := info.Name(idx)
err = c.modifyWimIndex(wimFile, idx, name)
if err != nil {
return fmt.Errorf("Failed to modify index %d=%s of %q: %w", idx, name, wimName, err)
}
}
return
}

func (c *cmdRepackWindows) modifyWimIndex(wimFile string, index int, name string) error {
wimIndex := strconv.Itoa(index)
wimPath := filepath.Join(c.global.flagCacheDir, "wim", wimIndex)
wimName := filepath.Base(wimFile)
logger := c.global.logger.WithFields(logrus.Fields{"wim": wimName, "idx": wimIndex + ":" + name})
if !incus.PathExists(wimPath) {
err := os.MkdirAll(wimPath, 0755)
if err != nil {
return fmt.Errorf("Failed to create directory %q: %w", wimPath, err)
}
}

success := false
logger.Info("Mounting")
// Mount wim file
err := shared.RunCommand(c.global.ctx, nil, nil, "wimlib-imagex", "mountrw", wimFile, wimIndex, wimPath, "--allow-other")
if err != nil {
return fmt.Errorf("Failed to mount %q: %w", wimName, err)
}

defer func() {
if !success {
_ = shared.RunCommand(c.global.ctx, nil, nil, "wimlib-imagex", "unmount", wimPath)
}
}()

dirs, err := c.getWindowsDirectories(wimPath)
if err != nil {
return fmt.Errorf("Failed to get required windows directories: %w", err)
}

logger.Info("Modifying")
// Create registry entries and copy files
err = c.injectDrivers(dirs["inf"], dirs["drivers"], dirs["filerepository"], dirs["config"])
if err != nil {
return fmt.Errorf("Failed to inject drivers: %w", err)
}

logger.Info("Unmounting")
err = shared.RunCommand(c.global.ctx, nil, nil, "wimlib-imagex", "unmount", wimPath, "--commit")
if err != nil {
return fmt.Errorf("Failed to unmount WIM image %q: %w", wimName, err)
}

success = true
return nil
}

func (c *cmdRepackWindows) checkDependencies() error {
dependencies := []string{"hivexregedit", "rsync", "wimlib-imagex"}

Expand All @@ -453,164 +376,6 @@ func (c *cmdRepackWindows) checkDependencies() error {
return nil
}

func (c *cmdRepackWindows) getWindowsDirectories(wimPath string) (dirs map[string]string, err error) {
dirs = map[string]string{}
dirs["inf"], err = shared.FindFirstMatch(wimPath, "windows", "inf")
if err != nil {
return nil, fmt.Errorf("Failed to determine windows/inf path: %w", err)
}

dirs["config"], err = shared.FindFirstMatch(wimPath, "windows", "system32", "config")
if err != nil {
return nil, fmt.Errorf("Failed to determine windows/system32/config path: %w", err)
}

dirs["drivers"], err = shared.FindFirstMatch(wimPath, "windows", "system32", "drivers")
if err != nil {
return nil, fmt.Errorf("Failed to determine windows/system32/drivers path: %w", err)
}

dirs["filerepository"], err = shared.FindFirstMatch(wimPath, "windows", "system32", "driverstore", "filerepository")
if err != nil {
return nil, fmt.Errorf("Failed to determine windows/system32/driverstore/filerepository path: %w", err)
}

return
}

func (c *cmdRepackWindows) injectDrivers(infDir, driversDir, filerepositoryDir, configDir string) error {
logger := c.global.logger

driverPath := filepath.Join(c.global.flagCacheDir, "drivers")
i := 0

driversRegistry := "Windows Registry Editor Version 5.00"
systemRegistry := "Windows Registry Editor Version 5.00"
softwareRegistry := "Windows Registry Editor Version 5.00"
for driverName, driverInfo := range windows.Drivers {
logger.WithField("driver", driverName).Debug("Injecting driver")
infFilename := fmt.Sprintf("oem-virtio-incus%d.inf", i)
sourceDir := filepath.Join(driverPath, driverName, c.flagWindowsVersion, c.flagWindowsArchitecture)
targetBaseDir := filepath.Join(filerepositoryDir, driverInfo.PackageName)
if !incus.PathExists(targetBaseDir) {
err := os.MkdirAll(targetBaseDir, 0755)
if err != nil {
logger.Error(err)
return err
}
}

for ext, dir := range map[string]string{"inf": infDir, "cat": driversDir, "dll": driversDir, "exe": driversDir, "sys": driversDir} {
sourceMatches, err := shared.FindAllMatches(sourceDir, fmt.Sprintf("*.%s", ext))
if err != nil {
logger.Debugf("failed to find first match %q %q", driverName, ext)
continue
}

for _, sourcePath := range sourceMatches {
targetName := filepath.Base(sourcePath)
targetPath := filepath.Join(targetBaseDir, targetName)
if err = shared.Copy(sourcePath, targetPath); err != nil {
return err
}

if ext == "cat" || ext == "exe" {
continue
} else if ext == "inf" {
targetName = infFilename
}

targetPath = filepath.Join(dir, targetName)
if err = shared.Copy(sourcePath, targetPath); err != nil {
return err
}
}
}

classGuid, err := windows.ParseDriverClassGuid(driverName, filepath.Join(infDir, infFilename))
if err != nil {
return err
}

ctx := pongo2.Context{
"infFile": infFilename,
"packageName": driverInfo.PackageName,
"driverName": driverName,
"classGuid": classGuid,
}

// Update Windows DRIVERS registry
if driverInfo.DriversRegistry != "" {
tpl, err := pongo2.FromString(driverInfo.DriversRegistry)
if err != nil {
return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err)
}

out, err := tpl.Execute(ctx)
if err != nil {
return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err)
}

driversRegistry = fmt.Sprintf("%s\n\n%s", driversRegistry, out)
}

// Update Windows SYSTEM registry
if driverInfo.SystemRegistry != "" {
tpl, err := pongo2.FromString(driverInfo.SystemRegistry)
if err != nil {
return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err)
}

out, err := tpl.Execute(ctx)
if err != nil {
return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err)
}

systemRegistry = fmt.Sprintf("%s\n\n%s", systemRegistry, out)
}

// Update Windows SOFTWARE registry
if driverInfo.SoftwareRegistry != "" {
tpl, err := pongo2.FromString(driverInfo.SoftwareRegistry)
if err != nil {
return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err)
}

out, err := tpl.Execute(ctx)
if err != nil {
return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err)
}

softwareRegistry = fmt.Sprintf("%s\n\n%s", softwareRegistry, out)
}

i++
}

logger.WithField("hivefile", "DRIVERS").Debug("Updating Windows registry")

err := shared.RunCommand(c.global.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(configDir, "DRIVERS"))
if err != nil {
return fmt.Errorf("Failed to edit Windows DRIVERS registry: %w", err)
}

logger.WithField("hivefile", "SYSTEM").Debug("Updating Windows registry")

err = shared.RunCommand(c.global.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(configDir, "SYSTEM"))
if err != nil {
return fmt.Errorf("Failed to edit Windows SYSTEM registry: %w", err)
}

logger.WithField("hivefile", "SOFTWARE").Debug("Updating Windows registry")

err = shared.RunCommand(c.global.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(configDir, "SOFTWARE"))
if err != nil {
return fmt.Errorf("Failed to edit Windows SOFTWARE registry: %w", err)
}

return nil
}

// toHex is a pongo2 filter which converts the provided value to a hex value understood by the Windows registry.
func toHex(in *pongo2.Value, param *pongo2.Value) (out *pongo2.Value, err *pongo2.Error) {
dst := make([]byte, hex.EncodedLen(len(in.String())))
Expand Down
Loading

0 comments on commit 2c5ab32

Please sign in to comment.