Просмотр исходного кода

always use readFull in case inconsistent bytes are read.

Matt Bolt 1 месяц назад
Родитель
Сommit
90f6fb4abd
3 измененных файлов с 139 добавлено и 42 удалено
  1. 2 20
      core/pkg/util/buffer.go
  2. 82 2
      core/pkg/util/buffer_test.go
  3. 55 20
      core/pkg/util/bufferhelper.go

+ 2 - 20
core/pkg/util/buffer.go

@@ -354,7 +354,7 @@ func (b *Buffer) ReadString() string {
 	bytes := bytePool.Get(int(l))
 	defer bytePool.Put(bytes)
 
-	_, err := readFull(b.b, bytes)
+	_, err := readBuffFull(b.b, bytes)
 	if err != nil {
 		return ""
 	}
@@ -371,7 +371,7 @@ func (b *Buffer) ReadBytes(length int) []byte {
 	bytes := bytePool.Get(length)
 	defer bytePool.Put(bytes)
 
-	_, err := readFull(b.b, bytes)
+	_, err := readBuffFull(b.b, bytes)
 	if err != nil {
 		return bytes
 	}
@@ -379,24 +379,6 @@ func (b *Buffer) ReadBytes(length int) []byte {
 	return bytes
 }
 
-// read full is a bufio.Reader specific implementation of io.ReadFull() which
-// avoids escaping our stack allocated scratch bytes
-func readFull(r *bufio.Reader, buf []byte) (n int, err error) {
-
-	min := len(buf)
-	for n < min && err == nil {
-		var nn int
-		nn, err = r.Read(buf[n:])
-		n += nn
-	}
-	if n >= min {
-		err = nil
-	} else if n > 0 && err == io.EOF {
-		err = io.ErrUnexpectedEOF
-	}
-	return
-}
-
 // Conversion from byte slice to string
 func bytesToString(b []byte) string {
 	// This code will take the passed byte slice and cast it in-place into a string. By doing

+ 82 - 2
core/pkg/util/buffer_test.go

@@ -2,8 +2,9 @@ package util
 
 import (
 	"bytes"
+	"io"
 	"math"
-	"math/rand"
+	"math/rand/v2"
 	"runtime"
 	"strings"
 	"testing"
@@ -224,7 +225,7 @@ const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
 func generateRandomString(ln int) string {
 	b := make([]byte, ln)
 	for i := range b {
-		b[i] = letters[rand.Intn(len(letters))]
+		b[i] = letters[rand.IntN(len(letters))]
 	}
 	return string(b)
 }
@@ -283,6 +284,85 @@ func TestStringBytes(t *testing.T) {
 	}
 }
 
+type randomByteReader struct {
+	bytes []byte
+	pos   int
+}
+
+func newSingleByteReader(bytes []byte) *randomByteReader {
+	return &randomByteReader{
+		bytes: bytes,
+		pos:   0,
+	}
+}
+
+func (sbr *randomByteReader) Read(b []byte) (int, error) {
+	if sbr.pos >= len(sbr.bytes) {
+		return 0, io.EOF
+	}
+
+	toCopy := rand.IntN(4)
+	if toCopy > len(b) {
+		toCopy = len(b)
+	}
+
+	var err error
+	remaining := len(sbr.bytes) - sbr.pos
+	if toCopy > remaining {
+		err = io.EOF
+		toCopy = remaining
+	}
+
+	bytesCopied := copy(b, sbr.bytes[sbr.pos:sbr.pos+toCopy])
+	sbr.pos += bytesCopied
+
+	return bytesCopied, err
+}
+
+func TestBufferReaderSupport(t *testing.T) {
+	buf := NewBuffer()
+	buf.WriteBool(true)
+	buf.WriteInt(42)
+	buf.WriteFloat64(3.14)
+	buf.WriteString("Testing, 1, 2, 3!")
+	buf.WriteUInt64(uint64(123456))
+	buf.WriteInt16(44)
+	buf.WriteFloat32(float32(5.0))
+
+	reader := newSingleByteReader(buf.Bytes())
+	readerBuff := NewBufferFromReader(reader)
+
+	b := readerBuff.ReadBool()
+	i := readerBuff.ReadInt()
+	f := readerBuff.ReadFloat64()
+	s := readerBuff.ReadString()
+	ui64 := readerBuff.ReadUInt64()
+	i16 := readerBuff.ReadInt16()
+	f32 := readerBuff.ReadFloat32()
+
+	if !b {
+		t.Errorf("expected true, got: false")
+	}
+	if i != 42 {
+		t.Errorf("expected 42, got: %d", i)
+	}
+	if f != 3.14 {
+		t.Errorf("expected 3.14, got: %f", f)
+	}
+	if s != "Testing, 1, 2, 3!" {
+		t.Errorf("expected 'Testing, 1, 2, 3!', got: '%s'", s)
+	}
+	if ui64 != uint64(123456) {
+		t.Errorf("expected 123456, got: %d", ui64)
+	}
+	if i16 != int16(44) {
+		t.Errorf("expected 44, got: %d", i16)
+	}
+	if f32 != float32(5.0) {
+		t.Errorf("expected 5.0, got: %f", f32)
+	}
+}
+
 func TestTooLargeStringTruncate(t *testing.T) {
 	normalStr := generateRandomString(100)
 	bigStr := generateRandomString(math.MaxUint16 + (math.MaxUint16 / 2))

+ 55 - 20
core/pkg/util/bufferhelper.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/binary"
+	"io"
 	"math"
 )
 
@@ -42,7 +43,7 @@ func readInt16(r *bytes.Buffer, data *int16) error {
 	var b [2]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -56,7 +57,7 @@ func readUint16(r *bytes.Buffer, data *uint16) error {
 	var b [2]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -70,7 +71,7 @@ func readInt(r *bytes.Buffer, data *int) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -84,7 +85,7 @@ func readInt32(r *bytes.Buffer, data *int32) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -98,7 +99,7 @@ func readUint(r *bytes.Buffer, data *uint) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -112,7 +113,7 @@ func readUint32(r *bytes.Buffer, data *uint32) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -126,7 +127,7 @@ func readInt64(r *bytes.Buffer, data *int64) error {
 	var b [8]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -140,7 +141,7 @@ func readUint64(r *bytes.Buffer, data *uint64) error {
 	var b [8]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -154,7 +155,7 @@ func readFloat32(r *bytes.Buffer, data *float32) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -168,7 +169,7 @@ func readFloat64(r *bytes.Buffer, data *float64) error {
 	var b [8]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -212,7 +213,7 @@ func readBuffInt16(r *bufio.Reader, data *int16) error {
 	var b [2]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -226,7 +227,7 @@ func readBuffUint16(r *bufio.Reader, data *uint16) error {
 	var b [2]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -240,7 +241,7 @@ func readBuffInt(r *bufio.Reader, data *int) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -254,7 +255,7 @@ func readBuffInt32(r *bufio.Reader, data *int32) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -268,7 +269,7 @@ func readBuffUint(r *bufio.Reader, data *uint) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -282,7 +283,7 @@ func readBuffUint32(r *bufio.Reader, data *uint32) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -296,7 +297,7 @@ func readBuffInt64(r *bufio.Reader, data *int64) error {
 	var b [8]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -310,7 +311,7 @@ func readBuffUint64(r *bufio.Reader, data *uint64) error {
 	var b [8]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -324,7 +325,7 @@ func readBuffFloat32(r *bufio.Reader, data *float32) error {
 	var b [4]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -338,7 +339,7 @@ func readBuffFloat64(r *bufio.Reader, data *float64) error {
 	var b [8]byte
 
 	bs := b[:]
-	_, err := r.Read(bs)
+	_, err := readBuffFull(r, bs)
 	if err != nil {
 		return err
 	}
@@ -347,6 +348,40 @@ func readBuffFloat64(r *bufio.Reader, data *float64) error {
 	return nil
 }
 
+// read full is a bufio.Reader specific implementation of io.ReadFull() which
+// avoids escaping our stack allocated scratch bytes
+func readBuffFull(r *bufio.Reader, buf []byte) (n int, err error) {
+	min := len(buf)
+	for n < min && err == nil {
+		var nn int
+		nn, err = r.Read(buf[n:])
+		n += nn
+	}
+	if n >= min {
+		err = nil
+	} else if n > 0 && err == io.EOF {
+		err = io.ErrUnexpectedEOF
+	}
+	return
+}
+
+// read full is a bytes.Buffer specific implementation of io.ReadFull() which
+// avoids escaping our stack allocated scratch bytes
+func readFull(r *bytes.Buffer, buf []byte) (n int, err error) {
+	min := len(buf)
+	for n < min && err == nil {
+		var nn int
+		nn, err = r.Read(buf[n:])
+		n += nn
+	}
+	if n >= min {
+		err = nil
+	} else if n > 0 && err == io.EOF {
+		err = io.ErrUnexpectedEOF
+	}
+	return
+}
+
 func writeBool(w *bytes.Buffer, data bool) error {
 	if data {
 		w.WriteByte(1)