WebSocket Handler

Exercise: WebSocket Handler

Difficulty - Intermediate

Learning Objectives

  • Master WebSocket protocol for real-time communication
  • Build WebSocket servers with gorilla/websocket
  • Handle concurrent connections safely
  • Implement broadcast messaging
  • Practice graceful connection management

Problem Statement

Create a real-time chat server using WebSockets.

Implementation

  1package main
  2
  3import (
  4	"log"
  5	"net/http"
  6	"sync"
  7
  8	"github.com/gorilla/websocket"
  9)
 10
 11var upgrader = websocket.Upgrader{
 12	ReadBufferSize:  1024,
 13	WriteBufferSize: 1024,
 14	CheckOrigin: func(r *http.Request) bool {
 15		return true // Allow all origins in development
 16	},
 17}
 18
 19type Message struct {
 20	Username string `json:"username"`
 21	Content  string `json:"content"`
 22}
 23
 24type Hub struct {
 25	clients    map[*Client]bool
 26	broadcast  chan Message
 27	register   chan *Client
 28	unregister chan *Client
 29	mu         sync.RWMutex
 30}
 31
 32func NewHub() *Hub {
 33	return &Hub{
 34		clients:    make(map[*Client]bool),
 35		broadcast:  make(chan Message),
 36		register:   make(chan *Client),
 37		unregister: make(chan *Client),
 38	}
 39}
 40
 41func Run() {
 42	for {
 43		select {
 44		case client := <-h.register:
 45			h.mu.Lock()
 46			h.clients[client] = true
 47			h.mu.Unlock()
 48
 49		case client := <-h.unregister:
 50			h.mu.Lock()
 51			if _, ok := h.clients[client]; ok {
 52				delete(h.clients, client)
 53				close(client.send)
 54			}
 55			h.mu.Unlock()
 56
 57		case message := <-h.broadcast:
 58			h.mu.RLock()
 59			for client := range h.clients {
 60				select {
 61				case client.send <- message:
 62				default:
 63					close(client.send)
 64					delete(h.clients, client)
 65				}
 66			}
 67			h.mu.RUnlock()
 68		}
 69	}
 70}
 71
 72type Client struct {
 73	hub      *Hub
 74	conn     *websocket.Conn
 75	send     chan Message
 76	username string
 77}
 78
 79func readPump() {
 80	defer func() {
 81		c.hub.unregister <- c
 82		c.conn.Close()
 83	}()
 84
 85	for {
 86		var msg Message
 87		err := c.conn.ReadJSON(&msg)
 88		if err != nil {
 89			break
 90		}
 91		msg.Username = c.username
 92		c.hub.broadcast <- msg
 93	}
 94}
 95
 96func writePump() {
 97	defer c.conn.Close()
 98
 99	for message := range c.send {
100		err := c.conn.WriteJSON(message)
101		if err != nil {
102			return
103		}
104	}
105}
106
107func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
108	conn, err := upgrader.Upgrade(w, r, nil)
109	if err != nil {
110		log.Println(err)
111		return
112	}
113
114	username := r.URL.Query().Get("username")
115	if username == "" {
116		username = "Anonymous"
117	}
118
119	client := &Client{
120		hub:      hub,
121		conn:     conn,
122		send:     make(chan Message, 256),
123		username: username,
124	}
125
126	client.hub.register <- client
127
128	go client.writePump()
129	go client.readPump()
130}
131
132func main() {
133	hub := NewHub()
134	go hub.Run()
135
136	http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
137		serveWs(hub, w, r)
138	})
139
140	log.Println("Server starting on :8080")
141	log.Fatal(http.ListenAndServe(":8080", nil))
142}

Solution

Click to see the complete solution with explanations

Complete WebSocket Chat Server

The implementation above shows a production-ready WebSocket chat server using the Hub pattern. Here's a detailed breakdown:

  1package main
  2
  3import (
  4	"log"
  5	"net/http"
  6	"sync"
  7
  8	"github.com/gorilla/websocket"
  9)
 10
 11var upgrader = websocket.Upgrader{
 12	ReadBufferSize:  1024,
 13	WriteBufferSize: 1024,
 14	CheckOrigin: func(r *http.Request) bool {
 15		// In production, check origin properly
 16		// return r.Header.Get("Origin") == "https://yourdomain.com"
 17		return true
 18	},
 19}
 20
 21type Message struct {
 22	Username string `json:"username"`
 23	Content  string `json:"content"`
 24	Type     string `json:"type"` // "message", "join", "leave"
 25}
 26
 27type Hub struct {
 28	clients    map[*Client]bool
 29	broadcast  chan Message
 30	register   chan *Client
 31	unregister chan *Client
 32	mu         sync.RWMutex
 33}
 34
 35func NewHub() *Hub {
 36	return &Hub{
 37		clients:    make(map[*Client]bool),
 38		broadcast:  make(chan Message, 256),
 39		register:   make(chan *Client, 10),
 40		unregister: make(chan *Client, 10),
 41	}
 42}
 43
 44func Run() {
 45	for {
 46		select {
 47		case client := <-h.register:
 48			h.mu.Lock()
 49			h.clients[client] = true
 50			h.mu.Unlock()
 51			log.Printf("Client registered: %s", client.username, len(h.clients))
 52
 53			// Broadcast join message
 54			h.broadcast <- Message{
 55				Username: client.username,
 56				Content:  "joined the chat",
 57				Type:     "join",
 58			}
 59
 60		case client := <-h.unregister:
 61			h.mu.Lock()
 62			if _, ok := h.clients[client]; ok {
 63				delete(h.clients, client)
 64				close(client.send)
 65				log.Printf("Client unregistered: %s", client.username, len(h.clients))
 66			}
 67			h.mu.Unlock()
 68
 69			// Broadcast leave message
 70			h.broadcast <- Message{
 71				Username: client.username,
 72				Content:  "left the chat",
 73				Type:     "leave",
 74			}
 75
 76		case message := <-h.broadcast:
 77			h.mu.RLock()
 78			for client := range h.clients {
 79				select {
 80				case client.send <- message:
 81					// Message sent successfully
 82				default:
 83					// Client send buffer full, disconnect
 84					close(client.send)
 85					delete(h.clients, client)
 86					log.Printf("Disconnected slow client: %s", client.username)
 87				}
 88			}
 89			h.mu.RUnlock()
 90		}
 91	}
 92}
 93
 94type Client struct {
 95	hub      *Hub
 96	conn     *websocket.Conn
 97	send     chan Message
 98	username string
 99}
100
101const (
102	maxMessageSize = 512
103	writeWait      = 10 * time.Second
104	pongWait       = 60 * time.Second
105	pingPeriod     = / 10
106)
107
108func readPump() {
109	defer func() {
110		c.hub.unregister <- c
111		c.conn.Close()
112	}()
113
114	c.conn.SetReadLimit(maxMessageSize)
115	c.conn.SetReadDeadline(time.Now().Add(pongWait))
116	c.conn.SetPongHandler(func(string) error {
117		c.conn.SetReadDeadline(time.Now().Add(pongWait))
118		return nil
119	})
120
121	for {
122		var msg Message
123		err := c.conn.ReadJSON(&msg)
124		if err != nil {
125			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
126				log.Printf("WebSocket error: %v", err)
127			}
128			break
129		}
130
131		// Set username and type
132		msg.Username = c.username
133		msg.Type = "message"
134
135		// Broadcast to all clients
136		c.hub.broadcast <- msg
137	}
138}
139
140func writePump() {
141	ticker := time.NewTicker(pingPeriod)
142	defer func() {
143		ticker.Stop()
144		c.conn.Close()
145	}()
146
147	for {
148		select {
149		case message, ok := <-c.send:
150			c.conn.SetWriteDeadline(time.Now().Add(writeWait))
151			if !ok {
152				// Hub closed the channel
153				c.conn.WriteMessage(websocket.CloseMessage, []byte{})
154				return
155			}
156
157			if err := c.conn.WriteJSON(message); err != nil {
158				log.Printf("Write error: %v", err)
159				return
160			}
161
162		case <-ticker.C:
163			c.conn.SetWriteDeadline(time.Now().Add(writeWait))
164			if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
165				return
166			}
167		}
168	}
169}
170
171func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
172	conn, err := upgrader.Upgrade(w, r, nil)
173	if err != nil {
174		log.Printf("Upgrade error: %v", err)
175		return
176	}
177
178	username := r.URL.Query().Get("username")
179	if username == "" {
180		username = "Anonymous"
181	}
182
183	client := &Client{
184		hub:      hub,
185		conn:     conn,
186		send:     make(chan Message, 256),
187		username: username,
188	}
189
190	client.hub.register <- client
191
192	// Start goroutines for this client
193	go client.writePump()
194	go client.readPump()
195}
196
197func main() {
198	hub := NewHub()
199	go hub.Run()
200
201	http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
202		serveWs(hub, w, r)
203	})
204
205	// Serve static files for chat UI
206	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
207		http.ServeFile(w, r, "index.html")
208	})
209
210	log.Println("WebSocket server starting on :8080")
211	log.Fatal(http.ListenAndServe(":8080", nil))
212}

Explanation

Hub Pattern:

  • Central hub manages all connected clients
  • Three channels: register, unregister, broadcast
  • Single goroutine processes all hub operations
  • RWMutex protects client map during broadcasts

Client Structure:

  • Each client has two goroutines: readPump and writePump
  • send channel buffers messages to prevent blocking
  • conn is the WebSocket connection

readPump:

  • Continuously reads messages from WebSocket
  • Sets read deadline and handles pong messages
  • Sends messages to hub's broadcast channel
  • Cleans up on disconnect

writePump:

  • Reads from client's send channel
  • Sends messages to WebSocket
  • Periodically sends ping messages for keepalive
  • Handles write timeouts

Connection Upgrade:

  • upgrader.Upgrade() converts HTTP to WebSocket
  • CheckOrigin validates request origin
  • Returns WebSocket connection

Best Practices

1. Ping/Pong for Keepalive:

 1// Send ping every 54 seconds, expect pong within 60 seconds
 2const (
 3	pongWait   = 60 * time.Second
 4	pingPeriod = / 10
 5)
 6
 7c.conn.SetPongHandler(func(string) error {
 8	c.conn.SetReadDeadline(time.Now().Add(pongWait))
 9	return nil
10})

2. Buffered Send Channel:

 1// Buffer prevents slow clients from blocking broadcasts
 2send: make(chan Message, 256)
 3
 4// Disconnect slow clients instead of blocking
 5select {
 6case client.send <- message:
 7default:
 8	close(client.send)
 9	delete(h.clients, client)
10}

3. Separate Read/Write Goroutines:

1// One goroutine for reading, one for writing
2// Prevents deadlocks and allows concurrent operations
3go client.writePump()
4go client.readPump()

4. Proper Cleanup:

1defer func() {
2	c.hub.unregister <- c
3	c.conn.Close()
4}()

5. Message Size Limits:

1const maxMessageSize = 512
2c.conn.SetReadLimit(maxMessageSize)

Security Considerations

Origin Validation:

1CheckOrigin: func(r *http.Request) bool {
2	origin := r.Header.Get("Origin")
3	return origin == "https://yourdomain.com"
4}

Rate Limiting:

1// Limit messages per client
2type Client struct {
3	// ...
4	limiter *rate.Limiter
5}
6
7if !c.limiter.Allow() {
8	continue // Drop message
9}

Input Validation:

1// Sanitize message content
2msg.Content = html.EscapeString(msg.Content)
3
4// Limit message length
5if len(msg.Content) > maxMessageSize {
6	msg.Content = msg.Content[:maxMessageSize]
7}

Client-Side HTML/JavaScript

 1<!DOCTYPE html>
 2<html>
 3<head>
 4	<title>WebSocket Chat</title>
 5</head>
 6<body>
 7	<div id="messages"></div>
 8	<input id="messageInput" type="text" placeholder="Type a message...">
 9	<button onclick="sendMessage()">Send</button>
10
11	<script>
12		const username = prompt("Enter your username:") || "Anonymous";
13		const ws = new WebSocket(`ws://localhost:8080/ws?username=${username}`);
14
15		ws.onmessage = function(event) {
16			const msg = JSON.parse(event.data);
17			const div = document.getElementById('messages');
18			const msgEl = document.createElement('div');
19			msgEl.textContent = `${msg.username}: ${msg.content}`;
20			div.appendChild(msgEl);
21		};
22
23		function sendMessage() {
24			const input = document.getElementById('messageInput');
25			const message = {
26				content: input.value
27			};
28			ws.send(JSON.stringify(message));
29			input.value = '';
30		}
31
32		ws.onclose = function() {
33			console.log('WebSocket connection closed');
34		};
35
36		ws.onerror = function(error) {
37			console.error('WebSocket error:', error);
38		};
39	</script>
40</body>
41</html>

Performance Considerations

Broadcast Efficiency:

  • Use buffered channels to prevent blocking
  • Disconnect slow clients instead of waiting
  • RWMutex allows concurrent reads during broadcasts

Memory Management:

  • Limit message buffer size per client
  • Set max message size to prevent memory exhaustion
  • Clean up disconnected clients promptly

Scalability:

  • Single hub works for ~10K concurrent connections
  • For more, use Redis pub/sub for multi-server setup
  • Consider connection pooling for database operations

Key Takeaways

  1. Upgrade Connection: Use upgrader to convert HTTP to WebSocket
  2. Goroutine per Client: Handle each client in separate goroutines
  3. Hub Pattern: Central hub manages all clients and broadcasts
  4. Channel Communication: Use channels for thread-safe messaging
  5. Graceful Shutdown: Always close connections and channels properly
  6. Keepalive: Use ping/pong for connection health monitoring
  7. Buffer Management: Use buffered channels to handle slow clients
  8. Security: Validate origins and sanitize input