cache.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Copyright (c) Faye Amacker. All rights reserved.
  2. // Licensed under the MIT License. See LICENSE in the project root for license information.
  3. package cbor
  4. import (
  5. "bytes"
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. )
  14. type encodeFuncs struct {
  15. ef encodeFunc
  16. ief isEmptyFunc
  17. izf isZeroFunc
  18. }
  19. var (
  20. decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
  21. encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
  22. encodeFuncCache sync.Map // map[reflect.Type]encodeFuncs
  23. typeInfoCache sync.Map // map[reflect.Type]*typeInfo
  24. )
  25. type specialType int
  26. const (
  27. specialTypeNone specialType = iota
  28. specialTypeUnmarshalerIface
  29. specialTypeUnexportedUnmarshalerIface
  30. specialTypeEmptyIface
  31. specialTypeIface
  32. specialTypeTag
  33. specialTypeTime
  34. specialTypeJSONUnmarshalerIface
  35. )
  36. type typeInfo struct {
  37. elemTypeInfo *typeInfo
  38. keyTypeInfo *typeInfo
  39. typ reflect.Type
  40. kind reflect.Kind
  41. nonPtrType reflect.Type
  42. nonPtrKind reflect.Kind
  43. spclType specialType
  44. }
  45. func newTypeInfo(t reflect.Type) *typeInfo {
  46. tInfo := typeInfo{typ: t, kind: t.Kind()}
  47. for t.Kind() == reflect.Pointer {
  48. t = t.Elem()
  49. }
  50. k := t.Kind()
  51. tInfo.nonPtrType = t
  52. tInfo.nonPtrKind = k
  53. if k == reflect.Interface {
  54. if t.NumMethod() == 0 {
  55. tInfo.spclType = specialTypeEmptyIface
  56. } else {
  57. tInfo.spclType = specialTypeIface
  58. }
  59. } else if t == typeTag {
  60. tInfo.spclType = specialTypeTag
  61. } else if t == typeTime {
  62. tInfo.spclType = specialTypeTime
  63. } else if reflect.PointerTo(t).Implements(typeUnexportedUnmarshaler) {
  64. tInfo.spclType = specialTypeUnexportedUnmarshalerIface
  65. } else if reflect.PointerTo(t).Implements(typeUnmarshaler) {
  66. tInfo.spclType = specialTypeUnmarshalerIface
  67. } else if reflect.PointerTo(t).Implements(typeJSONUnmarshaler) {
  68. tInfo.spclType = specialTypeJSONUnmarshalerIface
  69. }
  70. switch k {
  71. case reflect.Array, reflect.Slice:
  72. tInfo.elemTypeInfo = getTypeInfo(t.Elem())
  73. case reflect.Map:
  74. tInfo.keyTypeInfo = getTypeInfo(t.Key())
  75. tInfo.elemTypeInfo = getTypeInfo(t.Elem())
  76. }
  77. return &tInfo
  78. }
  79. type decodingStructType struct {
  80. fields fields
  81. fieldIndicesByName map[string]int
  82. err error
  83. toArray bool
  84. }
  85. // The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead,
  86. // here's a very basic implementation of an aggregated error.
  87. type multierror []error
  88. func (m multierror) Error() string {
  89. var sb strings.Builder
  90. for i, err := range m {
  91. sb.WriteString(err.Error())
  92. if i < len(m)-1 {
  93. sb.WriteString(", ")
  94. }
  95. }
  96. return sb.String()
  97. }
  98. func getDecodingStructType(t reflect.Type) *decodingStructType {
  99. if v, _ := decodingStructTypeCache.Load(t); v != nil {
  100. return v.(*decodingStructType)
  101. }
  102. flds, structOptions := getFields(t)
  103. toArray := hasToArrayOption(structOptions)
  104. var errs []error
  105. for i := 0; i < len(flds); i++ {
  106. if flds[i].keyAsInt {
  107. nameAsInt, numErr := strconv.Atoi(flds[i].name)
  108. if numErr != nil {
  109. errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")"))
  110. break
  111. }
  112. flds[i].nameAsInt = int64(nameAsInt)
  113. }
  114. flds[i].typInfo = getTypeInfo(flds[i].typ)
  115. }
  116. fieldIndicesByName := make(map[string]int, len(flds))
  117. for i, fld := range flds {
  118. if _, ok := fieldIndicesByName[fld.name]; ok {
  119. errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name))
  120. continue
  121. }
  122. fieldIndicesByName[fld.name] = i
  123. }
  124. var err error
  125. {
  126. var multi multierror
  127. for _, each := range errs {
  128. if each != nil {
  129. multi = append(multi, each)
  130. }
  131. }
  132. if len(multi) == 1 {
  133. err = multi[0]
  134. } else if len(multi) > 1 {
  135. err = multi
  136. }
  137. }
  138. structType := &decodingStructType{
  139. fields: flds,
  140. fieldIndicesByName: fieldIndicesByName,
  141. err: err,
  142. toArray: toArray,
  143. }
  144. decodingStructTypeCache.Store(t, structType)
  145. return structType
  146. }
  147. type encodingStructType struct {
  148. fields fields
  149. bytewiseFields fields
  150. lengthFirstFields fields
  151. omitEmptyFieldsIdx []int
  152. err error
  153. toArray bool
  154. }
  155. func (st *encodingStructType) getFields(em *encMode) fields {
  156. switch em.sort {
  157. case SortNone, SortFastShuffle:
  158. return st.fields
  159. case SortLengthFirst:
  160. return st.lengthFirstFields
  161. default:
  162. return st.bytewiseFields
  163. }
  164. }
  165. type bytewiseFieldSorter struct {
  166. fields fields
  167. }
  168. func (x *bytewiseFieldSorter) Len() int {
  169. return len(x.fields)
  170. }
  171. func (x *bytewiseFieldSorter) Swap(i, j int) {
  172. x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
  173. }
  174. func (x *bytewiseFieldSorter) Less(i, j int) bool {
  175. return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
  176. }
  177. type lengthFirstFieldSorter struct {
  178. fields fields
  179. }
  180. func (x *lengthFirstFieldSorter) Len() int {
  181. return len(x.fields)
  182. }
  183. func (x *lengthFirstFieldSorter) Swap(i, j int) {
  184. x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
  185. }
  186. func (x *lengthFirstFieldSorter) Less(i, j int) bool {
  187. if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
  188. return len(x.fields[i].cborName) < len(x.fields[j].cborName)
  189. }
  190. return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
  191. }
  192. func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
  193. if v, _ := encodingStructTypeCache.Load(t); v != nil {
  194. structType := v.(*encodingStructType)
  195. return structType, structType.err
  196. }
  197. flds, structOptions := getFields(t)
  198. if hasToArrayOption(structOptions) {
  199. return getEncodingStructToArrayType(t, flds)
  200. }
  201. var err error
  202. var hasKeyAsInt bool
  203. var hasKeyAsStr bool
  204. var omitEmptyIdx []int
  205. e := getEncodeBuffer()
  206. for i := 0; i < len(flds); i++ {
  207. // Get field's encodeFunc
  208. flds[i].ef, flds[i].ief, flds[i].izf = getEncodeFunc(flds[i].typ)
  209. if flds[i].ef == nil {
  210. err = &UnsupportedTypeError{t}
  211. break
  212. }
  213. // Encode field name
  214. if flds[i].keyAsInt {
  215. nameAsInt, numErr := strconv.Atoi(flds[i].name)
  216. if numErr != nil {
  217. err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
  218. break
  219. }
  220. flds[i].nameAsInt = int64(nameAsInt)
  221. if nameAsInt >= 0 {
  222. encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
  223. } else {
  224. n := nameAsInt*(-1) - 1
  225. encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
  226. }
  227. flds[i].cborName = make([]byte, e.Len())
  228. copy(flds[i].cborName, e.Bytes())
  229. e.Reset()
  230. hasKeyAsInt = true
  231. } else {
  232. encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
  233. flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
  234. n := copy(flds[i].cborName, e.Bytes())
  235. copy(flds[i].cborName[n:], flds[i].name)
  236. e.Reset()
  237. // If cborName contains a text string, then cborNameByteString contains a
  238. // string that has the byte string major type but is otherwise identical to
  239. // cborName.
  240. flds[i].cborNameByteString = make([]byte, len(flds[i].cborName))
  241. copy(flds[i].cborNameByteString, flds[i].cborName)
  242. // Reset encoded CBOR type to byte string, preserving the "additional
  243. // information" bits:
  244. flds[i].cborNameByteString[0] = byte(cborTypeByteString) |
  245. getAdditionalInformation(flds[i].cborNameByteString[0])
  246. hasKeyAsStr = true
  247. }
  248. // Check if field can be omitted when empty
  249. if flds[i].omitEmpty {
  250. omitEmptyIdx = append(omitEmptyIdx, i)
  251. }
  252. }
  253. putEncodeBuffer(e)
  254. if err != nil {
  255. structType := &encodingStructType{err: err}
  256. encodingStructTypeCache.Store(t, structType)
  257. return structType, structType.err
  258. }
  259. // Sort fields by canonical order
  260. bytewiseFields := make(fields, len(flds))
  261. copy(bytewiseFields, flds)
  262. sort.Sort(&bytewiseFieldSorter{bytewiseFields})
  263. lengthFirstFields := bytewiseFields
  264. if hasKeyAsInt && hasKeyAsStr {
  265. lengthFirstFields = make(fields, len(flds))
  266. copy(lengthFirstFields, flds)
  267. sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
  268. }
  269. structType := &encodingStructType{
  270. fields: flds,
  271. bytewiseFields: bytewiseFields,
  272. lengthFirstFields: lengthFirstFields,
  273. omitEmptyFieldsIdx: omitEmptyIdx,
  274. }
  275. encodingStructTypeCache.Store(t, structType)
  276. return structType, structType.err
  277. }
  278. func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
  279. for i := 0; i < len(flds); i++ {
  280. // Get field's encodeFunc
  281. flds[i].ef, flds[i].ief, flds[i].izf = getEncodeFunc(flds[i].typ)
  282. if flds[i].ef == nil {
  283. structType := &encodingStructType{err: &UnsupportedTypeError{t}}
  284. encodingStructTypeCache.Store(t, structType)
  285. return structType, structType.err
  286. }
  287. }
  288. structType := &encodingStructType{
  289. fields: flds,
  290. toArray: true,
  291. }
  292. encodingStructTypeCache.Store(t, structType)
  293. return structType, structType.err
  294. }
  295. func getEncodeFunc(t reflect.Type) (encodeFunc, isEmptyFunc, isZeroFunc) {
  296. if v, _ := encodeFuncCache.Load(t); v != nil {
  297. fs := v.(encodeFuncs)
  298. return fs.ef, fs.ief, fs.izf
  299. }
  300. ef, ief, izf := getEncodeFuncInternal(t)
  301. encodeFuncCache.Store(t, encodeFuncs{ef, ief, izf})
  302. return ef, ief, izf
  303. }
  304. func getTypeInfo(t reflect.Type) *typeInfo {
  305. if v, _ := typeInfoCache.Load(t); v != nil {
  306. return v.(*typeInfo)
  307. }
  308. tInfo := newTypeInfo(t)
  309. typeInfoCache.Store(t, tInfo)
  310. return tInfo
  311. }
  312. func hasToArrayOption(tag string) bool {
  313. s := ",toarray"
  314. idx := strings.Index(tag, s)
  315. return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
  316. }