Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent snooping around filesystem outside of bounds passed in Add(). #3

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ type Plugin interface {
Store(envelope *Envelope) error
Sha1ADG(map[string]string)
Sha256ADG(map[string]string)
SetAllowList([]string)
}
31 changes: 30 additions & 1 deletion directory_plugin.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package omnitrail

import (
"github.com/omnibor/omnibor-go"
"fmt"
"os"
"path/filepath"
"sort"
"strings"

"github.com/omnibor/omnibor-go"
)

type DirectoryPlugin struct {
Expand All @@ -13,6 +16,16 @@ type DirectoryPlugin struct {
directories map[string]bool
sha1adgs map[string]omnibor.ArtifactTree
sha256adgs map[string]omnibor.ArtifactTree
AllowList []string
}

func (plug *DirectoryPlugin) isAllowedDirectory(path string) bool {
for _, allowedPath := range plug.AllowList {
if strings.HasPrefix(path, allowedPath) {
return true
}
}
return false
}

func (plug *DirectoryPlugin) Sha1ADG(m map[string]string) {
Expand All @@ -28,20 +41,32 @@ func (plug *DirectoryPlugin) Sha256ADG(m map[string]string) {
}

func (plug *DirectoryPlugin) Add(path string) error {

// if this is a broken symlink, ignore
fileInfo, err := os.Lstat(path)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
return err
}
if fileInfo.Mode()&os.ModeSymlink != 0 {
// path is a symlink
targetPath, err := os.Readlink(path)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
return err
}
if !filepath.IsAbs(targetPath) {
targetPath = filepath.Join(filepath.Dir(path), targetPath)
}
if !plug.isAllowedDirectory(targetPath) {
return fmt.Errorf("path %s is not in the allow list", path)
}
if _, err := os.Stat(targetPath); err != nil {
return nil
}
Expand Down Expand Up @@ -168,6 +193,10 @@ func (plug *DirectoryPlugin) addKeysToTree(keys []string, tree map[string]omnibo
return nil
}

func (plug *DirectoryPlugin) SetAllowList(allowList []string) {
plug.AllowList = allowList
}

func NewDirectoryPlugin() Plugin {
algorithms := []string{"gitoid:sha1", "gitoid:sha256"}
sort.Strings(algorithms)
Expand Down
18 changes: 14 additions & 4 deletions factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,27 @@ import (
)

type factoryImpl struct {
Options *Options
envelope *Envelope
Plugins []Plugin
Options *Options
envelope *Envelope
Plugins []Plugin
AllowList []string
}

func (factory *factoryImpl) Add(originalPath string) error {
originalPath, err := filepath.Abs(originalPath)
// Convert the path to an absolute path
absPath, err := filepath.Abs(originalPath)
if err != nil {
return err
}

// Add the absolute path to the allow list
factory.AllowList = append(factory.AllowList, absPath)

// For each plugin, add the allow list
for _, plugin := range factory.Plugins {
plugin.SetAllowList(factory.AllowList)
}

// check if path already exists in the envelope, if so, return
if _, ok := factory.envelope.Mapping[originalPath]; ok {
return nil
Expand Down
28 changes: 27 additions & 1 deletion file_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ import (
type FilePlugin struct {
algorithms []string
files map[string]map[string]string
AllowList []string
}

func (plug *FilePlugin) isAllowedDirectory(path string) bool {
for _, allowedPath := range plug.AllowList {
if strings.HasPrefix(path, allowedPath) {
return true
}
}
return false
}

func (plug *FilePlugin) Sha1ADG(m map[string]string) {
Expand All @@ -38,6 +48,10 @@ func (plug *FilePlugin) Sha256ADG(m map[string]string) {
}
}

func (plug *FilePlugin) SetAllowList(allowList []string) {
plug.AllowList = allowList
}

func NewFilePlugin() Plugin {
algorithms := []string{"sha1", "sha256", "gitoid:sha1", "gitoid:sha256"}
sort.Strings(algorithms)
Expand All @@ -52,20 +66,32 @@ func NewFilePlugin() Plugin {
}

func (plug *FilePlugin) Add(filePath string) error {

// ignore broken symlink
localFileInfo, err := os.Lstat(filePath)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
return err
}
if localFileInfo.Mode()&os.ModeSymlink != 0 {
targetPath, err := os.Readlink(filePath)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
fmt.Println("returning err: ", err)
return err
}
if !filepath.IsAbs(targetPath) {

targetPath = filepath.Join(filepath.Dir(filePath), targetPath)
}
if !plug.isAllowedDirectory(targetPath) {
return fmt.Errorf("path %s is not in the allow list", filePath)
}
if _, err = os.Stat(targetPath); err != nil {
return nil
}
Expand Down
7 changes: 6 additions & 1 deletion omnitrail.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ func NewTrail(option ...Option) Factory {
if o.Sha1Enabled == false && o.Sha256Enabled == false {
o.Sha1Enabled = true
}
allowList := []string{}
plugins := make([]Plugin, 0)
plugins = append(plugins, NewFilePlugin())
plugins = append(plugins, NewDirectoryPlugin())
plugins = append(plugins, NewPosixPlugin())
return &factoryImpl{

factory := &factoryImpl{
Options: o,
Plugins: plugins,
envelope: &Envelope{
Expand All @@ -26,7 +28,10 @@ func NewTrail(option ...Option) Factory {
},
Mapping: make(map[string]*Element),
},
AllowList: allowList,
}

return factory
}

func FormatADGString(mapping Factory) string {
Expand Down
90 changes: 78 additions & 12 deletions omnitrail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package omnitrail

import (
"encoding/json"
"github.com/stretchr/testify/assert"
"fmt"
"os"
"os/user"
"reflect"
"sort"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestEmpty(t *testing.T) {
Expand All @@ -17,36 +20,88 @@ func TestEmpty(t *testing.T) {
assert.NoError(t, err)
}
name := "empty"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestEmpty failed: %v", err)
}
}

func TestOneFiles(t *testing.T) {
name := "one-file"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestOneFiles failed: %v", err)
}
}

func TestTwoFiles(t *testing.T) {
name := "two-files"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestTwoFiles failed: %v", err)
}
}

func TestDeepStructure(t *testing.T) {
name := "deep"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestDeepStructure failed: %v", err)
}
}

func TestSymlinkGood(t *testing.T) {
name := "symlink-good"
if err := testAdd(t, name); err != nil {
t.Fatalf("TestSymlinkGood failed: %v", err)
}
}

func TestSymlinkBroken(t *testing.T) {
name := "symlink-broken"
if err := testAdd(t, name); err != nil {
t.Fatalf("should ignore a bad symlink: %v", err)
}
}

func TestSymlinkOutOfBounds(t *testing.T) {
name := "symlink-out-of-bounds"
err := os.WriteFile("/tmp/omnitrail-well-known-file", []byte("hello"), 0644)
if err != nil {
t.Fatalf("unable to write temporary file: %v", err)
}
defer os.Remove("/tmp/omnitrail-well-known-file")
err = testAdd(t, name)
if !strings.Contains(err.Error(), "not in the allow list") {
t.Fatalf("unexpected error: %v", err)

}
if err == nil {
t.Fatalf("TestSymlinkOutOfBounds failed: should report a symlik out of bounds")
}
}

func testAdd(t *testing.T, name string) {
func testAdd(t *testing.T, name string) error {
mapping := NewTrail()

err := mapping.Add("./test/" + name)
assert.NoError(t, err)
if err != nil {
return err
}

// WARNING: these are only for generating new test cases easily
// file, err := json.MarshalIndent(mapping.Envelope(), "", " ")
// os.WriteFile("./test/"+name+".json", file, 0644)
// res := FormatADGString(mapping)
// os.WriteFile("./test/"+name+".adg", []byte(res), 0644)
// END WARNING

expectedBytes, err := os.ReadFile("./test/" + name + ".json")
assert.NoError(t, err)
if err != nil {
return err
}

var expectedEnvelope Envelope
err = json.Unmarshal(expectedBytes, &expectedEnvelope)
assert.NoError(t, err)
if err != nil {
return err
}

shortestExpectedKey := getShortestKey(&expectedEnvelope)
shortestActualKey := getShortestKey(mapping.Envelope())
Expand All @@ -59,7 +114,9 @@ func testAdd(t *testing.T, name string) {

// get current username
currentUser, err := user.Current()
assert.NoError(t, err)
if err != nil {
return err
}
uid := currentUser.Uid
gid := currentUser.Gid

Expand All @@ -70,11 +127,20 @@ func testAdd(t *testing.T, name string) {

assert.Equal(t, &expectedEnvelope, mapping.Envelope())

if !reflect.DeepEqual(&expectedEnvelope, mapping.Envelope()) {
return fmt.Errorf("expected envelope does not match actual envelope")
}

res := FormatADGString(mapping)

expectedBytes, err = os.ReadFile("./test/" + name + ".adg")
assert.NoError(t, err)
assert.Equal(t, string(expectedBytes), res)
if err != nil {
return err
}
if string(expectedBytes) != res {
return fmt.Errorf("expected ADG string does not match actual ADG string")
}
return nil
}

func getShortestKey(expectedEnvelope *Envelope) string {
Expand Down
Loading
Loading