-
Notifications
You must be signed in to change notification settings - Fork 87
/
websocket_server.go
146 lines (126 loc) · 3.99 KB
/
websocket_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package turnpike
import (
"fmt"
"net/http"
"github.com/gorilla/websocket"
)
const (
jsonWebsocketProtocol = "wamp.2.json"
msgpackWebsocketProtocol = "wamp.2.msgpack"
)
type invalidPayload byte
func (e invalidPayload) Error() string {
return fmt.Sprintf("Invalid payloadType: %d", e)
}
type protocolExists string
func (e protocolExists) Error() string {
return "This protocol has already been registered: " + string(e)
}
type protocol struct {
payloadType int
serializer Serializer
}
// WebsocketServer handles websocket connections.
type WebsocketServer struct {
Router
Upgrader *websocket.Upgrader
protocols map[string]protocol
// The serializer to use for text frames. Defaults to JSONSerializer.
TextSerializer Serializer
// The serializer to use for binary frames. Defaults to JSONSerializer.
BinarySerializer Serializer
}
// NewWebsocketServer creates a new WebsocketServer from a map of realms
func NewWebsocketServer(realms map[string]Realm) (*WebsocketServer, error) {
log.Println("NewWebsocketServer")
r := NewDefaultRouter()
for uri, realm := range realms {
if err := r.RegisterRealm(URI(uri), realm); err != nil {
return nil, err
}
}
s := newWebsocketServer(r)
return s, nil
}
// NewBasicWebsocketServer creates a new WebsocketServer with a single basic realm
func NewBasicWebsocketServer(uri string) *WebsocketServer {
log.Println("NewBasicWebsocketServer")
s, _ := NewWebsocketServer(map[string]Realm{uri: {}})
return s
}
func newWebsocketServer(r Router) *WebsocketServer {
s := &WebsocketServer{
Router: r,
protocols: make(map[string]protocol),
}
s.Upgrader = &websocket.Upgrader{}
s.RegisterProtocol(jsonWebsocketProtocol, websocket.TextMessage, new(JSONSerializer))
s.RegisterProtocol(msgpackWebsocketProtocol, websocket.BinaryMessage, new(MessagePackSerializer))
return s
}
// RegisterProtocol registers a serializer that should be used for a given protocol string and payload type.
func (s *WebsocketServer) RegisterProtocol(proto string, payloadType int, serializer Serializer) error {
log.Println("RegisterProtocol:", proto)
if payloadType != websocket.TextMessage && payloadType != websocket.BinaryMessage {
return invalidPayload(payloadType)
}
if _, ok := s.protocols[proto]; ok {
return protocolExists(proto)
}
s.protocols[proto] = protocol{payloadType, serializer}
s.Upgrader.Subprotocols = append(s.Upgrader.Subprotocols, proto)
return nil
}
// GetLocalClient returns a client connected to the specified realm
func (s *WebsocketServer) GetLocalClient(realm string, details map[string]interface{}) (*Client, error) {
peer, err := s.Router.GetLocalPeer(URI(realm), details)
if err != nil {
return nil, err
}
c := NewClient(peer)
go c.Receive()
return c, nil
}
// ServeHTTP handles a new HTTP connection.
func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Println("WebsocketServer.ServeHTTP", r.Method, r.RequestURI)
// TODO: subprotocol?
conn, err := s.Upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("Error upgrading to websocket connection:", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.handleWebsocket(conn)
}
func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) {
var serializer Serializer
var payloadType int
if proto, ok := s.protocols[conn.Subprotocol()]; ok {
serializer = proto.serializer
payloadType = proto.payloadType
} else {
// TODO: this will not currently ever be hit because
// gorilla/websocket will reject the conncetion
// if the subprotocol isn't registered
switch conn.Subprotocol() {
case jsonWebsocketProtocol:
serializer = new(JSONSerializer)
payloadType = websocket.TextMessage
case msgpackWebsocketProtocol:
serializer = new(MessagePackSerializer)
payloadType = websocket.BinaryMessage
default:
conn.Close()
return
}
}
peer := websocketPeer{
conn: conn,
serializer: serializer,
messages: make(chan Message, 10),
payloadType: payloadType,
}
go peer.run()
logErr(s.Router.Accept(&peer))
}