ironplans.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. // +build ee
  2. package billing
  3. import (
  4. "crypto/hmac"
  5. "crypto/sha256"
  6. "encoding/base64"
  7. "encoding/hex"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io/ioutil"
  12. "net/http"
  13. "net/url"
  14. "strings"
  15. "time"
  16. "github.com/porter-dev/porter/api/types"
  17. "github.com/porter-dev/porter/ee/models"
  18. "github.com/porter-dev/porter/ee/repository"
  19. "gorm.io/gorm"
  20. cemodels "github.com/porter-dev/porter/internal/models"
  21. )
  22. // Client contains an API client for IronPlans
  23. type Client struct {
  24. apiKey string
  25. serverURL string
  26. repo repository.EERepository
  27. httpClient *http.Client
  28. defaultPlan *Plan
  29. }
  30. // NewClient creates a new billing API client
  31. func NewClient(serverURL, apiKey string, repo repository.EERepository) (*Client, error) {
  32. httpClient := &http.Client{
  33. Timeout: time.Minute,
  34. }
  35. client := &Client{apiKey, serverURL, repo, httpClient, nil}
  36. // get the default plans from the IronPlans API server
  37. listResp := &ListPlansResponse{}
  38. err := client.getRequest("/plans/v1", listResp)
  39. if err != nil {
  40. return nil, err
  41. }
  42. for _, plan := range listResp.Results {
  43. if plan.Name == "Free" {
  44. copyPlan := plan
  45. client.defaultPlan = &copyPlan
  46. }
  47. }
  48. return client, nil
  49. }
  50. func (c *Client) CreateTeam(proj *cemodels.Project) (string, error) {
  51. resp := &Team{}
  52. err := c.postRequest("/teams/v1", &CreateTeamRequest{
  53. Name: proj.Name,
  54. }, resp)
  55. if err != nil {
  56. return "", err
  57. }
  58. // put the user on the free plan, as the default behavior, if there is a default plan
  59. if c.defaultPlan != nil {
  60. err := c.postRequest("/subscriptions/v1", &CreateSubscriptionRequest{
  61. PlanID: c.defaultPlan.ID,
  62. NextPlanID: c.defaultPlan.ID,
  63. TeamID: resp.ID,
  64. IsPaused: false,
  65. }, nil)
  66. if err != nil {
  67. return "", fmt.Errorf("subscription creation failed: %s", err)
  68. }
  69. }
  70. _, err = c.repo.ProjectBilling().CreateProjectBilling(&models.ProjectBilling{
  71. ProjectID: proj.ID,
  72. BillingTeamID: resp.ID,
  73. })
  74. if err != nil {
  75. return "", err
  76. }
  77. return resp.ID, err
  78. }
  79. func (c *Client) DeleteTeam(proj *cemodels.Project) error {
  80. projBilling, err := c.repo.ProjectBilling().ReadProjectBillingByProjectID(proj.ID)
  81. if err != nil {
  82. return err
  83. }
  84. return c.deleteRequest(fmt.Sprintf("/teams/v1/%s", projBilling.BillingTeamID), nil, nil)
  85. }
  86. func (c *Client) GetTeamID(proj *cemodels.Project) (teamID string, err error) {
  87. projBilling, err := c.repo.ProjectBilling().ReadProjectBillingByProjectID(proj.ID)
  88. if err != nil {
  89. return "", err
  90. }
  91. return projBilling.BillingTeamID, nil
  92. }
  93. func (c *Client) AddUserToTeam(teamID string, user *cemodels.User, role *cemodels.Role) error {
  94. roleEnum := RoleEnumMember
  95. // if user's role is admin, add them to the team as an owner
  96. if role.Kind == types.RoleAdmin {
  97. roleEnum = RoleEnumOwner
  98. }
  99. req := &AddTeammateRequest{
  100. TeamID: teamID,
  101. Role: roleEnum,
  102. Email: user.Email,
  103. SourceID: fmt.Sprintf("%d-%d", role.ProjectID, user.ID),
  104. }
  105. resp := &Teammate{}
  106. err := c.postRequest("/team_memberships/v1", req, resp)
  107. if err != nil {
  108. return err
  109. }
  110. _, err = c.repo.UserBilling().CreateUserBilling(&models.UserBilling{
  111. ProjectID: role.ProjectID,
  112. UserID: user.ID,
  113. TeammateID: resp.ID,
  114. Token: []byte(""),
  115. })
  116. return err
  117. }
  118. func (c *Client) UpdateUserInTeam(role *cemodels.Role) error {
  119. // get the user billing information to get the membership id
  120. userBilling, err := c.repo.UserBilling().ReadUserBilling(role.ProjectID, role.UserID)
  121. if err != nil {
  122. return err
  123. }
  124. roleEnum := RoleEnumMember
  125. // if user's role is admin, add them to the team as an owner
  126. if role.Kind == types.RoleAdmin {
  127. roleEnum = RoleEnumOwner
  128. }
  129. req := &UpdateTeammateRequest{
  130. Role: roleEnum,
  131. }
  132. resp := &Teammate{}
  133. return c.putRequest(fmt.Sprintf("/team_memberships/v1/%s", userBilling.TeammateID), req, resp)
  134. }
  135. func (c *Client) RemoveUserFromTeam(role *cemodels.Role) error {
  136. // get the user billing information to get the membership id
  137. userBilling, err := c.repo.UserBilling().ReadUserBilling(role.ProjectID, role.UserID)
  138. if err != nil {
  139. return err
  140. }
  141. return c.deleteRequest(fmt.Sprintf("/team_memberships/v1/%s", userBilling.TeammateID), nil, nil)
  142. }
  143. // GetIDToken gets an id token for a user in a project, creating the ID token if necessary
  144. func (c *Client) GetIDToken(proj *cemodels.Project, user *cemodels.User) (token string, teamID string, err error) {
  145. // attempt to get a team ID for the project
  146. teamID, err = c.GetTeamID(proj)
  147. // attempt to read the user billing data from the project
  148. userBilling, err := c.repo.UserBilling().ReadUserBilling(proj.ID, user.ID)
  149. notFound := errors.Is(err, gorm.ErrRecordNotFound)
  150. if !notFound && err != nil {
  151. return "", "", err
  152. }
  153. if !notFound {
  154. token = string(userBilling.Token)
  155. if token != "" {
  156. // check if the JWT token has expired
  157. isTokExpired := isExpired(token)
  158. // if JWT token has not expired, return the token
  159. if !isTokExpired {
  160. return token, teamID, nil
  161. }
  162. }
  163. }
  164. req := &CreateIDTokenRequest{
  165. Email: user.Email,
  166. UserID: fmt.Sprintf("%d-%d", proj.ID, user.ID),
  167. }
  168. resp := &CreateIDTokenResponse{}
  169. err = c.postRequest("/customers/v1/token", req, resp)
  170. if err != nil {
  171. return "", "", err
  172. }
  173. token = resp.Token
  174. if notFound {
  175. _, err := c.repo.UserBilling().CreateUserBilling(&models.UserBilling{
  176. ProjectID: proj.ID,
  177. UserID: user.ID,
  178. Token: []byte(token),
  179. })
  180. if err != nil {
  181. return "", "", err
  182. }
  183. } else {
  184. _, err := c.repo.UserBilling().UpdateUserBilling(&models.UserBilling{
  185. Model: &gorm.Model{
  186. ID: userBilling.ID,
  187. },
  188. ProjectID: proj.ID,
  189. UserID: user.ID,
  190. Token: []byte(token),
  191. TeammateID: userBilling.TeammateID,
  192. })
  193. if err != nil {
  194. return "", "", err
  195. }
  196. }
  197. return token, teamID, nil
  198. }
  199. // VerifySignature verifies a webhook signature based on hmac protocol
  200. // https://docs.ironplans.com/webhook-events/webhook-events
  201. func (c *Client) VerifySignature(signature string, body []byte) bool {
  202. if len(signature) != 71 || !strings.HasPrefix(signature, "sha256=") {
  203. return false
  204. }
  205. actual := make([]byte, 32)
  206. _, err := hex.Decode(actual, []byte(signature[7:]))
  207. if err != nil {
  208. return false
  209. }
  210. computed := hmac.New(sha256.New, []byte(c.apiKey))
  211. _, err = computed.Write(body)
  212. if err != nil {
  213. return false
  214. }
  215. return hmac.Equal(computed.Sum(nil), actual)
  216. }
  217. func (c *Client) postRequest(path string, data interface{}, dst interface{}) error {
  218. return c.writeRequest("POST", path, data, dst)
  219. }
  220. func (c *Client) putRequest(path string, data interface{}, dst interface{}) error {
  221. return c.writeRequest("PUT", path, data, dst)
  222. }
  223. func (c *Client) deleteRequest(path string, data interface{}, dst interface{}) error {
  224. return c.writeRequest("DELETE", path, data, dst)
  225. }
  226. func (c *Client) getRequest(path string, dst interface{}) error {
  227. reqURL, err := url.Parse(c.serverURL)
  228. if err != nil {
  229. return nil
  230. }
  231. reqURL.Path = path
  232. req, err := http.NewRequest(
  233. "GET",
  234. reqURL.String(),
  235. nil,
  236. )
  237. if err != nil {
  238. return err
  239. }
  240. req.Header.Set("Content-Type", "application/json; charset=utf-8")
  241. req.Header.Set("Accept", "application/json; charset=utf-8")
  242. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  243. res, err := c.httpClient.Do(req)
  244. if err != nil {
  245. return err
  246. }
  247. defer res.Body.Close()
  248. if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
  249. resBytes, err := ioutil.ReadAll(res.Body)
  250. if err != nil {
  251. return fmt.Errorf("request failed with status code %d, but could not read body (%s)\n", res.StatusCode, err.Error())
  252. }
  253. return fmt.Errorf("request failed with status code %d: %s\n", res.StatusCode, string(resBytes))
  254. }
  255. if dst != nil {
  256. return json.NewDecoder(res.Body).Decode(dst)
  257. }
  258. return nil
  259. }
  260. func (c *Client) writeRequest(method, path string, data interface{}, dst interface{}) error {
  261. reqURL, err := url.Parse(c.serverURL)
  262. if err != nil {
  263. return nil
  264. }
  265. reqURL.Path = path
  266. var strData []byte
  267. if data != nil {
  268. strData, err = json.Marshal(data)
  269. if err != nil {
  270. return err
  271. }
  272. }
  273. req, err := http.NewRequest(
  274. method,
  275. reqURL.String(),
  276. strings.NewReader(string(strData)),
  277. )
  278. if err != nil {
  279. return err
  280. }
  281. req.Header.Set("Content-Type", "application/json; charset=utf-8")
  282. req.Header.Set("Accept", "application/json; charset=utf-8")
  283. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  284. res, err := c.httpClient.Do(req)
  285. if err != nil {
  286. return err
  287. }
  288. defer res.Body.Close()
  289. if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
  290. resBytes, err := ioutil.ReadAll(res.Body)
  291. if err != nil {
  292. return fmt.Errorf("request failed with status code %d, but could not read body (%s)\n", res.StatusCode, err.Error())
  293. }
  294. return fmt.Errorf("request failed with status code %d: %s\n", res.StatusCode, string(resBytes))
  295. }
  296. if dst != nil {
  297. return json.NewDecoder(res.Body).Decode(dst)
  298. }
  299. return nil
  300. }
  301. const (
  302. FeatureSlugCPU string = "cpu"
  303. FeatureSlugMemory string = "memory"
  304. FeatureSlugClusters string = "clusters"
  305. FeatureSlugUsers string = "users"
  306. )
  307. func (c *Client) ParseProjectUsageFromWebhook(payload []byte) (*cemodels.ProjectUsage, error) {
  308. subscription := &SubscriptionWebhookRequest{}
  309. err := json.Unmarshal(payload, subscription)
  310. if err != nil {
  311. return nil, err
  312. }
  313. // if event type is not subscription, return wrong webhook event type error
  314. if subscription.EventType != "subscription" {
  315. return nil, nil
  316. }
  317. // get the project id linked to that team
  318. projBilling, err := c.repo.ProjectBilling().ReadProjectBillingByTeamID(subscription.TeamID)
  319. if err != nil {
  320. return nil, err
  321. }
  322. usage := &cemodels.ProjectUsage{
  323. ProjectID: projBilling.ProjectID,
  324. }
  325. for _, feature := range subscription.Plan.Features {
  326. // look for slug of "cpus" and "memory"
  327. maxLimit := uint(feature.FeatureSpec.MaxLimit)
  328. switch feature.Feature.Slug {
  329. case FeatureSlugCPU:
  330. usage.ResourceCPU = maxLimit
  331. case FeatureSlugMemory:
  332. usage.ResourceMemory = 1000 * maxLimit
  333. case FeatureSlugClusters:
  334. usage.Clusters = maxLimit
  335. case FeatureSlugUsers:
  336. usage.Users = maxLimit
  337. }
  338. }
  339. return usage, nil
  340. }
  341. type expiryJWT struct {
  342. ExpiresAt int64 `json:"exp"`
  343. }
  344. func isExpired(token string) bool {
  345. var encoded string
  346. if tokenSplit := strings.Split(token, "."); len(tokenSplit) != 3 {
  347. return true
  348. } else {
  349. encoded = tokenSplit[1]
  350. }
  351. decodedBytes, err := base64.RawStdEncoding.DecodeString(encoded)
  352. if err != nil {
  353. return true
  354. }
  355. expiryData := &expiryJWT{}
  356. err = json.Unmarshal(decodedBytes, expiryData)
  357. if err != nil {
  358. return true
  359. }
  360. expiryTime := time.Unix(expiryData.ExpiresAt, 0)
  361. return expiryTime.Before(time.Now())
  362. }