generate.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. // +build codegen
  2. package main
  3. import (
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "net/url"
  8. "os"
  9. "os/exec"
  10. "reflect"
  11. "regexp"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "text/template"
  16. "github.com/aws/aws-sdk-go/private/model/api"
  17. "github.com/aws/aws-sdk-go/private/util"
  18. )
  19. // TestSuiteTypeInput input test
  20. // TestSuiteTypeInput output test
  21. const (
  22. TestSuiteTypeInput = iota
  23. TestSuiteTypeOutput
  24. )
  25. type testSuite struct {
  26. *api.API
  27. Description string
  28. ClientEndpoint string
  29. Cases []testCase
  30. Type uint
  31. title string
  32. }
  33. func (s *testSuite) UnmarshalJSON(p []byte) error {
  34. type stub testSuite
  35. var v stub
  36. if err := json.Unmarshal(p, &v); err != nil {
  37. return err
  38. }
  39. if len(v.ClientEndpoint) == 0 {
  40. v.ClientEndpoint = "https://test"
  41. }
  42. for i := 0; i < len(v.Cases); i++ {
  43. if len(v.Cases[i].InputTest.Host) == 0 {
  44. v.Cases[i].InputTest.Host = "test"
  45. }
  46. if len(v.Cases[i].InputTest.URI) == 0 {
  47. v.Cases[i].InputTest.URI = "/"
  48. }
  49. }
  50. *s = testSuite(v)
  51. return nil
  52. }
  53. type testCase struct {
  54. TestSuite *testSuite
  55. Given *api.Operation
  56. Params interface{} `json:",omitempty"`
  57. Data interface{} `json:"result,omitempty"`
  58. InputTest testExpectation `json:"serialized"`
  59. OutputTest testExpectation `json:"response"`
  60. }
  61. type testExpectation struct {
  62. Body string
  63. Host string
  64. URI string
  65. Headers map[string]string
  66. JSONValues map[string]string
  67. StatusCode uint `json:"status_code"`
  68. }
  69. const preamble = `
  70. var _ bytes.Buffer // always import bytes
  71. var _ http.Request
  72. var _ json.Marshaler
  73. var _ time.Time
  74. var _ xmlutil.XMLNode
  75. var _ xml.Attr
  76. var _ = ioutil.Discard
  77. var _ = util.Trim("")
  78. var _ = url.Values{}
  79. var _ = io.EOF
  80. var _ = aws.String
  81. var _ = fmt.Println
  82. var _ = reflect.Value{}
  83. func init() {
  84. protocol.RandReader = &awstesting.ZeroReader{}
  85. }
  86. `
  87. var reStripSpace = regexp.MustCompile(`\s(\w)`)
  88. var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
  89. func removeImports(code string) string {
  90. return reImportRemoval.ReplaceAllString(code, "")
  91. }
  92. var extraImports = []string{
  93. "bytes",
  94. "encoding/json",
  95. "encoding/xml",
  96. "fmt",
  97. "io",
  98. "io/ioutil",
  99. "net/http",
  100. "testing",
  101. "time",
  102. "reflect",
  103. "net/url",
  104. "",
  105. "github.com/aws/aws-sdk-go/awstesting",
  106. "github.com/aws/aws-sdk-go/awstesting/unit",
  107. "github.com/aws/aws-sdk-go/private/protocol",
  108. "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
  109. "github.com/aws/aws-sdk-go/private/util",
  110. }
  111. func addImports(code string) string {
  112. importNames := make([]string, len(extraImports))
  113. for i, n := range extraImports {
  114. if n != "" {
  115. importNames[i] = fmt.Sprintf("%q", n)
  116. }
  117. }
  118. str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
  119. return str
  120. }
  121. func (t *testSuite) TestSuite() string {
  122. var buf bytes.Buffer
  123. t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
  124. return strings.ToUpper(x[1:])
  125. })
  126. t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
  127. for idx, c := range t.Cases {
  128. c.TestSuite = t
  129. buf.WriteString(c.TestCase(idx) + "\n")
  130. }
  131. return buf.String()
  132. }
  133. var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
  134. func Test{{ .OpName }}(t *testing.T) {
  135. svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("{{ .TestCase.TestSuite.ClientEndpoint }}")})
  136. {{ if ne .ParamsString "" }}input := {{ .ParamsString }}
  137. {{ range $k, $v := .JSONValues -}}
  138. input.{{ $k }} = {{ $v }}
  139. {{ end -}}
  140. req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
  141. r := req.HTTPRequest
  142. // build request
  143. req.Build()
  144. if req.Error != nil {
  145. t.Errorf("expect no error, got %v", req.Error)
  146. }
  147. {{ if ne .TestCase.InputTest.Body "" }}// assert body
  148. if r.Body == nil {
  149. t.Errorf("expect body not to be nil")
  150. }
  151. {{ .BodyAssertions }}{{ end }}
  152. // assert URL
  153. awstesting.AssertURL(t, "https://{{ .TestCase.InputTest.Host }}{{ .TestCase.InputTest.URI }}", r.URL.String())
  154. // assert headers
  155. {{ range $k, $v := .TestCase.InputTest.Headers -}}
  156. if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a {
  157. t.Errorf("expect %v to be %v", e, a)
  158. }
  159. {{ end }}
  160. }
  161. `))
  162. type tplInputTestCaseData struct {
  163. TestCase *testCase
  164. JSONValues map[string]string
  165. OpName, ParamsString string
  166. }
  167. func (t tplInputTestCaseData) BodyAssertions() string {
  168. code := &bytes.Buffer{}
  169. protocol := t.TestCase.TestSuite.API.Metadata.Protocol
  170. // Extract the body bytes
  171. switch protocol {
  172. case "rest-xml":
  173. fmt.Fprintln(code, "body := util.SortXML(r.Body)")
  174. default:
  175. fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
  176. }
  177. // Generate the body verification code
  178. expectedBody := util.Trim(t.TestCase.InputTest.Body)
  179. switch protocol {
  180. case "ec2", "query":
  181. fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
  182. expectedBody)
  183. case "rest-xml":
  184. if strings.HasPrefix(expectedBody, "<") {
  185. fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})",
  186. expectedBody, t.TestCase.Given.InputRef.ShapeName)
  187. } else {
  188. code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
  189. }
  190. case "json", "jsonrpc", "rest-json":
  191. if strings.HasPrefix(expectedBody, "{") {
  192. fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
  193. expectedBody)
  194. } else {
  195. code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
  196. }
  197. default:
  198. code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))"))
  199. }
  200. return code.String()
  201. }
  202. func fmtAssertEqual(e, a string) string {
  203. const format = `if e, a := %s, %s; e != a {
  204. t.Errorf("expect %%v, got %%v", e, a)
  205. }
  206. `
  207. return fmt.Sprintf(format, e, a)
  208. }
  209. func fmtAssertNil(v string) string {
  210. const format = `if e := %s; e != nil {
  211. t.Errorf("expect nil, got %%v", e)
  212. }
  213. `
  214. return fmt.Sprintf(format, v)
  215. }
  216. var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
  217. func Test{{ .OpName }}(t *testing.T) {
  218. svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
  219. buf := bytes.NewReader([]byte({{ .Body }}))
  220. req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
  221. req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
  222. // set headers
  223. {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
  224. {{ end }}
  225. // unmarshal response
  226. req.Handlers.UnmarshalMeta.Run(req)
  227. req.Handlers.Unmarshal.Run(req)
  228. if req.Error != nil {
  229. t.Errorf("expect not error, got %v", req.Error)
  230. }
  231. // assert response
  232. if out == nil {
  233. t.Errorf("expect not to be nil")
  234. }
  235. {{ .Assertions }}
  236. }
  237. `))
  238. type tplOutputTestCaseData struct {
  239. TestCase *testCase
  240. Body, OpName, Assertions string
  241. }
  242. func (i *testCase) TestCase(idx int) string {
  243. var buf bytes.Buffer
  244. opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
  245. if i.TestSuite.Type == TestSuiteTypeInput { // input test
  246. // query test should sort body as form encoded values
  247. switch i.TestSuite.API.Metadata.Protocol {
  248. case "query", "ec2":
  249. m, _ := url.ParseQuery(i.InputTest.Body)
  250. i.InputTest.Body = m.Encode()
  251. case "rest-xml":
  252. i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
  253. case "json", "rest-json":
  254. // Nothing to do
  255. }
  256. jsonValues := buildJSONValues(i.Given.InputRef.Shape)
  257. var params interface{}
  258. if m, ok := i.Params.(map[string]interface{}); ok {
  259. paramsMap := map[string]interface{}{}
  260. for k, v := range m {
  261. if _, ok := jsonValues[k]; !ok {
  262. paramsMap[k] = v
  263. } else {
  264. if i.InputTest.JSONValues == nil {
  265. i.InputTest.JSONValues = map[string]string{}
  266. }
  267. i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{}))
  268. }
  269. }
  270. params = paramsMap
  271. } else {
  272. params = i.Params
  273. }
  274. input := tplInputTestCaseData{
  275. TestCase: i,
  276. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  277. ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false),
  278. JSONValues: i.InputTest.JSONValues,
  279. }
  280. if err := tplInputTestCase.Execute(&buf, input); err != nil {
  281. panic(err)
  282. }
  283. } else if i.TestSuite.Type == TestSuiteTypeOutput {
  284. output := tplOutputTestCaseData{
  285. TestCase: i,
  286. Body: fmt.Sprintf("%q", i.OutputTest.Body),
  287. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  288. Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
  289. }
  290. if err := tplOutputTestCase.Execute(&buf, output); err != nil {
  291. panic(err)
  292. }
  293. }
  294. return buf.String()
  295. }
  296. func serializeJSONValue(m map[string]interface{}) string {
  297. str := "aws.JSONValue"
  298. str += walkMap(m)
  299. return str
  300. }
  301. func walkMap(m map[string]interface{}) string {
  302. str := "{"
  303. for k, v := range m {
  304. str += fmt.Sprintf("%q:", k)
  305. switch v.(type) {
  306. case bool:
  307. str += fmt.Sprintf("%t,\n", v.(bool))
  308. case string:
  309. str += fmt.Sprintf("%q,\n", v.(string))
  310. case int:
  311. str += fmt.Sprintf("%d,\n", v.(int))
  312. case float64:
  313. str += fmt.Sprintf("%f,\n", v.(float64))
  314. case map[string]interface{}:
  315. str += walkMap(v.(map[string]interface{}))
  316. }
  317. }
  318. str += "}"
  319. return str
  320. }
  321. func buildJSONValues(shape *api.Shape) map[string]struct{} {
  322. keys := map[string]struct{}{}
  323. for key, field := range shape.MemberRefs {
  324. if field.JSONValue {
  325. keys[key] = struct{}{}
  326. }
  327. }
  328. return keys
  329. }
  330. // generateTestSuite generates a protocol test suite for a given configuration
  331. // JSON protocol test file.
  332. func generateTestSuite(filename string) string {
  333. inout := "Input"
  334. if strings.Contains(filename, "output/") {
  335. inout = "Output"
  336. }
  337. var suites []testSuite
  338. f, err := os.Open(filename)
  339. if err != nil {
  340. panic(err)
  341. }
  342. err = json.NewDecoder(f).Decode(&suites)
  343. if err != nil {
  344. panic(err)
  345. }
  346. var buf bytes.Buffer
  347. buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
  348. var innerBuf bytes.Buffer
  349. innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
  350. for i, suite := range suites {
  351. svcPrefix := inout + "Service" + strconv.Itoa(i+1)
  352. suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
  353. suite.API.Operations = map[string]*api.Operation{}
  354. for idx, c := range suite.Cases {
  355. c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
  356. suite.API.Operations[c.Given.ExportedName] = c.Given
  357. }
  358. suite.Type = getType(inout)
  359. suite.API.NoInitMethods = true // don't generate init methods
  360. suite.API.NoStringerMethods = true // don't generate stringer methods
  361. suite.API.NoConstServiceNames = true // don't generate service names
  362. suite.API.Setup()
  363. suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
  364. suite.API.Metadata.EndpointsID = suite.API.Metadata.EndpointPrefix
  365. // Sort in order for deterministic test generation
  366. names := make([]string, 0, len(suite.API.Shapes))
  367. for n := range suite.API.Shapes {
  368. names = append(names, n)
  369. }
  370. sort.Strings(names)
  371. for _, name := range names {
  372. s := suite.API.Shapes[name]
  373. s.Rename(svcPrefix + "TestShape" + name)
  374. }
  375. svcCode := addImports(suite.API.ServiceGoCode())
  376. if i == 0 {
  377. importMatch := reImportRemoval.FindStringSubmatch(svcCode)
  378. buf.WriteString(importMatch[0] + "\n\n")
  379. buf.WriteString(preamble + "\n\n")
  380. }
  381. svcCode = removeImports(svcCode)
  382. svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
  383. svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
  384. svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
  385. buf.WriteString(svcCode + "\n\n")
  386. apiCode := removeImports(suite.API.APIGoCode())
  387. apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
  388. apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
  389. apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
  390. buf.WriteString(apiCode + "\n\n")
  391. innerBuf.WriteString(suite.TestSuite() + "\n")
  392. }
  393. return buf.String() + innerBuf.String()
  394. }
  395. // findMember searches the shape for the member with the matching key name.
  396. func findMember(shape *api.Shape, key string) string {
  397. for actualKey := range shape.MemberRefs {
  398. if strings.ToLower(key) == strings.ToLower(actualKey) {
  399. return actualKey
  400. }
  401. }
  402. return ""
  403. }
  404. // GenerateAssertions builds assertions for a shape based on its type.
  405. //
  406. // The shape's recursive values also will have assertions generated for them.
  407. func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
  408. if shape == nil {
  409. return ""
  410. }
  411. switch t := out.(type) {
  412. case map[string]interface{}:
  413. keys := util.SortedKeys(t)
  414. code := ""
  415. if shape.Type == "map" {
  416. for _, k := range keys {
  417. v := t[k]
  418. s := shape.ValueRef.Shape
  419. code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
  420. }
  421. } else if shape.Type == "jsonvalue" {
  422. code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{})))
  423. } else {
  424. for _, k := range keys {
  425. v := t[k]
  426. m := findMember(shape, k)
  427. s := shape.MemberRefs[m].Shape
  428. code += GenerateAssertions(v, s, prefix+"."+m+"")
  429. }
  430. }
  431. return code
  432. case []interface{}:
  433. code := ""
  434. for i, v := range t {
  435. s := shape.MemberRef.Shape
  436. code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
  437. }
  438. return code
  439. default:
  440. switch shape.Type {
  441. case "timestamp":
  442. return fmtAssertEqual(
  443. fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out),
  444. fmt.Sprintf("%s.UTC().String()", prefix),
  445. )
  446. case "blob":
  447. return fmtAssertEqual(
  448. fmt.Sprintf("%#v", out),
  449. fmt.Sprintf("string(%s)", prefix),
  450. )
  451. case "integer", "long":
  452. return fmtAssertEqual(
  453. fmt.Sprintf("int64(%#v)", out),
  454. fmt.Sprintf("*%s", prefix),
  455. )
  456. default:
  457. if !reflect.ValueOf(out).IsValid() {
  458. return fmtAssertNil(prefix)
  459. }
  460. return fmtAssertEqual(
  461. fmt.Sprintf("%#v", out),
  462. fmt.Sprintf("*%s", prefix),
  463. )
  464. }
  465. }
  466. }
  467. func getType(t string) uint {
  468. switch t {
  469. case "Input":
  470. return TestSuiteTypeInput
  471. case "Output":
  472. return TestSuiteTypeOutput
  473. default:
  474. panic("Invalid type for test suite")
  475. }
  476. }
  477. func main() {
  478. if len(os.Getenv("AWS_SDK_CODEGEN_DEBUG")) != 0 {
  479. api.LogDebug(os.Stdout)
  480. }
  481. fmt.Println("Generating test suite", os.Args[1:])
  482. out := generateTestSuite(os.Args[1])
  483. if len(os.Args) == 3 {
  484. f, err := os.Create(os.Args[2])
  485. defer f.Close()
  486. if err != nil {
  487. panic(err)
  488. }
  489. f.WriteString(util.GoFmt(out))
  490. f.Close()
  491. c := exec.Command("gofmt", "-s", "-w", os.Args[2])
  492. if err := c.Run(); err != nil {
  493. panic(err)
  494. }
  495. } else {
  496. fmt.Println(out)
  497. }
  498. }