2
0

upgrader.go 804 B

12345678910111213141516171819202122232425262728293031323334353637
  1. package websocket
  2. import (
  3. "fmt"
  4. "net/http"
  5. "github.com/gorilla/websocket"
  6. )
  7. type Upgrader struct {
  8. WSUpgrader *websocket.Upgrader
  9. }
  10. var UpgraderCheckOriginErr = fmt.Errorf("request origin not allowed by Upgrader.CheckOrigin")
  11. func (u *Upgrader) Upgrade(
  12. w http.ResponseWriter,
  13. r *http.Request,
  14. responseHeader http.Header,
  15. ) (*websocket.Conn, http.ResponseWriter, *WebsocketSafeReadWriter, error) {
  16. // we manually call CheckOrigin and pass a specific error to the client in this case
  17. check := u.WSUpgrader.CheckOrigin(r)
  18. if !check {
  19. return nil, nil, nil, UpgraderCheckOriginErr
  20. }
  21. conn, err := u.WSUpgrader.Upgrade(w, r, responseHeader)
  22. safeWriter := &WebsocketSafeReadWriter{
  23. conn: conn,
  24. }
  25. rw := &WebsocketResponseWriter{conn, safeWriter}
  26. return conn, rw, safeWriter, err
  27. }