Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code cleanup #32

Merged
merged 5 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 29 additions & 45 deletions core/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package albius
import (
"encoding/json"
"fmt"
"os/exec"
"strconv"
)

Expand All @@ -19,10 +18,12 @@ type Sector struct {
}

type Disk struct {
Path, Size, Model, Transport string
Label DiskLabel
LogicalSectorSize, PhysicalSectorSize, MaxPartitions int
Partitions []Partition
Path, Size, Model, Transport string
Label DiskLabel
LogicalSectorSize int `json:"logical-sector-size"`
PhysicalSectorSize int `json:"physical-sector-size"`
MaxPartitions int `json:"max-partitions"`
Partitions []Partition
}

func (disk *Disk) AvailableSectors() ([]Sector, error) {
Expand All @@ -31,14 +32,14 @@ func (disk *Disk) AvailableSectors() ([]Sector, error) {
for i, part := range disk.Partitions {
endInt, err := strconv.Atoi(part.End[:len(part.End)-3])
if err != nil {
return []Sector{}, fmt.Errorf("Failed to retrieve end position of partition: %s", err)
return []Sector{}, fmt.Errorf("failed to retrieve end position of partition: %s", err)
}

if i < len(disk.Partitions)-1 {
nextStart := disk.Partitions[i+1].Start
nextStartInt, err := strconv.Atoi(nextStart[:len(nextStart)-3])
if err != nil {
return []Sector{}, fmt.Errorf("Failed to retrieve start position of next partition: %s", err)
return []Sector{}, fmt.Errorf("failed to retrieve start position of next partition: %s", err)
}

if endInt != nextStartInt {
Expand All @@ -51,11 +52,11 @@ func (disk *Disk) AvailableSectors() ([]Sector, error) {
lastPartitionEndStr := disk.Partitions[len(disk.Partitions)-1].End
lastPartitionEnd, err := strconv.Atoi(lastPartitionEndStr[:len(lastPartitionEndStr)-3])
if err != nil {
return []Sector{}, fmt.Errorf("Failed to retrieve end position of last partition: %s", err)
return []Sector{}, fmt.Errorf("failed to retrieve end position of last partition: %s", err)
}
diskEnd, err := strconv.Atoi(disk.Size[:len(disk.Size)-3])
if err != nil {
return []Sector{}, fmt.Errorf("Failed to retrieve disk end")
return []Sector{}, fmt.Errorf("failed to retrieve disk end")
}
if lastPartitionEnd < diskEnd {
sectors = append(sectors, Sector{lastPartitionEnd, diskEnd})
Expand All @@ -64,45 +65,28 @@ func (disk *Disk) AvailableSectors() ([]Sector, error) {
return sectors, nil
}

type LocateDiskOutput struct {
Disk Disk
}

func LocateDisk(diskname string) (*Disk, error) {
findPartitionCmd := "parted -sj %s unit MiB print | sed -r 's/^(\\s*)\"(.)/\\1\"\\U\\2/g' | sed -r 's/(\\S)-(\\S)/\\1\\U\\2/g'"
cmd := exec.Command("sh", "-c", fmt.Sprintf(findPartitionCmd, diskname))
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("Failed to list disk: %s", err)
findPartitionCmd := "parted -sj %s unit MiB print"
output, err := OutputCommand(fmt.Sprintf(findPartitionCmd, diskname))
// If disk is unformatted, parted returns the expected json but also throws an error.
// We can assume we have all the necessary information if output isn't empty.
if err != nil && output == "" {
return nil, fmt.Errorf("failed to list disk: %s", err)
}

var device *Disk
var decoded *LocateDiskOutput
err = json.Unmarshal(output, &decoded)
if err != nil {
// Try a different approach suitable for when the disk is unformatted
var decodedMap map[string]map[string]interface{}
err = json.Unmarshal(output, &decodedMap)
device := new(Disk)
for k, v := range decodedMap["Disk"] {
err := setField(device, k, v)
if err != nil {
return nil, fmt.Errorf("Failed to decode parted output: %s", err)
}
}
} else {
device = &decoded.Disk
var decoded struct {
Disk Disk
}

if device == nil {
return nil, fmt.Errorf("Could not find device %s", diskname)
err = json.Unmarshal([]byte(output), &decoded)
if err != nil {
return nil, fmt.Errorf("could not find device %s", diskname)
}

for i := 0; i < len(device.Partitions); i++ {
device.Partitions[i].FillPath(device.Path)
for i := 0; i < len(decoded.Disk.Partitions); i++ {
decoded.Disk.Partitions[i].FillPath(decoded.Disk.Path)
}

return device, nil
return &decoded.Disk, nil
}

func (disk *Disk) Update() error {
Expand All @@ -127,14 +111,14 @@ func (disk *Disk) LabelDisk(label DiskLabel) error {
labelDiskCmd := "parted -s %s mklabel %s"

for _, part := range disk.Partitions {
if err := part.UmountPartition(); err != nil {
return fmt.Errorf("Failed to unmount partition %s: %s", part.Path, err)
if err := part.UnmountPartition(); err != nil {
return fmt.Errorf("failed to unmount partition %s: %s", part.Path, err)
}
}

err := RunCommand(fmt.Sprintf(labelDiskCmd, disk.Path, label))
if err != nil {
return fmt.Errorf("Failed to label disk: %s", err)
return fmt.Errorf("failed to label disk: %s", err)
}

return nil
Expand Down Expand Up @@ -170,13 +154,13 @@ func (target *Disk) NewPartition(name string, fsType PartitionFs, start, end int

err := RunCommand(fmt.Sprintf(createPartCmd, target.Path, partType, partName, fsType, start, endStr))
if err != nil {
return nil, fmt.Errorf("Failed to create partition: %s", err)
return nil, fmt.Errorf("failed to create partition: %s", err)
}

// Update partition list because we made changes to the disk
err = target.Update()
if err != nil {
return nil, fmt.Errorf("Failed to create partition: %s", err)
return nil, fmt.Errorf("failed to create partition: %s", err)
}

newPartition := &target.Partitions[len(target.Partitions)-1]
Expand Down
84 changes: 28 additions & 56 deletions core/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package albius
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"

Expand All @@ -20,12 +19,9 @@ func Unsquashfs(filesystem, destination string, force bool) error {
forceFlag = ""
}

cmd := exec.Command("sh", "-c", fmt.Sprintf(unsquashfsCmd, forceFlag, destination, filesystem))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
err := RunCommand(fmt.Sprintf(unsquashfsCmd, forceFlag, destination, filesystem))
if err != nil {
return fmt.Errorf("Failed to run unsquashfs: %s", err)
return fmt.Errorf("failed to run unsquashfs: %s", err)
}

return nil
Expand All @@ -47,14 +43,14 @@ func MakeFs(part *Partition) error {
makefsCmd := "mkswap -f %s"
err = RunCommand(fmt.Sprintf(makefsCmd, part.Path))
case HFS, HFS_PLUS, UDF:
return fmt.Errorf("Unsupported filesystem: %s", part.Filesystem)
return fmt.Errorf("unsupported filesystem: %s", part.Filesystem)
default:
makefsCmd := "mkfs.%s -f %s"
err = RunCommand(fmt.Sprintf(makefsCmd, part.Filesystem, part.Path))
}

if err != nil {
return fmt.Errorf("Failed to make %s filesystem for %s: %s", part.Filesystem, part.Path, err)
return fmt.Errorf("failed to make %s filesystem for %s: %s", part.Filesystem, part.Path, err)
}

return nil
Expand Down Expand Up @@ -94,49 +90,25 @@ func GenFstab(targetRoot string, entries [][]string) error {

func UpdateInitramfs(root string) error {
// Setup mountpoints
if err := RunCommand(fmt.Sprintf("mount --bind /dev %s/dev", root)); err != nil {
return fmt.Errorf("Error mounting /dev to chroot: %s", err)
}
if err := RunCommand(fmt.Sprintf("mount --bind /dev/pts %s/dev/pts", root)); err != nil {
return fmt.Errorf("Error mounting /dev/pts to chroot: %s", err)
}
if err := RunCommand(fmt.Sprintf("mount --bind /proc %s/proc", root)); err != nil {
return fmt.Errorf("Error mounting /proc to chroot: %s", err)
}
if err := RunCommand(fmt.Sprintf("mount --bind /sys %s/sys", root)); err != nil {
return fmt.Errorf("Error mounting /sys to chroot: %s", err)
mountOrder := []string{"/dev", "/dev/pts", "/proc", "/sys"}
for _, mount := range mountOrder {
if err := RunCommand(fmt.Sprintf("mount --bind %s %s%s", mount, root, mount)); err != nil {
return fmt.Errorf("error mounting %s to chroot: %s", mount, err)
}
}

updInitramfsCmd := "update-initramfs -c -k all"

err := RunInChroot(root, updInitramfsCmd)
if err != nil {
return fmt.Errorf("Failed to run update-initramfs command: %s", err)
}

if err := RunCommand(fmt.Sprintf("umount %s/dev/pts", root)); err != nil {
return fmt.Errorf("Error unmounting /dev/pts fron chroot: %s", err)
}
if err := RunCommand(fmt.Sprintf("umount %s/dev", root)); err != nil {
return fmt.Errorf("Error unmounting /dev from chroot: %s", err)
return fmt.Errorf("failed to run update-initramfs command: %s", err)
}
if err := RunCommand(fmt.Sprintf("umount %s/proc", root)); err != nil {
return fmt.Errorf("Error unmounting /proc from chroot: %s", err)
}
if err := RunCommand(fmt.Sprintf("umount %s/sys", root)); err != nil {
return fmt.Errorf("Error unmounting /sys from chroot: %s", err)
}

return nil
}

func RunInChroot(root, command string) error {
cmd := exec.Command("chroot", root, "sh", "-c", command)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return err
// Cleanup mountpoints
unmountOrder := []string{"/dev/pts", "/dev", "/proc", "/sys"}
for _, mount := range unmountOrder {
if err := RunCommand(fmt.Sprintf("umount %s%s", root, mount)); err != nil {
return fmt.Errorf("error unmounting %s fron chroot: %s", mount, err)
}
}

return nil
Expand All @@ -145,37 +117,37 @@ func RunInChroot(root, command string) error {
func OCISetup(imageSource, storagePath, destination string, verbose bool) error {
pmt, err := prometheus.NewPrometheus(filepath.Join(storagePath, "storage"), "overlay", 0)
if err != nil {
return fmt.Errorf("Failed to create Prometheus instance: %s", err)
return fmt.Errorf("failed to create Prometheus instance: %s", err)
}

// Create tmp directory in root's /var to store podman's temp files, since /var/tmp in
// the ISO is tied to the user's RAM and can run out of space pretty quickly
storageTmpDir := filepath.Join(storagePath, "tmp")
err = os.Mkdir(storageTmpDir, 0644)
if err != nil {
return fmt.Errorf("Failed to create storage tmp dir: %s", err)
return fmt.Errorf("failed to create storage tmp dir: %s", err)
}
err = RunCommand(fmt.Sprintf("mount --bind %s %s", storageTmpDir, "/var/tmp"))
if err != nil {
return fmt.Errorf("Failed to mount bind storage tmp dir: %s", err)
return fmt.Errorf("failed to mount bind storage tmp dir: %s", err)
}

storedImageName := strings.ReplaceAll(imageSource, "/", "-")
manifest, err := pmt.PullImage(imageSource, storedImageName)
if err != nil {
return fmt.Errorf("Failed to pull OCI image: %s", err)
return fmt.Errorf("failed to pull OCI image: %s", err)
}

fmt.Printf("Image pulled with digest %s\n", manifest.Config.Digest)

image, err := pmt.GetImageByDigest(manifest.Config.Digest)
if err != nil {
return fmt.Errorf("Failed to get image from digest: %s", err)
return fmt.Errorf("failed to get image from digest: %s", err)
}

mountPoint, err := pmt.MountImage(image.TopLayer)
if err != nil {
return fmt.Errorf("Failed to mount image at %s: %s", image.TopLayer, err)
return fmt.Errorf("failed to mount image at %s: %s", image.TopLayer, err)
}

fmt.Printf("Image mounted at %s\n", mountPoint)
Expand All @@ -191,35 +163,35 @@ func OCISetup(imageSource, storagePath, destination string, verbose bool) error
}
err = RunCommand(fmt.Sprintf("rsync -a%sxHAX --numeric-ids %s/ %s/", verboseFlag, mountPoint, destination))
if err != nil {
return fmt.Errorf("Failed to sync image contents to %s: %s", destination, err)
return fmt.Errorf("failed to sync image contents to %s: %s", destination, err)
}

// Remove storage from destination
err = RunCommand(fmt.Sprintf("umount -l %s/storage/graph/overlay", storagePath))
if err != nil {
return fmt.Errorf("Failed to unmount image: %s", err)
return fmt.Errorf("failed to unmount image: %s", err)
}

// Unmount tmp storage directory
err = RunCommand("umount -l /var/tmp")
if err != nil {
return fmt.Errorf("Failed to unmount storage tmp dir: %s", err)
return fmt.Errorf("failed to unmount storage tmp dir: %s", err)
}
entries, err := os.ReadDir(storageTmpDir)
if err != nil {
return fmt.Errorf("Failed to read from storage tmp dir: %s", err)
return fmt.Errorf("failed to read from storage tmp dir: %s", err)
}
for _, entry := range entries {
err = os.RemoveAll(filepath.Join(storageTmpDir, entry.Name()))
if err != nil {
return fmt.Errorf("Failed to remove %s from storage tmp dir: %s", entry.Name(), err)
return fmt.Errorf("failed to remove %s from storage tmp dir: %s", entry.Name(), err)
}
}

// Store the digest in destination as it may be used by the update manager
err = os.WriteFile(filepath.Join(destination, ".oci_digest"), []byte(manifest.Config.Digest), 0644)
if err != nil {
return fmt.Errorf("Failed to save digest in %s: %s", destination, err)
return fmt.Errorf("failed to save digest in %s: %s", destination, err)
}

return nil
Expand Down
Loading