Exercise: Network Protocol
Difficulty - Advanced
Learning Objectives
- Implement custom binary protocol
- Handle framing and serialization
- Implement connection pooling
- Support request-response pattern
- Handle backpressure and flow control
- Implement protocol versioning
Problem Statement
Create a custom network protocol with binary serialization, connection management, and support for streaming data.
Core Components
1package netproto
2
3import (
4 "context"
5 "net"
6)
7
8type MessageType uint8
9
10const (
11 Request MessageType = iota
12 Response
13 Stream
14 Ping
15 Pong
16)
17
18type Message struct {
19 Type MessageType
20 ID uint32
21 Payload []byte
22}
23
24type Server struct {
25 addr string
26 listener net.Listener
27}
28
29type Client struct {
30 conn net.Conn
31}
32
33func NewServer(addr string) *Server
34func Start(ctx context.Context) error
35func HandleFunc(fn func(*Message) *Message)
36func NewClient(addr string)
37func Send(msg *Message)
Solution
Click to see the solution
1package netproto
2
3import (
4 "bufio"
5 "context"
6 "encoding/binary"
7 "errors"
8 "fmt"
9 "io"
10 "net"
11 "sync"
12 "sync/atomic"
13 "time"
14)
15
16const (
17 ProtocolVersion = 1
18 MaxPayloadSize = 1 << 20 // 1MB
19)
20
21var (
22 ErrPayloadTooLarge = errors.New("payload too large")
23 ErrInvalidMessage = errors.New("invalid message")
24)
25
26type MessageType uint8
27
28const (
29 Request MessageType = iota
30 Response
31 Stream
32 Ping
33 Pong
34 Error
35)
36
37// Wire format:
38// [Version:1][Type:1][ID:4][PayloadSize:4][Payload:N]
39type Message struct {
40 Version uint8
41 Type MessageType
42 ID uint32
43 Payload []byte
44}
45
46func Encode(w io.Writer) error {
47 if len(m.Payload) > MaxPayloadSize {
48 return ErrPayloadTooLarge
49 }
50
51 if m.Version == 0 {
52 m.Version = ProtocolVersion
53 }
54
55 buf := make([]byte, 10)
56 buf[0] = m.Version
57 buf[1] = uint8(m.Type)
58 binary.BigEndian.PutUint32(buf[2:6], m.ID)
59 binary.BigEndian.PutUint32(buf[6:10], uint32(len(m.Payload)))
60
61 if _, err := w.Write(buf); err != nil {
62 return err
63 }
64
65 if len(m.Payload) > 0 {
66 if _, err := w.Write(m.Payload); err != nil {
67 return err
68 }
69 }
70
71 return nil
72}
73
74func DecodeMessage(r io.Reader) {
75 header := make([]byte, 10)
76 if _, err := io.ReadFull(r, header); err != nil {
77 return nil, err
78 }
79
80 msg := &Message{
81 Version: header[0],
82 Type: MessageType(header[1]),
83 ID: binary.BigEndian.Uint32(header[2:6]),
84 }
85
86 payloadSize := binary.BigEndian.Uint32(header[6:10])
87 if payloadSize > MaxPayloadSize {
88 return nil, ErrPayloadTooLarge
89 }
90
91 if payloadSize > 0 {
92 msg.Payload = make([]byte, payloadSize)
93 if _, err := io.ReadFull(r, msg.Payload); err != nil {
94 return nil, err
95 }
96 }
97
98 return msg, nil
99}
100
101type Handler func(*Message) *Message
102
103type Server struct {
104 addr string
105 listener net.Listener
106 handler Handler
107 mu sync.RWMutex
108 conns map[net.Conn]struct{}
109}
110
111func NewServer(addr string) *Server {
112 return &Server{
113 addr: addr,
114 conns: make(map[net.Conn]struct{}),
115 }
116}
117
118func HandleFunc(fn Handler) {
119 s.mu.Lock()
120 s.handler = fn
121 s.mu.Unlock()
122}
123
124func Start(ctx context.Context) error {
125 listener, err := net.Listen("tcp", s.addr)
126 if err != nil {
127 return err
128 }
129
130 s.listener = listener
131
132 go s.acceptLoop(ctx)
133 return nil
134}
135
136func acceptLoop(ctx context.Context) {
137 for {
138 select {
139 case <-ctx.Done():
140 return
141 default:
142 }
143
144 conn, err := s.listener.Accept()
145 if err != nil {
146 continue
147 }
148
149 s.mu.Lock()
150 s.conns[conn] = struct{}{}
151 s.mu.Unlock()
152
153 go s.handleConnection(ctx, conn)
154 }
155}
156
157func handleConnection(ctx context.Context, conn net.Conn) {
158 defer func() {
159 conn.Close()
160 s.mu.Lock()
161 delete(s.conns, conn)
162 s.mu.Unlock()
163 }()
164
165 reader := bufio.NewReader(conn)
166 writer := bufio.NewWriter(conn)
167
168 for {
169 select {
170 case <-ctx.Done():
171 return
172 default:
173 }
174
175 conn.SetReadDeadline(time.Now().Add(30 * time.Second))
176 msg, err := DecodeMessage(reader)
177 if err != nil {
178 if err == io.EOF {
179 return
180 }
181 continue
182 }
183
184 // Handle ping
185 if msg.Type == Ping {
186 pong := &Message{
187 Type: Pong,
188 ID: msg.ID,
189 }
190 pong.Encode(writer)
191 writer.Flush()
192 continue
193 }
194
195 // Handle request
196 s.mu.RLock()
197 handler := s.handler
198 s.mu.RUnlock()
199
200 if handler != nil {
201 response := handler(msg)
202 if response != nil {
203 response.ID = msg.ID
204 response.Encode(writer)
205 writer.Flush()
206 }
207 }
208 }
209}
210
211func Stop() error {
212 if s.listener != nil {
213 s.listener.Close()
214 }
215
216 s.mu.Lock()
217 for conn := range s.conns {
218 conn.Close()
219 }
220 s.conns = make(map[net.Conn]struct{})
221 s.mu.Unlock()
222
223 return nil
224}
225
226type Client struct {
227 conn net.Conn
228 reader *bufio.Reader
229 writer *bufio.Writer
230 mu sync.Mutex
231 nextID uint32
232 pending map[uint32]chan *Message
233 pendingMu sync.RWMutex
234}
235
236func NewClient(addr string) {
237 conn, err := net.Dial("tcp", addr)
238 if err != nil {
239 return nil, err
240 }
241
242 client := &Client{
243 conn: conn,
244 reader: bufio.NewReader(conn),
245 writer: bufio.NewWriter(conn),
246 pending: make(map[uint32]chan *Message),
247 }
248
249 go client.readLoop()
250 return client, nil
251}
252
253func Send(msg *Message) {
254 msg.ID = atomic.AddUint32(&c.nextID, 1)
255
256 responseCh := make(chan *Message, 1)
257 c.pendingMu.Lock()
258 c.pending[msg.ID] = responseCh
259 c.pendingMu.Unlock()
260
261 defer func() {
262 c.pendingMu.Lock()
263 delete(c.pending, msg.ID)
264 c.pendingMu.Unlock()
265 }()
266
267 c.mu.Lock()
268 err := msg.Encode(c.writer)
269 if err == nil {
270 err = c.writer.Flush()
271 }
272 c.mu.Unlock()
273
274 if err != nil {
275 return nil, err
276 }
277
278 select {
279 case response := <-responseCh:
280 return response, nil
281 case <-time.After(5 * time.Second):
282 return nil, errors.New("timeout")
283 }
284}
285
286func readLoop() {
287 for {
288 msg, err := DecodeMessage(c.reader)
289 if err != nil {
290 return
291 }
292
293 c.pendingMu.RLock()
294 ch, exists := c.pending[msg.ID]
295 c.pendingMu.RUnlock()
296
297 if exists {
298 ch <- msg
299 }
300 }
301}
302
303func Ping() error {
304 msg := &Message{
305 Type: Ping,
306 }
307
308 response, err := c.Send(msg)
309 if err != nil {
310 return err
311 }
312
313 if response.Type != Pong {
314 return errors.New("invalid ping response")
315 }
316
317 return nil
318}
319
320func Close() error {
321 return c.conn.Close()
322}
323
324// Connection pool for clients
325type Pool struct {
326 addr string
327 clients []*Client
328 mu sync.Mutex
329 maxSize int
330}
331
332func NewPool(addr string, maxSize int) *Pool {
333 return &Pool{
334 addr: addr,
335 clients: make([]*Client, 0, maxSize),
336 maxSize: maxSize,
337 }
338}
339
340func Get() {
341 p.mu.Lock()
342 if len(p.clients) > 0 {
343 client := p.clients[len(p.clients)-1]
344 p.clients = p.clients[:len(p.clients)-1]
345 p.mu.Unlock()
346 return client, nil
347 }
348 p.mu.Unlock()
349
350 return NewClient(p.addr)
351}
352
353func Put(client *Client) {
354 p.mu.Lock()
355 defer p.mu.Unlock()
356
357 if len(p.clients) < p.maxSize {
358 p.clients = append(p.clients, client)
359 } else {
360 client.Close()
361 }
362}
363
364func Close() {
365 p.mu.Lock()
366 defer p.mu.Unlock()
367
368 for _, client := range p.clients {
369 client.Close()
370 }
371 p.clients = nil
372}
Usage Example
Server:
1package main
2
3import (
4 "context"
5 "fmt"
6 "log"
7)
8
9func main() {
10 server := netproto.NewServer(":8080")
11
12 // Register handler
13 server.HandleFunc(func(msg *netproto.Message) *netproto.Message {
14 fmt.Printf("Received: %s\n", string(msg.Payload))
15
16 return &netproto.Message{
17 Type: netproto.Response,
18 Payload: []byte("Response: " + string(msg.Payload)),
19 }
20 })
21
22 ctx := context.Background()
23 if err := server.Start(ctx); err != nil {
24 log.Fatal(err)
25 }
26
27 fmt.Println("Server started on :8080")
28 select {} // Block forever
29}
Client:
1package main
2
3import (
4 "fmt"
5 "log"
6)
7
8func main() {
9 client, err := netproto.NewClient("localhost:8080")
10 if err != nil {
11 log.Fatal(err)
12 }
13 defer client.Close()
14
15 // Send request
16 request := &netproto.Message{
17 Type: netproto.Request,
18 Payload: []byte("Hello, Server!"),
19 }
20
21 response, err := client.Send(request)
22 if err != nil {
23 log.Fatal(err)
24 }
25
26 fmt.Printf("Response: %s\n", string(response.Payload))
27
28 // Ping
29 if err := client.Ping(); err != nil {
30 log.Fatal(err)
31 }
32 fmt.Println("Ping successful")
33}
Connection pool:
1package main
2
3import (
4 "fmt"
5 "log"
6)
7
8func main() {
9 pool := netproto.NewPool("localhost:8080", 10)
10 defer pool.Close()
11
12 // Use connection from pool
13 client, err := pool.Get()
14 if err != nil {
15 log.Fatal(err)
16 }
17
18 response, err := client.Send(&netproto.Message{
19 Type: netproto.Request,
20 Payload: []byte("Hello from pool!"),
21 })
22 if err != nil {
23 log.Fatal(err)
24 }
25
26 fmt.Printf("Response: %s\n", string(response.Payload))
27
28 // Return to pool
29 pool.Put(client)
30}
Key Takeaways
- Binary protocols are more efficient than text protocols
- Fixed-size headers simplify parsing
- Message framing prevents data corruption
- Request-response correlation uses message IDs
- Connection pooling reduces overhead
- Timeouts prevent indefinite blocking
- Ping/pong enables health checking
- Protocol versioning enables backward compatibility