Network Protocol

Exercise: Network Protocol

Difficulty - Advanced

Learning Objectives

  • Implement custom binary protocol
  • Handle framing and serialization
  • Implement connection pooling
  • Support request-response pattern
  • Handle backpressure and flow control
  • Implement protocol versioning

Problem Statement

Create a custom network protocol with binary serialization, connection management, and support for streaming data.

Core Components

 1package netproto
 2
 3import (
 4    "context"
 5    "net"
 6)
 7
 8type MessageType uint8
 9
10const (
11    Request MessageType = iota
12    Response
13    Stream
14    Ping
15    Pong
16)
17
18type Message struct {
19    Type    MessageType
20    ID      uint32
21    Payload []byte
22}
23
24type Server struct {
25    addr     string
26    listener net.Listener
27}
28
29type Client struct {
30    conn net.Conn
31}
32
33func NewServer(addr string) *Server
34func Start(ctx context.Context) error
35func HandleFunc(fn func(*Message) *Message)
36func NewClient(addr string)
37func Send(msg *Message)

Solution

Click to see the solution
  1package netproto
  2
  3import (
  4    "bufio"
  5    "context"
  6    "encoding/binary"
  7    "errors"
  8    "fmt"
  9    "io"
 10    "net"
 11    "sync"
 12    "sync/atomic"
 13    "time"
 14)
 15
 16const (
 17    ProtocolVersion = 1
 18    MaxPayloadSize  = 1 << 20 // 1MB
 19)
 20
 21var (
 22    ErrPayloadTooLarge = errors.New("payload too large")
 23    ErrInvalidMessage  = errors.New("invalid message")
 24)
 25
 26type MessageType uint8
 27
 28const (
 29    Request MessageType = iota
 30    Response
 31    Stream
 32    Ping
 33    Pong
 34    Error
 35)
 36
 37// Wire format:
 38// [Version:1][Type:1][ID:4][PayloadSize:4][Payload:N]
 39type Message struct {
 40    Version uint8
 41    Type    MessageType
 42    ID      uint32
 43    Payload []byte
 44}
 45
 46func Encode(w io.Writer) error {
 47    if len(m.Payload) > MaxPayloadSize {
 48        return ErrPayloadTooLarge
 49    }
 50
 51    if m.Version == 0 {
 52        m.Version = ProtocolVersion
 53    }
 54
 55    buf := make([]byte, 10)
 56    buf[0] = m.Version
 57    buf[1] = uint8(m.Type)
 58    binary.BigEndian.PutUint32(buf[2:6], m.ID)
 59    binary.BigEndian.PutUint32(buf[6:10], uint32(len(m.Payload)))
 60
 61    if _, err := w.Write(buf); err != nil {
 62        return err
 63    }
 64
 65    if len(m.Payload) > 0 {
 66        if _, err := w.Write(m.Payload); err != nil {
 67            return err
 68        }
 69    }
 70
 71    return nil
 72}
 73
 74func DecodeMessage(r io.Reader) {
 75    header := make([]byte, 10)
 76    if _, err := io.ReadFull(r, header); err != nil {
 77        return nil, err
 78    }
 79
 80    msg := &Message{
 81        Version: header[0],
 82        Type:    MessageType(header[1]),
 83        ID:      binary.BigEndian.Uint32(header[2:6]),
 84    }
 85
 86    payloadSize := binary.BigEndian.Uint32(header[6:10])
 87    if payloadSize > MaxPayloadSize {
 88        return nil, ErrPayloadTooLarge
 89    }
 90
 91    if payloadSize > 0 {
 92        msg.Payload = make([]byte, payloadSize)
 93        if _, err := io.ReadFull(r, msg.Payload); err != nil {
 94            return nil, err
 95        }
 96    }
 97
 98    return msg, nil
 99}
100
101type Handler func(*Message) *Message
102
103type Server struct {
104    addr     string
105    listener net.Listener
106    handler  Handler
107    mu       sync.RWMutex
108    conns    map[net.Conn]struct{}
109}
110
111func NewServer(addr string) *Server {
112    return &Server{
113        addr:  addr,
114        conns: make(map[net.Conn]struct{}),
115    }
116}
117
118func HandleFunc(fn Handler) {
119    s.mu.Lock()
120    s.handler = fn
121    s.mu.Unlock()
122}
123
124func Start(ctx context.Context) error {
125    listener, err := net.Listen("tcp", s.addr)
126    if err != nil {
127        return err
128    }
129
130    s.listener = listener
131
132    go s.acceptLoop(ctx)
133    return nil
134}
135
136func acceptLoop(ctx context.Context) {
137    for {
138        select {
139        case <-ctx.Done():
140            return
141        default:
142        }
143
144        conn, err := s.listener.Accept()
145        if err != nil {
146            continue
147        }
148
149        s.mu.Lock()
150        s.conns[conn] = struct{}{}
151        s.mu.Unlock()
152
153        go s.handleConnection(ctx, conn)
154    }
155}
156
157func handleConnection(ctx context.Context, conn net.Conn) {
158    defer func() {
159        conn.Close()
160        s.mu.Lock()
161        delete(s.conns, conn)
162        s.mu.Unlock()
163    }()
164
165    reader := bufio.NewReader(conn)
166    writer := bufio.NewWriter(conn)
167
168    for {
169        select {
170        case <-ctx.Done():
171            return
172        default:
173        }
174
175        conn.SetReadDeadline(time.Now().Add(30 * time.Second))
176        msg, err := DecodeMessage(reader)
177        if err != nil {
178            if err == io.EOF {
179                return
180            }
181            continue
182        }
183
184        // Handle ping
185        if msg.Type == Ping {
186            pong := &Message{
187                Type: Pong,
188                ID:   msg.ID,
189            }
190            pong.Encode(writer)
191            writer.Flush()
192            continue
193        }
194
195        // Handle request
196        s.mu.RLock()
197        handler := s.handler
198        s.mu.RUnlock()
199
200        if handler != nil {
201            response := handler(msg)
202            if response != nil {
203                response.ID = msg.ID
204                response.Encode(writer)
205                writer.Flush()
206            }
207        }
208    }
209}
210
211func Stop() error {
212    if s.listener != nil {
213        s.listener.Close()
214    }
215
216    s.mu.Lock()
217    for conn := range s.conns {
218        conn.Close()
219    }
220    s.conns = make(map[net.Conn]struct{})
221    s.mu.Unlock()
222
223    return nil
224}
225
226type Client struct {
227    conn      net.Conn
228    reader    *bufio.Reader
229    writer    *bufio.Writer
230    mu        sync.Mutex
231    nextID    uint32
232    pending   map[uint32]chan *Message
233    pendingMu sync.RWMutex
234}
235
236func NewClient(addr string) {
237    conn, err := net.Dial("tcp", addr)
238    if err != nil {
239        return nil, err
240    }
241
242    client := &Client{
243        conn:    conn,
244        reader:  bufio.NewReader(conn),
245        writer:  bufio.NewWriter(conn),
246        pending: make(map[uint32]chan *Message),
247    }
248
249    go client.readLoop()
250    return client, nil
251}
252
253func Send(msg *Message) {
254    msg.ID = atomic.AddUint32(&c.nextID, 1)
255
256    responseCh := make(chan *Message, 1)
257    c.pendingMu.Lock()
258    c.pending[msg.ID] = responseCh
259    c.pendingMu.Unlock()
260
261    defer func() {
262        c.pendingMu.Lock()
263        delete(c.pending, msg.ID)
264        c.pendingMu.Unlock()
265    }()
266
267    c.mu.Lock()
268    err := msg.Encode(c.writer)
269    if err == nil {
270        err = c.writer.Flush()
271    }
272    c.mu.Unlock()
273
274    if err != nil {
275        return nil, err
276    }
277
278    select {
279    case response := <-responseCh:
280        return response, nil
281    case <-time.After(5 * time.Second):
282        return nil, errors.New("timeout")
283    }
284}
285
286func readLoop() {
287    for {
288        msg, err := DecodeMessage(c.reader)
289        if err != nil {
290            return
291        }
292
293        c.pendingMu.RLock()
294        ch, exists := c.pending[msg.ID]
295        c.pendingMu.RUnlock()
296
297        if exists {
298            ch <- msg
299        }
300    }
301}
302
303func Ping() error {
304    msg := &Message{
305        Type: Ping,
306    }
307
308    response, err := c.Send(msg)
309    if err != nil {
310        return err
311    }
312
313    if response.Type != Pong {
314        return errors.New("invalid ping response")
315    }
316
317    return nil
318}
319
320func Close() error {
321    return c.conn.Close()
322}
323
324// Connection pool for clients
325type Pool struct {
326    addr    string
327    clients []*Client
328    mu      sync.Mutex
329    maxSize int
330}
331
332func NewPool(addr string, maxSize int) *Pool {
333    return &Pool{
334        addr:    addr,
335        clients: make([]*Client, 0, maxSize),
336        maxSize: maxSize,
337    }
338}
339
340func Get() {
341    p.mu.Lock()
342    if len(p.clients) > 0 {
343        client := p.clients[len(p.clients)-1]
344        p.clients = p.clients[:len(p.clients)-1]
345        p.mu.Unlock()
346        return client, nil
347    }
348    p.mu.Unlock()
349
350    return NewClient(p.addr)
351}
352
353func Put(client *Client) {
354    p.mu.Lock()
355    defer p.mu.Unlock()
356
357    if len(p.clients) < p.maxSize {
358        p.clients = append(p.clients, client)
359    } else {
360        client.Close()
361    }
362}
363
364func Close() {
365    p.mu.Lock()
366    defer p.mu.Unlock()
367
368    for _, client := range p.clients {
369        client.Close()
370    }
371    p.clients = nil
372}

Usage Example

Server:

 1package main
 2
 3import (
 4    "context"
 5    "fmt"
 6    "log"
 7)
 8
 9func main() {
10    server := netproto.NewServer(":8080")
11
12    // Register handler
13    server.HandleFunc(func(msg *netproto.Message) *netproto.Message {
14        fmt.Printf("Received: %s\n", string(msg.Payload))
15
16        return &netproto.Message{
17            Type:    netproto.Response,
18            Payload: []byte("Response: " + string(msg.Payload)),
19        }
20    })
21
22    ctx := context.Background()
23    if err := server.Start(ctx); err != nil {
24        log.Fatal(err)
25    }
26
27    fmt.Println("Server started on :8080")
28    select {} // Block forever
29}

Client:

 1package main
 2
 3import (
 4    "fmt"
 5    "log"
 6)
 7
 8func main() {
 9    client, err := netproto.NewClient("localhost:8080")
10    if err != nil {
11        log.Fatal(err)
12    }
13    defer client.Close()
14
15    // Send request
16    request := &netproto.Message{
17        Type:    netproto.Request,
18        Payload: []byte("Hello, Server!"),
19    }
20
21    response, err := client.Send(request)
22    if err != nil {
23        log.Fatal(err)
24    }
25
26    fmt.Printf("Response: %s\n", string(response.Payload))
27
28    // Ping
29    if err := client.Ping(); err != nil {
30        log.Fatal(err)
31    }
32    fmt.Println("Ping successful")
33}

Connection pool:

 1package main
 2
 3import (
 4    "fmt"
 5    "log"
 6)
 7
 8func main() {
 9    pool := netproto.NewPool("localhost:8080", 10)
10    defer pool.Close()
11
12    // Use connection from pool
13    client, err := pool.Get()
14    if err != nil {
15        log.Fatal(err)
16    }
17
18    response, err := client.Send(&netproto.Message{
19        Type:    netproto.Request,
20        Payload: []byte("Hello from pool!"),
21    })
22    if err != nil {
23        log.Fatal(err)
24    }
25
26    fmt.Printf("Response: %s\n", string(response.Payload))
27
28    // Return to pool
29    pool.Put(client)
30}

Key Takeaways

  • Binary protocols are more efficient than text protocols
  • Fixed-size headers simplify parsing
  • Message framing prevents data corruption
  • Request-response correlation uses message IDs
  • Connection pooling reduces overhead
  • Timeouts prevent indefinite blocking
  • Ping/pong enables health checking
  • Protocol versioning enables backward compatibility