| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370 |
- // Copyright (c) Faye Amacker. All rights reserved.
- // Licensed under the MIT License. See LICENSE in the project root for license information.
- package cbor
- import (
- "bytes"
- "errors"
- "fmt"
- "reflect"
- "sort"
- "strconv"
- "strings"
- "sync"
- )
- type encodeFuncs struct {
- ef encodeFunc
- ief isEmptyFunc
- izf isZeroFunc
- }
- var (
- decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
- encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
- encodeFuncCache sync.Map // map[reflect.Type]encodeFuncs
- typeInfoCache sync.Map // map[reflect.Type]*typeInfo
- )
- type specialType int
- const (
- specialTypeNone specialType = iota
- specialTypeUnmarshalerIface
- specialTypeUnexportedUnmarshalerIface
- specialTypeEmptyIface
- specialTypeIface
- specialTypeTag
- specialTypeTime
- specialTypeJSONUnmarshalerIface
- )
- type typeInfo struct {
- elemTypeInfo *typeInfo
- keyTypeInfo *typeInfo
- typ reflect.Type
- kind reflect.Kind
- nonPtrType reflect.Type
- nonPtrKind reflect.Kind
- spclType specialType
- }
- func newTypeInfo(t reflect.Type) *typeInfo {
- tInfo := typeInfo{typ: t, kind: t.Kind()}
- for t.Kind() == reflect.Pointer {
- t = t.Elem()
- }
- k := t.Kind()
- tInfo.nonPtrType = t
- tInfo.nonPtrKind = k
- if k == reflect.Interface {
- if t.NumMethod() == 0 {
- tInfo.spclType = specialTypeEmptyIface
- } else {
- tInfo.spclType = specialTypeIface
- }
- } else if t == typeTag {
- tInfo.spclType = specialTypeTag
- } else if t == typeTime {
- tInfo.spclType = specialTypeTime
- } else if reflect.PointerTo(t).Implements(typeUnexportedUnmarshaler) {
- tInfo.spclType = specialTypeUnexportedUnmarshalerIface
- } else if reflect.PointerTo(t).Implements(typeUnmarshaler) {
- tInfo.spclType = specialTypeUnmarshalerIface
- } else if reflect.PointerTo(t).Implements(typeJSONUnmarshaler) {
- tInfo.spclType = specialTypeJSONUnmarshalerIface
- }
- switch k {
- case reflect.Array, reflect.Slice:
- tInfo.elemTypeInfo = getTypeInfo(t.Elem())
- case reflect.Map:
- tInfo.keyTypeInfo = getTypeInfo(t.Key())
- tInfo.elemTypeInfo = getTypeInfo(t.Elem())
- }
- return &tInfo
- }
- type decodingStructType struct {
- fields fields
- fieldIndicesByName map[string]int
- err error
- toArray bool
- }
- // The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead,
- // here's a very basic implementation of an aggregated error.
- type multierror []error
- func (m multierror) Error() string {
- var sb strings.Builder
- for i, err := range m {
- sb.WriteString(err.Error())
- if i < len(m)-1 {
- sb.WriteString(", ")
- }
- }
- return sb.String()
- }
- func getDecodingStructType(t reflect.Type) *decodingStructType {
- if v, _ := decodingStructTypeCache.Load(t); v != nil {
- return v.(*decodingStructType)
- }
- flds, structOptions := getFields(t)
- toArray := hasToArrayOption(structOptions)
- var errs []error
- for i := 0; i < len(flds); i++ {
- if flds[i].keyAsInt {
- nameAsInt, numErr := strconv.Atoi(flds[i].name)
- if numErr != nil {
- errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")"))
- break
- }
- flds[i].nameAsInt = int64(nameAsInt)
- }
- flds[i].typInfo = getTypeInfo(flds[i].typ)
- }
- fieldIndicesByName := make(map[string]int, len(flds))
- for i, fld := range flds {
- if _, ok := fieldIndicesByName[fld.name]; ok {
- errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name))
- continue
- }
- fieldIndicesByName[fld.name] = i
- }
- var err error
- {
- var multi multierror
- for _, each := range errs {
- if each != nil {
- multi = append(multi, each)
- }
- }
- if len(multi) == 1 {
- err = multi[0]
- } else if len(multi) > 1 {
- err = multi
- }
- }
- structType := &decodingStructType{
- fields: flds,
- fieldIndicesByName: fieldIndicesByName,
- err: err,
- toArray: toArray,
- }
- decodingStructTypeCache.Store(t, structType)
- return structType
- }
- type encodingStructType struct {
- fields fields
- bytewiseFields fields
- lengthFirstFields fields
- omitEmptyFieldsIdx []int
- err error
- toArray bool
- }
- func (st *encodingStructType) getFields(em *encMode) fields {
- switch em.sort {
- case SortNone, SortFastShuffle:
- return st.fields
- case SortLengthFirst:
- return st.lengthFirstFields
- default:
- return st.bytewiseFields
- }
- }
- type bytewiseFieldSorter struct {
- fields fields
- }
- func (x *bytewiseFieldSorter) Len() int {
- return len(x.fields)
- }
- func (x *bytewiseFieldSorter) Swap(i, j int) {
- x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
- }
- func (x *bytewiseFieldSorter) Less(i, j int) bool {
- return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
- }
- type lengthFirstFieldSorter struct {
- fields fields
- }
- func (x *lengthFirstFieldSorter) Len() int {
- return len(x.fields)
- }
- func (x *lengthFirstFieldSorter) Swap(i, j int) {
- x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
- }
- func (x *lengthFirstFieldSorter) Less(i, j int) bool {
- if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
- return len(x.fields[i].cborName) < len(x.fields[j].cborName)
- }
- return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
- }
- func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
- if v, _ := encodingStructTypeCache.Load(t); v != nil {
- structType := v.(*encodingStructType)
- return structType, structType.err
- }
- flds, structOptions := getFields(t)
- if hasToArrayOption(structOptions) {
- return getEncodingStructToArrayType(t, flds)
- }
- var err error
- var hasKeyAsInt bool
- var hasKeyAsStr bool
- var omitEmptyIdx []int
- e := getEncodeBuffer()
- for i := 0; i < len(flds); i++ {
- // Get field's encodeFunc
- flds[i].ef, flds[i].ief, flds[i].izf = getEncodeFunc(flds[i].typ)
- if flds[i].ef == nil {
- err = &UnsupportedTypeError{t}
- break
- }
- // Encode field name
- if flds[i].keyAsInt {
- nameAsInt, numErr := strconv.Atoi(flds[i].name)
- if numErr != nil {
- err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
- break
- }
- flds[i].nameAsInt = int64(nameAsInt)
- if nameAsInt >= 0 {
- encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
- } else {
- n := nameAsInt*(-1) - 1
- encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
- }
- flds[i].cborName = make([]byte, e.Len())
- copy(flds[i].cborName, e.Bytes())
- e.Reset()
- hasKeyAsInt = true
- } else {
- encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
- flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
- n := copy(flds[i].cborName, e.Bytes())
- copy(flds[i].cborName[n:], flds[i].name)
- e.Reset()
- // If cborName contains a text string, then cborNameByteString contains a
- // string that has the byte string major type but is otherwise identical to
- // cborName.
- flds[i].cborNameByteString = make([]byte, len(flds[i].cborName))
- copy(flds[i].cborNameByteString, flds[i].cborName)
- // Reset encoded CBOR type to byte string, preserving the "additional
- // information" bits:
- flds[i].cborNameByteString[0] = byte(cborTypeByteString) |
- getAdditionalInformation(flds[i].cborNameByteString[0])
- hasKeyAsStr = true
- }
- // Check if field can be omitted when empty
- if flds[i].omitEmpty {
- omitEmptyIdx = append(omitEmptyIdx, i)
- }
- }
- putEncodeBuffer(e)
- if err != nil {
- structType := &encodingStructType{err: err}
- encodingStructTypeCache.Store(t, structType)
- return structType, structType.err
- }
- // Sort fields by canonical order
- bytewiseFields := make(fields, len(flds))
- copy(bytewiseFields, flds)
- sort.Sort(&bytewiseFieldSorter{bytewiseFields})
- lengthFirstFields := bytewiseFields
- if hasKeyAsInt && hasKeyAsStr {
- lengthFirstFields = make(fields, len(flds))
- copy(lengthFirstFields, flds)
- sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
- }
- structType := &encodingStructType{
- fields: flds,
- bytewiseFields: bytewiseFields,
- lengthFirstFields: lengthFirstFields,
- omitEmptyFieldsIdx: omitEmptyIdx,
- }
- encodingStructTypeCache.Store(t, structType)
- return structType, structType.err
- }
- func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
- for i := 0; i < len(flds); i++ {
- // Get field's encodeFunc
- flds[i].ef, flds[i].ief, flds[i].izf = getEncodeFunc(flds[i].typ)
- if flds[i].ef == nil {
- structType := &encodingStructType{err: &UnsupportedTypeError{t}}
- encodingStructTypeCache.Store(t, structType)
- return structType, structType.err
- }
- }
- structType := &encodingStructType{
- fields: flds,
- toArray: true,
- }
- encodingStructTypeCache.Store(t, structType)
- return structType, structType.err
- }
- func getEncodeFunc(t reflect.Type) (encodeFunc, isEmptyFunc, isZeroFunc) {
- if v, _ := encodeFuncCache.Load(t); v != nil {
- fs := v.(encodeFuncs)
- return fs.ef, fs.ief, fs.izf
- }
- ef, ief, izf := getEncodeFuncInternal(t)
- encodeFuncCache.Store(t, encodeFuncs{ef, ief, izf})
- return ef, ief, izf
- }
- func getTypeInfo(t reflect.Type) *typeInfo {
- if v, _ := typeInfoCache.Load(t); v != nil {
- return v.(*typeInfo)
- }
- tInfo := newTypeInfo(t)
- typeInfoCache.Store(t, tInfo)
- return tInfo
- }
- func hasToArrayOption(tag string) bool {
- s := ",toarray"
- idx := strings.Index(tag, s)
- return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
- }
|