example_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // +build example,exclude
  2. package rdsutils_test
  3. import (
  4. "crypto/tls"
  5. "crypto/x509"
  6. "database/sql"
  7. "flag"
  8. "fmt"
  9. "io/ioutil"
  10. "net/http"
  11. "net/url"
  12. "os"
  13. "github.com/go-sql-driver/mysql"
  14. "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
  15. "github.com/aws/aws-sdk-go/aws/session"
  16. "github.com/aws/aws-sdk-go/service/rds/rdsutils"
  17. )
  18. // ExampleConnectionStringBuilder contains usage of assuming a role and using
  19. // that to build the auth token.
  20. // Usage:
  21. // ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306
  22. func ExampleConnectionStringBuilder() {
  23. userPtr := flag.String("user", "", "user of the credentials")
  24. regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds")
  25. roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds")
  26. endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to")
  27. portPtr := flag.Int("port", 3306, "DB port to be connected to")
  28. tablePtr := flag.String("table", "test_table", "DB table to query against")
  29. dbNamePtr := flag.String("dbname", "", "DB name to query against")
  30. flag.Parse()
  31. // Check required flags. Will exit with status code 1 if
  32. // required field isn't set.
  33. if err := requiredFlags(
  34. userPtr,
  35. regionPtr,
  36. roleArnPtr,
  37. endpointPtr,
  38. portPtr,
  39. dbNamePtr,
  40. ); err != nil {
  41. fmt.Printf("Error: %v\n\n", err)
  42. flag.PrintDefaults()
  43. os.Exit(1)
  44. }
  45. err := registerRDSMysqlCerts(http.DefaultClient)
  46. if err != nil {
  47. panic(err)
  48. }
  49. sess := session.Must(session.NewSession())
  50. creds := stscreds.NewCredentials(sess, *roleArnPtr)
  51. v := url.Values{}
  52. // required fields for DB connection
  53. v.Add("tls", "rds")
  54. v.Add("allowCleartextPasswords", "true")
  55. endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr)
  56. b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds)
  57. connectStr, err := b.WithTCPFormat().WithParams(v).Build()
  58. const dbType = "mysql"
  59. db, err := sql.Open(dbType, connectStr)
  60. // if an error is encountered here, then most likely security groups are incorrect
  61. // in the database.
  62. if err != nil {
  63. panic(fmt.Errorf("failed to open connection to the database"))
  64. }
  65. rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr))
  66. if err != nil {
  67. panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err))
  68. }
  69. for rows.Next() {
  70. columns, err := rows.Columns()
  71. if err != nil {
  72. panic(fmt.Errorf("failed to read columns from row: %v", err))
  73. }
  74. fmt.Printf("rows colums:\n%d\n", len(columns))
  75. }
  76. }
  77. func requiredFlags(flags ...interface{}) error {
  78. for _, f := range flags {
  79. switch f.(type) {
  80. case nil:
  81. return fmt.Errorf("one or more required flags were not set")
  82. }
  83. }
  84. return nil
  85. }
  86. func registerRDSMysqlCerts(c *http.Client) error {
  87. resp, err := c.Get("https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem")
  88. if err != nil {
  89. return err
  90. }
  91. pem, err := ioutil.ReadAll(resp.Body)
  92. if err != nil {
  93. return err
  94. }
  95. rootCertPool := x509.NewCertPool()
  96. if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
  97. return fmt.Errorf("failed to append cert to cert pool!")
  98. }
  99. return mysql.RegisterTLSConfig("rds", &tls.Config{RootCAs: rootCertPool, InsecureSkipVerify: true})
  100. }