builder.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package rdsutils
  2. import (
  3. "fmt"
  4. "net/url"
  5. "github.com/aws/aws-sdk-go/aws/awserr"
  6. "github.com/aws/aws-sdk-go/aws/credentials"
  7. )
  8. // ConnectionFormat is the type of connection that will be
  9. // used to connect to the database
  10. type ConnectionFormat string
  11. // ConnectionFormat enums
  12. const (
  13. NoConnectionFormat ConnectionFormat = ""
  14. TCPFormat ConnectionFormat = "tcp"
  15. )
  16. // ErrNoConnectionFormat will be returned during build if no format had been
  17. // specified
  18. var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil)
  19. // ConnectionStringBuilder is a builder that will construct a connection
  20. // string with the provided parameters. params field is required to have
  21. // a tls specification and allowCleartextPasswords must be set to true.
  22. type ConnectionStringBuilder struct {
  23. dbName string
  24. endpoint string
  25. region string
  26. user string
  27. creds *credentials.Credentials
  28. connectFormat ConnectionFormat
  29. params url.Values
  30. }
  31. // NewConnectionStringBuilder will return an ConnectionStringBuilder
  32. func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, creds *credentials.Credentials) ConnectionStringBuilder {
  33. return ConnectionStringBuilder{
  34. dbName: dbName,
  35. endpoint: endpoint,
  36. region: region,
  37. user: dbUser,
  38. creds: creds,
  39. }
  40. }
  41. // WithEndpoint will return a builder with the given endpoint
  42. func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder {
  43. b.endpoint = endpoint
  44. return b
  45. }
  46. // WithRegion will return a builder with the given region
  47. func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder {
  48. b.region = region
  49. return b
  50. }
  51. // WithUser will return a builder with the given user
  52. func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder {
  53. b.user = user
  54. return b
  55. }
  56. // WithDBName will return a builder with the given database name
  57. func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder {
  58. b.dbName = dbName
  59. return b
  60. }
  61. // WithParams will return a builder with the given params. The parameters
  62. // will be included in the connection query string
  63. //
  64. // Example:
  65. // v := url.Values{}
  66. // v.Add("tls", "rds")
  67. // b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, creds)
  68. // connectStr, err := b.WithParams(v).WithTCPFormat().Build()
  69. func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder {
  70. b.params = params
  71. return b
  72. }
  73. // WithFormat will return a builder with the given connection format
  74. func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder {
  75. b.connectFormat = f
  76. return b
  77. }
  78. // WithTCPFormat will set the format to TCP and return the modified builder
  79. func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
  80. return b.WithFormat(TCPFormat)
  81. }
  82. // Build will return a new connection string that can be used to open a connection
  83. // to the desired database.
  84. //
  85. // Example:
  86. // b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
  87. // connectStr, err := b.WithTCPFormat().Build()
  88. // if err != nil {
  89. // panic(err)
  90. // }
  91. // const dbType = "mysql"
  92. // db, err := sql.Open(dbType, connectStr)
  93. func (b ConnectionStringBuilder) Build() (string, error) {
  94. if b.connectFormat == NoConnectionFormat {
  95. return "", ErrNoConnectionFormat
  96. }
  97. authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.creds)
  98. if err != nil {
  99. return "", err
  100. }
  101. connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s",
  102. b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName,
  103. )
  104. if len(b.params) > 0 {
  105. connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode())
  106. }
  107. return connectionStr, nil
  108. }