ironplans.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. // +build ee
  2. package billing
  3. import (
  4. "crypto/hmac"
  5. "crypto/sha256"
  6. "encoding/base64"
  7. "encoding/hex"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io/ioutil"
  12. "net/http"
  13. "net/url"
  14. "strings"
  15. "time"
  16. "github.com/porter-dev/porter/api/types"
  17. "github.com/porter-dev/porter/ee/models"
  18. "github.com/porter-dev/porter/ee/repository"
  19. "gorm.io/gorm"
  20. cemodels "github.com/porter-dev/porter/internal/models"
  21. )
  22. // Client contains an API client for IronPlans
  23. type Client struct {
  24. apiKey string
  25. serverURL string
  26. repo repository.EERepository
  27. httpClient *http.Client
  28. defaultPlanID string
  29. customPlanID string
  30. }
  31. // NewClient creates a new billing API client
  32. func NewClient(serverURL, apiKey string, repo repository.EERepository) (*Client, error) {
  33. httpClient := &http.Client{
  34. Timeout: time.Minute,
  35. }
  36. client := &Client{apiKey, serverURL, repo, httpClient, "", ""}
  37. // get the default plans from the IronPlans API server
  38. defPlanID, err := client.GetExistingPublicPlan("Free")
  39. if err != nil {
  40. return nil, err
  41. }
  42. customPlanID, err := client.GetExistingPublicPlan("Enterprise")
  43. if err != nil {
  44. return nil, err
  45. }
  46. client.defaultPlanID = defPlanID
  47. client.customPlanID = customPlanID
  48. return client, nil
  49. }
  50. func (c *Client) CreateTeam(proj *cemodels.Project) (string, error) {
  51. resp := &Team{}
  52. err := c.postRequest("/teams/v1", &CreateTeamRequest{
  53. Name: proj.Name,
  54. }, resp)
  55. if err != nil {
  56. return "", err
  57. }
  58. // put the user on the free plan, as the default behavior, if there is a default plan
  59. if c.defaultPlanID != "" {
  60. err = c.CreateOrUpdateSubscription(resp.ID, c.defaultPlanID)
  61. if err != nil {
  62. return "", fmt.Errorf("subscription creation failed: %s", err)
  63. }
  64. }
  65. _, err = c.repo.ProjectBilling().CreateProjectBilling(&models.ProjectBilling{
  66. ProjectID: proj.ID,
  67. BillingTeamID: resp.ID,
  68. })
  69. if err != nil {
  70. return "", err
  71. }
  72. return resp.ID, err
  73. }
  74. func (c *Client) DeleteTeam(proj *cemodels.Project) error {
  75. projBilling, err := c.repo.ProjectBilling().ReadProjectBillingByProjectID(proj.ID)
  76. if err != nil {
  77. return err
  78. }
  79. return c.deleteRequest(fmt.Sprintf("/teams/v1/%s", projBilling.BillingTeamID), nil, nil)
  80. }
  81. func (c *Client) GetTeamID(proj *cemodels.Project) (teamID string, err error) {
  82. projBilling, err := c.repo.ProjectBilling().ReadProjectBillingByProjectID(proj.ID)
  83. if err != nil {
  84. return "", err
  85. }
  86. return projBilling.BillingTeamID, nil
  87. }
  88. func (c *Client) CreatePlan(teamID string, proj *cemodels.Project, planSpec *types.AddProjectBillingRequest) (string, error) {
  89. // construct basic plan object
  90. planFeatures := make([]*CreatePlanFeature, 0)
  91. userDisplay := fmt.Sprintf("Up to %d users", planSpec.Users)
  92. if planSpec.Users == 0 {
  93. userDisplay = fmt.Sprintf("Unlimited users")
  94. }
  95. clusterDisplay := fmt.Sprintf("Up to %d clusters", planSpec.Clusters)
  96. if planSpec.Clusters == 0 {
  97. clusterDisplay = fmt.Sprintf("Unlimited clusters")
  98. }
  99. cpuDisplay := fmt.Sprintf("Up to %d CPUs", planSpec.CPU)
  100. if planSpec.CPU == 0 {
  101. cpuDisplay = fmt.Sprintf("Unlimited CPU")
  102. }
  103. ramDisplay := fmt.Sprintf("Up to %d GB RAM", planSpec.Memory)
  104. if planSpec.Memory == 0 {
  105. ramDisplay = fmt.Sprintf("Unlimited RAM")
  106. }
  107. planFeatures = append(planFeatures, &CreatePlanFeature{
  108. Display: userDisplay,
  109. })
  110. planFeatures = append(planFeatures, &CreatePlanFeature{
  111. Display: clusterDisplay,
  112. })
  113. planFeatures = append(planFeatures, &CreatePlanFeature{
  114. Display: cpuDisplay,
  115. })
  116. planFeatures = append(planFeatures, &CreatePlanFeature{
  117. Display: ramDisplay,
  118. })
  119. var customPlanID *string
  120. if c.customPlanID != "" {
  121. customPlanID = &c.customPlanID
  122. }
  123. createPlanReq := &CreatePlanRequest{
  124. Name: proj.Name,
  125. IsActive: true,
  126. IsPublic: false,
  127. IsTrialAllowed: true,
  128. ReplacePlanID: customPlanID,
  129. PerMonthPriceCents: planSpec.Price,
  130. PerYearPriceCents: 12 * planSpec.Price,
  131. Features: planFeatures,
  132. TeamsAccess: []*CreatePlanTeamsAccess{
  133. {
  134. TeamID: teamID,
  135. Revoke: false,
  136. },
  137. },
  138. }
  139. // find all relevant feature IDs
  140. listResp := &ListFeaturesResponse{}
  141. err := c.getRequest("/features/v1", listResp)
  142. if err != nil {
  143. return "", err
  144. }
  145. // create a feature spec per feature ID, and add to features array for plan
  146. for _, feature := range listResp.Results {
  147. featureSpec := &CreateFeatureSpecRequest{
  148. Name: "unnamed",
  149. RecordPeriod: "monthly",
  150. Aggregation: "sum",
  151. UnitPrice: 0,
  152. }
  153. switch feature.Slug {
  154. case FeatureSlugUsers:
  155. featureSpec.MaxLimit = planSpec.Users
  156. featureSpec.UnitsIncluded = planSpec.Users
  157. case FeatureSlugClusters:
  158. featureSpec.MaxLimit = planSpec.Clusters
  159. featureSpec.UnitsIncluded = planSpec.Clusters
  160. case FeatureSlugCPU:
  161. featureSpec.MaxLimit = planSpec.CPU
  162. featureSpec.UnitsIncluded = planSpec.CPU
  163. case FeatureSlugMemory:
  164. featureSpec.MaxLimit = planSpec.Memory
  165. featureSpec.UnitsIncluded = planSpec.Memory
  166. // continue on default behavior so that feature spec is not created for
  167. // features that don't match a slug
  168. default:
  169. continue
  170. }
  171. // create the feature spec
  172. resp := &CreateFeaturespecResponse{}
  173. err = c.postRequest("/featurespecs/v1/", featureSpec, resp)
  174. if err != nil {
  175. return "", err
  176. }
  177. var index int
  178. switch feature.Slug {
  179. case FeatureSlugUsers:
  180. index = 0
  181. case FeatureSlugClusters:
  182. index = 1
  183. case FeatureSlugCPU:
  184. index = 2
  185. case FeatureSlugMemory:
  186. index = 3
  187. }
  188. createPlanReq.Features[index].FeatureID = feature.ID
  189. createPlanReq.Features[index].SpecID = resp.ID
  190. }
  191. // create the plan and return the plan ID
  192. planResp := &Plan{}
  193. err = c.postRequest("/plans/v1/", createPlanReq, planResp)
  194. if err != nil {
  195. return "", err
  196. }
  197. return planResp.ID, nil
  198. }
  199. func (c *Client) CreateOrUpdateSubscription(teamID, planID string) error {
  200. // determine if subscription already exists by reading the team ID and seeing if the subscription
  201. // field has an ID attached
  202. teamResp := &Team{}
  203. err := c.getRequest(fmt.Sprintf("/teams/v1/%s", teamID), teamResp)
  204. if err != nil {
  205. return err
  206. }
  207. subReq := &CreateSubscriptionRequest{
  208. PlanID: planID,
  209. NextPlanID: c.defaultPlanID,
  210. TeamID: teamID,
  211. IsPaused: false,
  212. }
  213. // if subscription ID is not empty, perform a PUT request to update the subscription
  214. if teamResp.Subscription.ID != "" {
  215. // delete the subscription
  216. err = c.deleteRequest(fmt.Sprintf("/subscriptions/v1/%s/purge/", teamResp.Subscription.ID), nil, nil)
  217. if err != nil {
  218. return err
  219. }
  220. }
  221. return c.postRequest("/subscriptions/v1", subReq, nil)
  222. }
  223. func (c *Client) GetExistingPublicPlan(planName string) (string, error) {
  224. listResp := &ListPlansResponse{}
  225. err := c.getRequest("/plans/v1/", listResp, map[string]string{"is_public": "true"})
  226. if err != nil {
  227. return "", err
  228. }
  229. for _, plan := range listResp.Results {
  230. if plan.Name == planName {
  231. return plan.ID, nil
  232. }
  233. }
  234. return "", fmt.Errorf("plan not found")
  235. }
  236. func (c *Client) AddUserToTeam(teamID string, user *cemodels.User, role *cemodels.Role) error {
  237. // determine if user is already in team/has user billing
  238. userBilling, err := c.repo.UserBilling().ReadUserBilling(role.ProjectID, user.ID)
  239. if userBilling != nil {
  240. return nil
  241. }
  242. roleEnum := RoleEnumMember
  243. // if user's role is admin, add them to the team as an owner
  244. if role.Kind == types.RoleAdmin {
  245. roleEnum = RoleEnumOwner
  246. }
  247. req := &AddTeammateRequest{
  248. TeamID: teamID,
  249. Role: roleEnum,
  250. Email: user.Email,
  251. SourceID: fmt.Sprintf("%d-%d", role.ProjectID, user.ID),
  252. }
  253. resp := &Teammate{}
  254. err = c.postRequest("/team_memberships/v1", req, resp)
  255. if err != nil {
  256. return err
  257. }
  258. _, err = c.repo.UserBilling().CreateUserBilling(&models.UserBilling{
  259. ProjectID: role.ProjectID,
  260. UserID: user.ID,
  261. TeammateID: resp.ID,
  262. Token: []byte(""),
  263. })
  264. return err
  265. }
  266. func (c *Client) UpdateUserInTeam(role *cemodels.Role) error {
  267. // get the user billing information to get the membership id
  268. userBilling, err := c.repo.UserBilling().ReadUserBilling(role.ProjectID, role.UserID)
  269. if err != nil {
  270. return err
  271. }
  272. roleEnum := RoleEnumMember
  273. // if user's role is admin, add them to the team as an owner
  274. if role.Kind == types.RoleAdmin {
  275. roleEnum = RoleEnumOwner
  276. }
  277. req := &UpdateTeammateRequest{
  278. Role: roleEnum,
  279. }
  280. resp := &Teammate{}
  281. return c.putRequest(fmt.Sprintf("/team_memberships/v1/%s", userBilling.TeammateID), req, resp)
  282. }
  283. func (c *Client) RemoveUserFromTeam(role *cemodels.Role) error {
  284. // get the user billing information to get the membership id
  285. userBilling, err := c.repo.UserBilling().ReadUserBilling(role.ProjectID, role.UserID)
  286. if err != nil {
  287. return err
  288. }
  289. return c.deleteRequest(fmt.Sprintf("/team_memberships/v1/%s", userBilling.TeammateID), nil, nil)
  290. }
  291. // GetIDToken gets an id token for a user in a project, creating the ID token if necessary
  292. func (c *Client) GetIDToken(proj *cemodels.Project, user *cemodels.User) (token string, teamID string, err error) {
  293. // attempt to get a team ID for the project
  294. teamID, err = c.GetTeamID(proj)
  295. // attempt to read the user billing data from the project
  296. userBilling, err := c.repo.UserBilling().ReadUserBilling(proj.ID, user.ID)
  297. notFound := errors.Is(err, gorm.ErrRecordNotFound)
  298. if !notFound && err != nil {
  299. return "", "", err
  300. }
  301. if !notFound {
  302. token = string(userBilling.Token)
  303. if token != "" {
  304. // check if the JWT token has expired
  305. isTokExpired := isExpired(token)
  306. // if JWT token has not expired, return the token
  307. if !isTokExpired {
  308. return token, teamID, nil
  309. }
  310. }
  311. }
  312. req := &CreateIDTokenRequest{
  313. Email: user.Email,
  314. UserID: fmt.Sprintf("%d-%d", proj.ID, user.ID),
  315. }
  316. resp := &CreateIDTokenResponse{}
  317. err = c.postRequest("/customers/v1/token", req, resp)
  318. if err != nil {
  319. return "", "", err
  320. }
  321. token = resp.Token
  322. if notFound {
  323. _, err := c.repo.UserBilling().CreateUserBilling(&models.UserBilling{
  324. ProjectID: proj.ID,
  325. UserID: user.ID,
  326. Token: []byte(token),
  327. })
  328. if err != nil {
  329. return "", "", err
  330. }
  331. } else {
  332. _, err := c.repo.UserBilling().UpdateUserBilling(&models.UserBilling{
  333. Model: &gorm.Model{
  334. ID: userBilling.ID,
  335. },
  336. ProjectID: proj.ID,
  337. UserID: user.ID,
  338. Token: []byte(token),
  339. TeammateID: userBilling.TeammateID,
  340. })
  341. if err != nil {
  342. return "", "", err
  343. }
  344. }
  345. return token, teamID, nil
  346. }
  347. // VerifySignature verifies a webhook signature based on hmac protocol
  348. // https://docs.ironplans.com/webhook-events/webhook-events
  349. func (c *Client) VerifySignature(signature string, body []byte) bool {
  350. if len(signature) != 71 || !strings.HasPrefix(signature, "sha256=") {
  351. return false
  352. }
  353. actual := make([]byte, 32)
  354. _, err := hex.Decode(actual, []byte(signature[7:]))
  355. if err != nil {
  356. return false
  357. }
  358. computed := hmac.New(sha256.New, []byte(c.apiKey))
  359. _, err = computed.Write(body)
  360. if err != nil {
  361. return false
  362. }
  363. return hmac.Equal(computed.Sum(nil), actual)
  364. }
  365. func (c *Client) postRequest(path string, data interface{}, dst interface{}) error {
  366. return c.writeRequest("POST", path, data, dst)
  367. }
  368. func (c *Client) putRequest(path string, data interface{}, dst interface{}) error {
  369. return c.writeRequest("PUT", path, data, dst)
  370. }
  371. func (c *Client) deleteRequest(path string, data interface{}, dst interface{}) error {
  372. return c.writeRequest("DELETE", path, data, dst)
  373. }
  374. func (c *Client) getRequest(path string, dst interface{}, query ...map[string]string) error {
  375. reqURL, err := url.Parse(c.serverURL)
  376. if err != nil {
  377. return nil
  378. }
  379. reqURL.Path = path
  380. q := reqURL.Query()
  381. for _, queryGroup := range query {
  382. for key, val := range queryGroup {
  383. q.Add(key, val)
  384. }
  385. }
  386. reqURL.RawQuery = q.Encode()
  387. req, err := http.NewRequest(
  388. "GET",
  389. reqURL.String(),
  390. nil,
  391. )
  392. if err != nil {
  393. return err
  394. }
  395. req.Header.Set("Content-Type", "application/json; charset=utf-8")
  396. req.Header.Set("Accept", "application/json; charset=utf-8")
  397. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  398. res, err := c.httpClient.Do(req)
  399. if err != nil {
  400. return err
  401. }
  402. defer res.Body.Close()
  403. if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
  404. resBytes, err := ioutil.ReadAll(res.Body)
  405. if err != nil {
  406. return fmt.Errorf("request failed with status code %d, but could not read body (%s)\n", res.StatusCode, err.Error())
  407. }
  408. return fmt.Errorf("request failed with status code %d: %s\n", res.StatusCode, string(resBytes))
  409. }
  410. if dst != nil {
  411. return json.NewDecoder(res.Body).Decode(dst)
  412. }
  413. return nil
  414. }
  415. func (c *Client) writeRequest(method, path string, data interface{}, dst interface{}) error {
  416. reqURL, err := url.Parse(c.serverURL)
  417. if err != nil {
  418. return nil
  419. }
  420. reqURL.Path = path
  421. var strData []byte
  422. if data != nil {
  423. strData, err = json.Marshal(data)
  424. if err != nil {
  425. return err
  426. }
  427. }
  428. req, err := http.NewRequest(
  429. method,
  430. reqURL.String(),
  431. strings.NewReader(string(strData)),
  432. )
  433. if err != nil {
  434. return err
  435. }
  436. req.Header.Set("Content-Type", "application/json; charset=utf-8")
  437. req.Header.Set("Accept", "application/json; charset=utf-8")
  438. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  439. res, err := c.httpClient.Do(req)
  440. if err != nil {
  441. return err
  442. }
  443. defer res.Body.Close()
  444. if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
  445. resBytes, err := ioutil.ReadAll(res.Body)
  446. if err != nil {
  447. return fmt.Errorf("request failed with status code %d, but could not read body (%s)\n", res.StatusCode, err.Error())
  448. }
  449. return fmt.Errorf("request failed with status code %d: %s\n", res.StatusCode, string(resBytes))
  450. }
  451. if dst != nil {
  452. return json.NewDecoder(res.Body).Decode(dst)
  453. }
  454. return nil
  455. }
  456. const (
  457. FeatureSlugCPU string = "cpu"
  458. FeatureSlugMemory string = "memory"
  459. FeatureSlugClusters string = "clusters"
  460. FeatureSlugUsers string = "users"
  461. )
  462. func (c *Client) ParseProjectUsageFromWebhook(payload []byte) (*cemodels.ProjectUsage, error) {
  463. subscription := &SubscriptionWebhookRequest{}
  464. err := json.Unmarshal(payload, subscription)
  465. if err != nil {
  466. return nil, err
  467. }
  468. // if event type is not subscription, return wrong webhook event type error
  469. if subscription.EventType != "subscription" {
  470. return nil, nil
  471. }
  472. // get the project id linked to that team
  473. projBilling, err := c.repo.ProjectBilling().ReadProjectBillingByTeamID(subscription.TeamID)
  474. if err != nil {
  475. return nil, err
  476. }
  477. usage := &cemodels.ProjectUsage{
  478. ProjectID: projBilling.ProjectID,
  479. }
  480. for _, feature := range subscription.Plan.Features {
  481. // look for slug of "cpus" and "memory"
  482. maxLimit := uint(feature.FeatureSpec.MaxLimit)
  483. switch feature.Feature.Slug {
  484. case FeatureSlugCPU:
  485. usage.ResourceCPU = maxLimit
  486. case FeatureSlugMemory:
  487. usage.ResourceMemory = 1000 * maxLimit
  488. case FeatureSlugClusters:
  489. usage.Clusters = maxLimit
  490. case FeatureSlugUsers:
  491. usage.Users = maxLimit
  492. }
  493. }
  494. return usage, nil
  495. }
  496. type expiryJWT struct {
  497. ExpiresAt int64 `json:"exp"`
  498. }
  499. func isExpired(token string) bool {
  500. var encoded string
  501. if tokenSplit := strings.Split(token, "."); len(tokenSplit) != 3 {
  502. return true
  503. } else {
  504. encoded = tokenSplit[1]
  505. }
  506. decodedBytes, err := base64.RawStdEncoding.DecodeString(encoded)
  507. if err != nil {
  508. return true
  509. }
  510. expiryData := &expiryJWT{}
  511. err = json.Unmarshal(decodedBytes, expiryData)
  512. if err != nil {
  513. return true
  514. }
  515. expiryTime := time.Unix(expiryData.ExpiresAt, 0)
  516. return expiryTime.Before(time.Now())
  517. }