Procházet zdrojové kódy

Merge branch 'develop' into fix-spot-prices-with-default-config

Ajay Tripathy před 2 roky
rodič
revize
8b38c12098

+ 11 - 6
pkg/storage/filestorage.go

@@ -1,6 +1,7 @@
 package storage
 
 import (
+	"fmt"
 	gofs "io/fs"
 	"os"
 	gopath "path"
@@ -90,15 +91,17 @@ func (fs *FileStorage) ListDirectories(path string) ([]*StorageInfo, error) {
 
 // Read uses the relative path of the storage combined with the provided path to
 // read the contents.
+//
+// It takes advantage of flock() based locking to improve safety.
 func (fs *FileStorage) Read(path string) ([]byte, error) {
 	f := gopath.Join(fs.baseDir, path)
 
-	b, err := os.ReadFile(f)
+	b, err := fileutil.ReadLocked(f)
 	if err != nil {
-		if os.IsNotExist(err) {
+		if errors.Is(err, os.ErrNotExist) {
 			return nil, DoesNotExistError
 		}
-		return nil, errors.Wrap(err, "Failed to read file")
+		return nil, fmt.Errorf("reading %s: %w", f, err)
 	}
 
 	return b, nil
@@ -106,14 +109,16 @@ func (fs *FileStorage) Read(path string) ([]byte, error) {
 
 // Write uses the relative path of the storage combined with the provided path
 // to write a new file or overwrite an existing file.
+//
+// It takes advantage of flock() based locking to improve safety.
 func (fs *FileStorage) Write(path string, data []byte) error {
 	f, err := fs.prepare(path)
 	if err != nil {
 		return errors.Wrap(err, "Failed to prepare path")
 	}
-	err = os.WriteFile(f, data, os.ModePerm)
-	if err != nil {
-		return errors.Wrap(err, "Failed to write file")
+
+	if _, err := fileutil.WriteLocked(f, data); err != nil {
+		return fmt.Errorf("writing %s: %w", f, err)
 	}
 
 	return nil

+ 3 - 1
pkg/util/fileutil/fileutil.go

@@ -1,6 +1,8 @@
 package fileutil
 
-import "os"
+import (
+	"os"
+)
 
 // File exists has three different return cases that should be handled:
 //  1. File exists and is not a directory (true, nil)

+ 108 - 0
pkg/util/fileutil/locks_test.go

@@ -0,0 +1,108 @@
+package fileutil
+
+import (
+	"os"
+	"path/filepath"
+	"testing"
+)
+
+// Make sure read works on file created without locking logic
+func TestReadLocked(t *testing.T) {
+	toWrite := "hello world"
+	dir := t.TempDir()
+	filename := filepath.Join(dir, "test.txt")
+	if err := os.WriteFile(filename, []byte(toWrite), 0600); err != nil {
+		t.Fatalf("failed to write test data: %s", err)
+	}
+
+	read, err := ReadLocked(filename)
+	if err != nil {
+		t.Fatalf("Failed to read: %s", err)
+	}
+	sread := string(read)
+
+	if toWrite != sread {
+		t.Errorf("Expected read data to be '%s' but was '%s'", toWrite, sread)
+	}
+}
+
+// Does not test concurrency, just makes sure the basic read write functionality
+// works
+func TestRWLocked(t *testing.T) {
+	toWrite := "hello world"
+	dir := t.TempDir()
+	filename := filepath.Join(dir, "test.txt")
+
+	if _, err := WriteLocked(filename, []byte(toWrite)); err != nil {
+		t.Fatalf("Failed to write: %s", err)
+	}
+
+	read, err := ReadLocked(filename)
+	if err != nil {
+		t.Fatalf("Failed to read: %s", err)
+	}
+	sread := string(read)
+
+	if toWrite != sread {
+		t.Errorf("Expected read data to be '%s' but was '%s'", toWrite, sread)
+	}
+}
+
+func TestReadLockedFDMiddlePosition(t *testing.T) {
+	toWrite := "hello world"
+	dir := t.TempDir()
+	filename := filepath.Join(dir, "test.txt")
+	if err := os.WriteFile(filename, []byte(toWrite), 0600); err != nil {
+		t.Fatalf("failed to write test data: %s", err)
+	}
+
+	f, err := os.Open(filename)
+	if err != nil {
+		t.Fatalf("opening after write: %s", err)
+	}
+	if _, err := f.Seek(3, 0); err != nil {
+		t.Fatalf("seeking: %s", err)
+	}
+
+	read, err := ReadLockedFD(f)
+	if err != nil {
+		t.Fatalf("Failed to read: %s", err)
+	}
+	sread := string(read)
+
+	if toWrite != sread {
+		t.Errorf("Expected read data to be '%s' but was '%s'", toWrite, sread)
+	}
+}
+
+func TestWriteLockedFDMiddlePosition(t *testing.T) {
+	toWrite := "hello world"
+	toWriteOver := "goodbye land"
+	dir := t.TempDir()
+	filename := filepath.Join(dir, "test.txt")
+	if err := os.WriteFile(filename, []byte(toWrite), 0600); err != nil {
+		t.Fatalf("failed to write test data: %s", err)
+	}
+
+	f, err := os.OpenFile(filename, os.O_RDWR, 0600)
+	if err != nil {
+		t.Fatalf("opening after write: %s", err)
+	}
+	if _, err := f.Seek(3, 0); err != nil {
+		t.Fatalf("seeking: %s", err)
+	}
+
+	if _, err := WriteLockedFD(f, []byte(toWriteOver)); err != nil {
+		t.Fatalf("writing over: %s", err)
+	}
+
+	read, err := ReadLockedFD(f)
+	if err != nil {
+		t.Fatalf("Failed to read: %s", err)
+	}
+	sread := string(read)
+
+	if toWriteOver != sread {
+		t.Errorf("Expected read data to be '%s' but was '%s'", toWriteOver, sread)
+	}
+}

+ 124 - 0
pkg/util/fileutil/locks_unix.go

@@ -0,0 +1,124 @@
+//go:build darwin || dragonfly || freebsd || illumos || linux || netbsd || openbsd
+
+// The above platforms support flock()
+
+package fileutil
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"os"
+	"syscall"
+
+	"github.com/opencost/opencost/pkg/log"
+)
+
+// WriteLockedFD uses the flock() syscall to safely write to an open file as
+// long as other users of the file are also using flock()-based access.
+//
+// WriteLocked will block until it gets lock access.
+//
+// The file will be truncated before writing and at the end of writing the
+// FD will be reset to position 0.
+//
+// For the reasons outlined best in https://lwn.net/Articles/586904/ this uses
+// flock() instead of fcntl(). The ability to lock byte ranges is not necessary
+// and flock() has better behavior.
+func WriteLockedFD(f *os.File, data []byte) (int, error) {
+	// For the reasons outlined best in https://lwn.net/Articles/586904/ we're
+	// going to use flock() instead of fcntl() because we want a whole-file lock.
+	if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
+		return 0, fmt.Errorf("unexpected error flock()-ing with EX: %w", err)
+	}
+	defer func() {
+		if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
+			log.Errorf("unexpected error flock()-ing FD %d with UN after writing: %s", f.Fd(), err)
+		}
+	}()
+
+	if err := f.Truncate(0); err != nil {
+		return 0, fmt.Errorf("truncating: %w", err)
+	}
+
+	if _, err := f.Seek(0, 0); err != nil {
+		return 0, fmt.Errorf("seeking to 0 before write: %w", err)
+	}
+	defer func() {
+		if _, err := f.Seek(0, 0); err != nil {
+			log.Errorf("unexpected error seeking to 0 after write on FD %d: %s", f.Fd(), err)
+		}
+	}()
+
+	n, err := f.Write(data)
+	if err != nil {
+		return n, fmt.Errorf("writing data: %w", err)
+	}
+
+	return n, nil
+}
+
+// WriteLocked opens the file and then calls WriteLockedFD.
+func WriteLocked(filename string, data []byte) (int, error) {
+	file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0600)
+	if err != nil {
+		return 0, fmt.Errorf("opening %s: %w", filename, err)
+	}
+	defer file.Close()
+
+	return WriteLockedFD(file, data)
+}
+
+// ReadLockedFD uses the flock() syscall to safely read from an open file as
+// long as other users of the file are also using flock()-based access.
+//
+// ReadLockedFD will block until it gets lock access.
+//
+// This will read the file in full, from 0, regardless of the current
+// position and then reset the position to 0.
+//
+// For the reasons outlined best in https://lwn.net/Articles/586904/ this uses
+// flock() instead of fcntl(). The ability to lock byte ranges is not necessary
+// and flock() has better behavior.
+func ReadLockedFD(f *os.File) ([]byte, error) {
+	// For the reasons outlined best in https://lwn.net/Articles/586904/ we're
+	// going to use flock() instead of fcntl() because we want a whole-file lock.
+	if err := syscall.Flock(int(f.Fd()), syscall.LOCK_SH); err != nil {
+		return nil, fmt.Errorf("unexpected error flock()-ing with SH: %w", err)
+	}
+	defer func() {
+		if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
+			log.Errorf("unexpected error flock()-ing FD %d with UN reading: %s", f.Fd(), err)
+		}
+	}()
+
+	if _, err := f.Seek(0, 0); err != nil {
+		return nil, fmt.Errorf("seeking to 0 before read: %w", err)
+	}
+	defer func() {
+		if _, err := f.Seek(0, 0); err != nil {
+			log.Errorf("unexpected error seeking to 0 after read on FD %d: %s", f.Fd(), err)
+		}
+	}()
+
+	buf := bytes.NewBuffer(nil)
+	if _, err := io.Copy(buf, f); err != nil {
+		if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
+			log.Errorf("unexpected error flock()-ing with UN after error reading: %s", err)
+		}
+		return nil, fmt.Errorf("copying data out of file: %w", err)
+	}
+
+	return buf.Bytes(), nil
+}
+
+// ReadLocked opens the given file and then calls ReadLockedFD.
+func ReadLocked(filename string) ([]byte, error) {
+	file, err := os.OpenFile(filename, os.O_RDONLY, 0600)
+	if err != nil {
+		return nil, fmt.Errorf("opening %s: %w", filename, err)
+	}
+	defer file.Close()
+
+	return ReadLockedFD(file)
+}

+ 24 - 0
pkg/util/fileutil/locks_windows.go

@@ -0,0 +1,24 @@
+//go:build windows
+
+package fileutil
+
+import (
+	"fmt"
+	"os"
+)
+
+func WriteLockedFD(f *os.File, data []byte) (int, error) {
+	return 0, fmt.Errorf("WriteLockedFD is not implemented on Windows. Please open an issue.")
+}
+
+func WriteLocked(filename string, data []byte) (int, error) {
+	return 0, fmt.Errorf("WriteLocked is not implemented on Windows. Please open an issue.")
+}
+
+func ReadLockedFD(f *os.File) ([]byte, error) {
+	return nil, fmt.Errorf("ReadLockedFD is not implemented on Windows. Please open an issue.")
+}
+
+func ReadLocked(filename string) ([]byte, error) {
+	return nil, fmt.Errorf("ReadLocked is not implemented on Windows. Please open an issue.")
+}