union.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. /*
  2. Copyright 2024 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package validate
  14. import (
  15. "context"
  16. "fmt"
  17. "strings"
  18. "k8s.io/apimachinery/pkg/api/operation"
  19. "k8s.io/apimachinery/pkg/util/validation/field"
  20. )
  21. // ExtractorFn extracts a value from a parent object. Depending on the context,
  22. // that could be the value of a field or just whether that field was set or
  23. // not.
  24. // Note: obj is not guaranteed to be non-nil, need to handle nil obj in the
  25. // extractor.
  26. type ExtractorFn[T, V any] func(obj T) V
  27. // UnionValidationOptions configures how union validation behaves
  28. type UnionValidationOptions struct {
  29. // ErrorForEmpty returns error when no fields are set (nil means no error)
  30. ErrorForEmpty func(fldPath *field.Path, allFields []string) *field.Error
  31. // ErrorForMultiple returns error when multiple fields are set (nil means no error)
  32. ErrorForMultiple func(fldPath *field.Path, specifiedFields []string, allFields []string) *field.Error
  33. }
  34. // Union verifies that exactly one member of a union is specified.
  35. //
  36. // UnionMembership must define all the members of the union.
  37. //
  38. // For example:
  39. //
  40. // var UnionMembershipForABC := validate.NewUnionMembership(
  41. // validate.NewUnionMember("a"),
  42. // validate.NewUnionMember("b"),
  43. // validate.NewUnionMember("c"),
  44. // )
  45. // func ValidateABC(ctx context.Context, op operation.Operation, fldPath *field.Path, in *ABC) (errs field.ErrorList) {
  46. // errs = append(errs, Union(ctx, op, fldPath, in, oldIn, UnionMembershipForABC,
  47. // func(in *ABC) bool { return in.A != nil },
  48. // func(in *ABC) bool { return in.B != "" },
  49. // func(in *ABC) bool { return in.C != 0 },
  50. // )...)
  51. // return errs
  52. // }
  53. func Union[T any](_ context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj T, union *UnionMembership, isSetFns ...ExtractorFn[T, bool]) field.ErrorList {
  54. options := UnionValidationOptions{
  55. ErrorForEmpty: func(fldPath *field.Path, allFields []string) *field.Error {
  56. return field.Invalid(fldPath, "",
  57. fmt.Sprintf("must specify one of: %s", strings.Join(allFields, ", ")))
  58. },
  59. ErrorForMultiple: func(fldPath *field.Path, specifiedFields []string, allFields []string) *field.Error {
  60. return field.Invalid(fldPath, fmt.Sprintf("{%s}", strings.Join(specifiedFields, ", ")),
  61. fmt.Sprintf("must specify exactly one of: %s", strings.Join(allFields, ", ")))
  62. },
  63. }
  64. return unionValidate(op, fldPath, obj, oldObj, union, options, isSetFns...)
  65. }
  66. // DiscriminatedUnion verifies specified union member matches the discriminator.
  67. //
  68. // UnionMembership must define all the members of the union and the discriminator.
  69. //
  70. // For example:
  71. //
  72. // var UnionMembershipForABC = validate.NewDiscriminatedUnionMembership("type",
  73. // validate.NewDiscriminatedUnionMember("a", "A"),
  74. // validate.NewDiscriminatedUnionMember("b", "B"),
  75. // validate.NewDiscriminatedUnionMember("c", "C"),
  76. // )
  77. // func ValidateABC(ctx context.Context, op operation.Operation, fldPath *field.Path, in *ABC) (errs field.ErrorList) {
  78. // errs = append(errs, DiscriminatedUnion(ctx, op, fldPath, in, oldIn, UnionMembershipForABC,
  79. // func(in *ABC) string { return string(in.Type) },
  80. // func(in *ABC) bool { return in.A != nil },
  81. // func(in *ABC) bool { return in.B != "" },
  82. // func(in *ABC) bool { return in.C != 0 },
  83. // )...)
  84. // return errs
  85. // }
  86. //
  87. // It is not an error for the discriminatorValue to be unknown. That must be
  88. // validated on its own.
  89. func DiscriminatedUnion[T any, D ~string](_ context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj T, union *UnionMembership, discriminatorExtractor ExtractorFn[T, D], isSetFns ...ExtractorFn[T, bool]) (errs field.ErrorList) {
  90. if len(union.members) != len(isSetFns) {
  91. return field.ErrorList{
  92. field.InternalError(fldPath,
  93. fmt.Errorf("number of extractors (%d) does not match number of union members (%d)",
  94. len(isSetFns), len(union.members))),
  95. }
  96. }
  97. var changed bool
  98. discriminatorValue := discriminatorExtractor(obj)
  99. if op.Type == operation.Update {
  100. oldDiscriminatorValue := discriminatorExtractor(oldObj)
  101. changed = discriminatorValue != oldDiscriminatorValue
  102. }
  103. for i, fieldIsSet := range isSetFns {
  104. member := union.members[i]
  105. isDiscriminatedMember := string(discriminatorValue) == member.discriminatorValue
  106. newIsSet := fieldIsSet(obj)
  107. if op.Type == operation.Update && !changed {
  108. oldIsSet := fieldIsSet(oldObj)
  109. changed = changed || newIsSet != oldIsSet
  110. }
  111. if newIsSet && !isDiscriminatedMember {
  112. errs = append(errs, field.Invalid(fldPath.Child(member.fieldName), "",
  113. fmt.Sprintf("may only be specified when `%s` is %q", union.discriminatorName, member.discriminatorValue)))
  114. } else if !newIsSet && isDiscriminatedMember {
  115. errs = append(errs, field.Invalid(fldPath.Child(member.fieldName), "",
  116. fmt.Sprintf("must be specified when `%s` is %q", union.discriminatorName, discriminatorValue)))
  117. }
  118. }
  119. // If the union discriminator and membership is unchanged, we don't need to
  120. // re-validate.
  121. if op.Type == operation.Update && !changed {
  122. return nil
  123. }
  124. return errs
  125. }
  126. // UnionMember represents a member of a union.
  127. type UnionMember struct {
  128. fieldName string
  129. discriminatorValue string
  130. }
  131. // NewUnionMember returns a new UnionMember for the given field name.
  132. func NewUnionMember(fieldName string) UnionMember {
  133. return UnionMember{fieldName: fieldName}
  134. }
  135. // NewDiscriminatedUnionMember returns a new UnionMember for the given field
  136. // name and discriminator value.
  137. func NewDiscriminatedUnionMember(fieldName, discriminatorValue string) UnionMember {
  138. return UnionMember{fieldName: fieldName, discriminatorValue: discriminatorValue}
  139. }
  140. // UnionMembership represents an ordered list of field union memberships.
  141. type UnionMembership struct {
  142. discriminatorName string
  143. members []UnionMember
  144. }
  145. // NewUnionMembership returns a new UnionMembership for the given list of members.
  146. // Member names must be unique.
  147. func NewUnionMembership(member ...UnionMember) *UnionMembership {
  148. return NewDiscriminatedUnionMembership("", member...)
  149. }
  150. // NewDiscriminatedUnionMembership returns a new UnionMembership for the given discriminator field and list of members.
  151. // members are provided in the same way as for NewUnionMembership.
  152. func NewDiscriminatedUnionMembership(discriminatorFieldName string, members ...UnionMember) *UnionMembership {
  153. return &UnionMembership{
  154. discriminatorName: discriminatorFieldName,
  155. members: members,
  156. }
  157. }
  158. // allFields returns a string listing all the field names of the member of a union for use in error reporting.
  159. func (u UnionMembership) allFields() []string {
  160. memberNames := make([]string, 0, len(u.members))
  161. for _, f := range u.members {
  162. memberNames = append(memberNames, fmt.Sprintf("`%s`", f.fieldName))
  163. }
  164. return memberNames
  165. }
  166. func unionValidate[T any](op operation.Operation, fldPath *field.Path,
  167. obj, oldObj T, union *UnionMembership, options UnionValidationOptions, isSetFns ...ExtractorFn[T, bool],
  168. ) field.ErrorList {
  169. if len(union.members) != len(isSetFns) {
  170. return field.ErrorList{
  171. field.InternalError(fldPath,
  172. fmt.Errorf("number of extractors (%d) does not match number of union members (%d)",
  173. len(isSetFns), len(union.members))),
  174. }
  175. }
  176. var specifiedFields []string
  177. var changed bool
  178. for i, fieldIsSet := range isSetFns {
  179. newIsSet := fieldIsSet(obj)
  180. if op.Type == operation.Update && !changed {
  181. oldIsSet := fieldIsSet(oldObj)
  182. changed = changed || newIsSet != oldIsSet
  183. }
  184. if newIsSet {
  185. specifiedFields = append(specifiedFields, union.members[i].fieldName)
  186. }
  187. }
  188. // If the union membership is unchanged, we don't need to re-validate.
  189. if op.Type == operation.Update && !changed {
  190. return nil
  191. }
  192. var errs field.ErrorList
  193. if len(specifiedFields) > 1 && options.ErrorForMultiple != nil {
  194. errs = append(errs, options.ErrorForMultiple(fldPath, specifiedFields, union.allFields()))
  195. }
  196. if len(specifiedFields) == 0 && options.ErrorForEmpty != nil {
  197. errs = append(errs, options.ErrorForEmpty(fldPath, union.allFields()))
  198. }
  199. return errs
  200. }