Security Hardening for Go Applications

Security Hardening for Go Applications

Exercise Overview

Build a secure web API that implements multiple layers of security protection. You'll defend against common OWASP vulnerabilities including injection attacks, authentication bypass, and data exposure.

Learning Objectives

  • Implement input validation and sanitization
  • Secure password handling and authentication
  • Add rate limiting and DDoS protection
  • Implement security headers and CSP
  • Handle secrets securely
  • Prevent common web vulnerabilities

Initial Code

  1package main
  2
  3import (
  4	"context"
  5	"crypto/rand"
  6	"database/sql"
  7	"encoding/json"
  8	"fmt"
  9	"log"
 10	"net/http"
 11	"regexp"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"golang.org/x/crypto/bcrypt"
 17	_ "github.com/lib/pq"
 18)
 19
 20// TODO: Implement secure user model
 21type User struct {
 22	ID       int    `json:"id"`
 23	Username string `json:"username"`
 24	Email    string `json:"email"`
 25	Password string `json:"-"` // Never expose password in JSON
 26	// Add security fields
 27}
 28
 29// TODO: Implement input validation
 30type ValidationRule struct {
 31	// Define validation rule structure
 32}
 33
 34type Validator struct {
 35	rules map[string][]ValidationRule
 36}
 37
 38// TODO: Implement secure password handling
 39type AuthService struct {
 40	db *sql.DB
 41	// Add authentication fields
 42}
 43
 44// TODO: Implement rate limiting
 45type RateLimiter struct {
 46	// Implement token bucket or sliding window
 47}
 48
 49// TODO: Implement security headers middleware
 50type SecurityHeaders struct {
 51	// Define security header configurations
 52}
 53
 54// TODO: Implement CSRF protection
 55type CSRFProtection struct {
 56	// Add CSRF token management
 57}
 58
 59// TODO: Implement secure session management
 60type SessionManager struct {
 61	// Add secure session handling
 62}
 63
 64// API Handlers
 65type APIHandler struct {
 66	authService *AuthService
 67	validator   *Validator
 68	rateLimiter *RateLimiter
 69	sessions    *SessionManager
 70}
 71
 72// TODO: Implement secure registration endpoint
 73func Register(w http.ResponseWriter, r *http.Request) {
 74	// Implement secure user registration
 75}
 76
 77// TODO: Implement secure login endpoint
 78func Login(w http.ResponseWriter, r *http.Request) {
 79	// Implement secure authentication
 80}
 81
 82// TODO: Implement protected profile endpoint
 83func Profile(w http.ResponseWriter, r *http.Request) {
 84	// Implement secure profile access
 85}
 86
 87// TODO: Implement secure password change
 88func ChangePassword(w http.ResponseWriter, r *http.Request) {
 89	// Implement secure password update
 90}
 91
 92// TODO: Implement secure search endpoint
 93func Search(w http.ResponseWriter, r *http.Request) {
 94	// Implement secure search with input validation
 95}
 96
 97// TODO: Implement secure file upload
 98func UploadFile(w http.ResponseWriter, r *http.Request) {
 99	// Implement secure file upload with validation
100}
101
102func main() {
103	// Initialize database
104	db, err := sql.Open("postgres", "host=localhost user=app dbname=secureapp sslmode=require")
105	if err != nil {
106		log.Fatal(err)
107	}
108	defer db.Close()
109
110	// Initialize services
111	authService := &AuthService{db: db}
112	validator := &Validator{
113		rules: make(map[string][]ValidationRule),
114	}
115	rateLimiter := &RateLimiter{}
116	sessions := &SessionManager{}
117
118	handler := &APIHandler{
119		authService: authService,
120		validator:   validator,
121		rateLimiter: rateLimiter,
122		sessions:    sessions,
123	}
124
125	// TODO: Set up secure middleware chain
126	mux := http.NewServeMux()
127	mux.HandleFunc("/api/register", handler.Register)
128	mux.HandleFunc("/api/login", handler.Login)
129	mux.HandleFunc("/api/profile", handler.Profile)
130	mux.HandleFunc("/api/change-password", handler.ChangePassword)
131	mux.HandleFunc("/api/search", handler.Search)
132	mux.HandleFunc("/api/upload", handler.UploadFile)
133
134	fmt.Println("Secure API server starting on :8443")
135	// TODO: Configure HTTPS properly
136	log.Fatal(http.ListenAndServeTLS(":8443", "server.crt", "server.key", mux))
137}

Tasks

Task 1: Input Validation System

Implement comprehensive input validation:

 1type ValidationRule struct {
 2	Name        string
 3	Required    bool
 4	MinLength   int
 5	MaxLength   int
 6	Pattern     *regexp.Regexp
 7	CustomCheck func(interface{}) error
 8}
 9
10type ValidationError struct {
11	Field   string `json:"field"`
12	Message string `json:"message"`
13}
14
15func Validate(data interface{}, targetType string) []ValidationError {
16	// Implement validation logic
17	return nil
18}
19
20// Built-in validators
21func ValidateEmail(email string) bool {
22	emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
23	return emailRegex.MatchString(email)
24}
25
26func ValidatePassword(password string) []string {
27	var errors []string
28
29	if len(password) < 8 {
30		errors = append(errors, "Password must be at least 8 characters")
31	}
32
33	if !regexp.MustCompile(`[A-Z]`).MatchString(password) {
34		errors = append(errors, "Password must contain uppercase letter")
35	}
36
37	if !regexp.MustCompile(`[a-z]`).MatchString(password) {
38		errors = append(errors, "Password must contain lowercase letter")
39	}
40
41	if !regexp.MustCompile(`[0-9]`).MatchString(password) {
42		errors = append(errors, "Password must contain number")
43	}
44
45	if !regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) {
46		errors = append(errors, "Password must contain special character")
47	}
48
49	return errors
50}

Task 2: Secure Authentication System

Implement secure authentication with bcrypt:

 1type AuthService struct {
 2	db        *sql.DB
 3	tokenSecret []byte
 4}
 5
 6func HashPassword(password string) {
 7	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
 8	if err != nil {
 9		return "", err
10	}
11	return string(hash), nil
12}
13
14func VerifyPassword(hashedPassword, password string) bool {
15	err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
16	return err == nil
17}
18
19func GenerateJWT(userID int, username string) {
20	// Implement JWT generation with proper claims
21	return "", nil
22}
23
24func ValidateJWT(tokenString string) {
25	// Implement JWT validation
26	return nil, nil
27}

Task 3: Rate Limiting Implementation

Implement token bucket rate limiting:

 1type TokenBucket struct {
 2	tokens       int
 3	maxTokens    int
 4	refillRate   int
 5	lastRefill   time.Time
 6	mutex        sync.Mutex
 7}
 8
 9type RateLimiter struct {
10	buckets map[string]*TokenBucket
11	mutex   sync.RWMutex
12	limit   int
13	window  time.Duration
14}
15
16func Allow(key string) bool {
17	rl.mutex.Lock()
18	defer rl.mutex.Unlock()
19
20	bucket, exists := rl.buckets[key]
21	if !exists {
22		bucket = &TokenBucket{
23			maxTokens:  rl.limit,
24			tokens:     rl.limit,
25			refillRate: 1,
26			lastRefill: time.Now(),
27		}
28		rl.buckets[key] = bucket
29	}
30
31	bucket.mutex.Lock()
32	defer bucket.mutex.Unlock()
33
34	// Refill tokens based on time elapsed
35	now := time.Now()
36	elapsed := now.Sub(bucket.lastRefill)
37	tokensToAdd := int(elapsed.Seconds()) * bucket.refillRate
38
39	bucket.tokens += tokensToAdd
40	if bucket.tokens > bucket.maxTokens {
41		bucket.tokens = bucket.maxTokens
42	}
43	bucket.lastRefill = now
44
45	if bucket.tokens > 0 {
46		bucket.tokens--
47		return true
48	}
49
50	return false
51}

Task 4: Security Headers Middleware

Implement comprehensive security headers:

 1type SecurityConfig struct {
 2	MaxAge         int
 3	IncludeSubDomains bool
 4	ReportURI      string
 5	ReportOnly     bool
 6}
 7
 8func SecurityHeaders(config SecurityConfig) func(http.Handler) http.Handler {
 9	return func(next http.Handler) http.Handler {
10		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
11			// Content Security Policy
12			csp := "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'"
13			w.Header().Set("Content-Security-Policy", csp)
14
15			// X-Frame-Options
16			w.Header().Set("X-Frame-Options", "DENY")
17
18			// X-Content-Type-Options
19			w.Header().Set("X-Content-Type-Options", "nosniff")
20
21			// X-XSS-Protection
22			w.Header().Set("X-XSS-Protection", "1; mode=block")
23
24			// Strict-Transport-Security
25			hstsValue := fmt.Sprintf("max-age=%d", config.MaxAge)
26			if config.IncludeSubDomains {
27				hstsValue += "; includeSubDomains"
28			}
29			w.Header().Set("Strict-Transport-Security", hstsValue)
30
31			// Referrer Policy
32			w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
33
34			// Permissions Policy
35			w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
36
37			next.ServeHTTP(w, r)
38		})
39	}
40}

Task 5: CSRF Protection

Implement CSRF token validation:

 1type CSRFProtection struct {
 2	secret []byte
 3}
 4
 5func GenerateToken(userID int) {
 6	// Generate secure random token
 7	token := make([]byte, 32)
 8	_, err := rand.Read(token)
 9	if err != nil {
10		return "", err
11	}
12
13	// Combine with user ID and timestamp
14	data := fmt.Sprintf("%d:%d:%x", userID, time.Now().Unix(), token)
15
16	// Sign the token
17	signature := c.sign(data)
18	return fmt.Sprintf("%s.%s", data, signature), nil
19}
20
21func ValidateToken(token string, userID int) bool {
22	// Implement token validation logic
23	return true
24}
25
26func sign(data string) string {
27	// Implement HMAC signing
28	return ""
29}

Solution Approach

Click to see detailed solution

Complete Implementation:

  1package main
  2
  3import (
  4	"context"
  5	"crypto/hmac"
  6	"crypto/rand"
  7	"crypto/sha256"
  8	"database/sql"
  9	"encoding/base64"
 10	"encoding/json"
 11	"fmt"
 12	"log"
 13	"net/http"
 14	"regexp"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"golang.org/x/crypto/bcrypt"
 20	_ "github.com/lib/pq"
 21)
 22
 23type User struct {
 24	ID        int       `json:"id"`
 25	Username  string    `json:"username"`
 26	Email     string    `json:"email"`
 27	Password  string    `json:"-"`
 28	CreatedAt time.Time `json:"created_at"`
 29	UpdatedAt time.Time `json:"updated_at"`
 30	LastLogin *time.Time `json:"last_login,omitempty"`
 31}
 32
 33type ValidationRule struct {
 34	Name        string
 35	Required    bool
 36	MinLength   int
 37	MaxLength   int
 38	Pattern     *regexp.Regexp
 39	CustomCheck func(interface{}) error
 40}
 41
 42type ValidationError struct {
 43	Field   string `json:"field"`
 44	Message string `json:"message"`
 45}
 46
 47type Validator struct {
 48	rules map[string][]ValidationRule
 49}
 50
 51func NewValidator() *Validator {
 52	return &Validator{
 53		rules: make(map[string][]ValidationRule),
 54	}
 55}
 56
 57func AddRule(targetType string, rule ValidationRule) {
 58	v.rules[targetType] = append(v.rules[targetType], rule)
 59}
 60
 61func Validate(data interface{}, targetType string) []ValidationError {
 62	var errors []ValidationError
 63
 64	// Convert data to map for validation
 65	dataMap, ok := data.(map[string]interface{})
 66	if !ok {
 67		errors = append(errors, ValidationError{
 68			Field:   "root",
 69			Message: "Invalid data format",
 70		})
 71		return errors
 72	}
 73
 74	rules, exists := v.rules[targetType]
 75	if !exists {
 76		return errors
 77	}
 78
 79	for _, rule := range rules {
 80		value, fieldExists := dataMap[rule.Name]
 81
 82		if rule.Required && !fieldExists {
 83			errors = append(errors, ValidationError{
 84				Field:   rule.Name,
 85				Message: fmt.Sprintf("%s is required", rule.Name),
 86			})
 87			continue
 88		}
 89
 90		if !fieldExists {
 91			continue
 92		}
 93
 94		strValue := fmt.Sprintf("%v", value)
 95
 96		if rule.MinLength > 0 && len(strValue) < rule.MinLength {
 97			errors = append(errors, ValidationError{
 98				Field:   rule.Name,
 99				Message: fmt.Sprintf("%s must be at least %d characters", rule.Name, rule.MinLength),
100			})
101		}
102
103		if rule.MaxLength > 0 && len(strValue) > rule.MaxLength {
104			errors = append(errors, ValidationError{
105				Field:   rule.Name,
106				Message: fmt.Sprintf("%s must be no more than %d characters", rule.Name, rule.MaxLength),
107			})
108		}
109
110		if rule.Pattern != nil && !rule.Pattern.MatchString(strValue) {
111			errors = append(errors, ValidationError{
112				Field:   rule.Name,
113				Message: fmt.Sprintf("%s format is invalid", rule.Name),
114			})
115		}
116
117		if rule.CustomCheck != nil {
118			if err := rule.CustomCheck(value); err != nil {
119				errors = append(errors, ValidationError{
120					Field:   rule.Name,
121					Message: err.Error(),
122				})
123			}
124		}
125	}
126
127	return errors
128}
129
130func ValidateEmail(email string) bool {
131	emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
132	return emailRegex.MatchString(email) && len(email) <= 254
133}
134
135func ValidatePassword(password string) []string {
136	var errors []string
137
138	if len(password) < 8 {
139		errors = append(errors, "Password must be at least 8 characters")
140	}
141
142	if len(password) > 128 {
143		errors = append(errors, "Password must be no more than 128 characters")
144	}
145
146	if !regexp.MustCompile(`[A-Z]`).MatchString(password) {
147		errors = append(errors, "Password must contain uppercase letter")
148	}
149
150	if !regexp.MustCompile(`[a-z]`).MatchString(password) {
151		errors = append(errors, "Password must contain lowercase letter")
152	}
153
154	if !regexp.MustCompile(`[0-9]`).MatchString(password) {
155		errors = append(errors, "Password must contain number")
156	}
157
158	if !regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) {
159		errors = append(errors, "Password must contain special character")
160	}
161
162	// Check for common patterns
163	if strings.ToLower(password) == "password" ||
164	   strings.ToLower(password) == "12345678" ||
165	   strings.Contains(strings.ToLower(password), "qwerty") {
166		errors = append(errors, "Password is too common")
167	}
168
169	return errors
170}
171
172type TokenBucket struct {
173	tokens     int
174	maxTokens  int
175	refillRate int
176	lastRefill time.Time
177	mutex      sync.Mutex
178}
179
180type RateLimiter struct {
181	buckets map[string]*TokenBucket
182	mutex   sync.RWMutex
183	limit   int
184	window  time.Duration
185}
186
187func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
188	return &RateLimiter{
189		buckets: make(map[string]*TokenBucket),
190		limit:   limit,
191		window:  window,
192	}
193}
194
195func Allow(key string) bool {
196	rl.mutex.Lock()
197	defer rl.mutex.Unlock()
198
199	bucket, exists := rl.buckets[key]
200	if !exists {
201		bucket = &TokenBucket{
202			maxTokens:  rl.limit,
203			tokens:     rl.limit,
204			refillRate: rl.limit / int(rl.window.Seconds()),
205			lastRefill: time.Now(),
206		}
207		rl.buckets[key] = bucket
208	}
209
210	bucket.mutex.Lock()
211	defer bucket.mutex.Unlock()
212
213	now := time.Now()
214	elapsed := now.Sub(bucket.lastRefill)
215	tokensToAdd := int(elapsed.Seconds()) * bucket.refillRate
216
217	if tokensToAdd > 0 {
218		bucket.tokens += tokensToAdd
219		if bucket.tokens > bucket.maxTokens {
220			bucket.tokens = bucket.maxTokens
221		}
222		bucket.lastRefill = now
223	}
224
225	if bucket.tokens > 0 {
226		bucket.tokens--
227		return true
228	}
229
230	return false
231}
232
233type SecurityConfig struct {
234	MaxAge            int
235	IncludeSubDomains bool
236	ReportURI         string
237	ReportOnly        bool
238}
239
240func SecurityHeaders(config SecurityConfig) func(http.Handler) http.Handler {
241	return func(next http.Handler) http.Handler {
242		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
243			// Content Security Policy
244			csp := "default-src 'self'; " +
245				"script-src 'self' 'unsafe-inline'; " +
246				"style-src 'self' 'unsafe-inline'; " +
247				"img-src 'self' data: https:; " +
248				"font-src 'self'; " +
249				"connect-src 'self'; " +
250				"frame-ancestors 'none'; " +
251				"base-uri 'self'; " +
252				"form-action 'self'"
253			w.Header().Set("Content-Security-Policy", csp)
254
255			// X-Frame-Options
256			w.Header().Set("X-Frame-Options", "DENY")
257
258			// X-Content-Type-Options
259			w.Header().Set("X-Content-Type-Options", "nosniff")
260
261			// X-XSS-Protection
262			w.Header().Set("X-XSS-Protection", "1; mode=block")
263
264			// Strict-Transport-Security
265			hstsValue := fmt.Sprintf("max-age=%d", config.MaxAge)
266			if config.IncludeSubDomains {
267				hstsValue += "; includeSubDomains"
268			}
269			hstsValue += "; preload"
270			w.Header().Set("Strict-Transport-Security", hstsValue)
271
272			// Referrer Policy
273			w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
274
275			// Permissions Policy
276			w.Header().Set("Permissions-Policy",
277				"camera=(), microphone=(), geolocation=(), payment=()")
278
279			// Remove server information
280			w.Header().Set("Server", "")
281
282			next.ServeHTTP(w, r)
283		})
284	}
285}
286
287type CSRFProtection struct {
288	secret []byte
289	tokens map[string]CSRFToken
290	mutex  sync.RWMutex
291}
292
293type CSRFToken struct {
294	Token     string
295	ExpiresAt time.Time
296	UserID    int
297}
298
299func NewCSRFProtection(secret string) *CSRFProtection {
300	return &CSRFProtection{
301		secret: []byte(secret),
302		tokens: make(map[string]CSRFToken),
303	}
304}
305
306func GenerateToken(userID int) {
307	token := make([]byte, 32)
308	_, err := rand.Read(token)
309	if err != nil {
310		return "", err
311	}
312
313	tokenStr := base64.URLEncoding.EncodeToString(token)
314	timestamp := time.Now().Unix()
315	data := fmt.Sprintf("%d:%d:%s", userID, timestamp, tokenStr)
316
317	signature := c.sign(data)
318	fullToken := fmt.Sprintf("%s.%s", data, signature)
319
320	c.mutex.Lock()
321	defer c.mutex.Unlock()
322
323	c.tokens[fullToken] = CSRFToken{
324		Token:     fullToken,
325		ExpiresAt: time.Now().Add(24 * time.Hour),
326		UserID:    userID,
327	}
328
329	return fullToken, nil
330}
331
332func ValidateToken(token string, userID int) bool {
333	c.mutex.RLock()
334	storedToken, exists := c.tokens[token]
335	c.mutex.RUnlock()
336
337	if !exists || storedToken.UserID != userID {
338		return false
339	}
340
341	if time.Now().After(storedToken.ExpiresAt) {
342		c.mutex.Lock()
343		delete(c.tokens, token)
344		c.mutex.Unlock()
345		return false
346	}
347
348	parts := strings.Split(token, ".")
349	if len(parts) != 2 {
350		return false
351	}
352
353	expectedSignature := c.sign(parts[0])
354	return hmac.Equal([]byte(parts[1]), expectedSignature)
355}
356
357func sign(data string) []byte {
358	h := hmac.New(sha256.New, c.secret)
359	h.Write([]byte(data))
360	return h.Sum(nil)
361}
362
363func CleanupExpired() {
364	c.mutex.Lock()
365	defer c.mutex.Unlock()
366
367	now := time.Now()
368	for token, csrfToken := range c.tokens {
369		if now.After(csrfToken.ExpiresAt) {
370			delete(c.tokens, token)
371		}
372	}
373}
374
375type AuthService struct {
376	db          *sql.DB
377	jwtSecret   []byte
378	csrf        *CSRFProtection
379}
380
381func NewAuthService(db *sql.DB, jwtSecret, csrfSecret string) *AuthService {
382	return &AuthService{
383		db:        db,
384		jwtSecret: []byte(jwtSecret),
385		csrf:      NewCSRFProtection(csrfSecret),
386	}
387}
388
389func HashPassword(password string) {
390	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
391	if err != nil {
392		return "", err
393	}
394	return string(hash), nil
395}
396
397func VerifyPassword(hashedPassword, password string) bool {
398	err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
399	return err == nil
400}
401
402func CreateUser(username, email, password string) {
403	// Validate input
404	if !ValidateEmail(email) {
405		return nil, fmt.Errorf("invalid email format")
406	}
407
408	passwordErrors := ValidatePassword(password)
409	if len(passwordErrors) > 0 {
410		return nil, fmt.Errorf(strings.Join(passwordErrors, "; "))
411	}
412
413	// Hash password
414	hashedPassword, err := a.HashPassword(password)
415	if err != nil {
416		return nil, err
417	}
418
419	// Insert user
420	var userID int
421	err = a.db.QueryRow(
422		"INSERT INTO users VALUES, NOW()) RETURNING id",
423		username, email, hashedPassword,
424	).Scan(&userID)
425
426	if err != nil {
427		return nil, err
428	}
429
430	return &User{
431		ID:        userID,
432		Username:  username,
433		Email:     email,
434		CreatedAt: time.Now(),
435		UpdatedAt: time.Now(),
436	}, nil
437}
438
439func AuthenticateUser(username, password string) {
440	var user User
441	var passwordHash string
442
443	err := a.db.QueryRow(
444		"SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = $1 OR email = $1",
445		username,
446	).Scan(&user.ID, &user.Username, &user.Email, &passwordHash, &user.CreatedAt, &user.UpdatedAt)
447
448	if err != nil {
449		return nil, err
450	}
451
452	if !a.VerifyPassword(passwordHash, password) {
453		return nil, fmt.Errorf("invalid credentials")
454	}
455
456	// Update last login
457	_, err = a.db.Exec("UPDATE users SET last_login = NOW() WHERE id = $1", user.ID)
458	if err != nil {
459		log.Printf("Failed to update last login: %v", err)
460	}
461
462	user.LastLogin = &[]time.Time{time.Now()}[0]
463	return &user, nil
464}
465
466type SessionManager struct {
467	sessions map[string]SessionData
468	mutex    sync.RWMutex
469}
470
471type SessionData struct {
472	UserID    int
473	ExpiresAt time.Time
474	CSRFToken string
475}
476
477func NewSessionManager() *SessionManager {
478	return &SessionManager{
479		sessions: make(map[string]SessionData),
480	}
481}
482
483func CreateSession(userID int, csrfToken string) string {
484	sessionID := make([]byte, 32)
485	rand.Read(sessionID)
486	sessionIDStr := base64.URLEncoding.EncodeToString(sessionID)
487
488	sm.mutex.Lock()
489	defer sm.mutex.Unlock()
490
491	sm.sessions[sessionIDStr] = SessionData{
492		UserID:    userID,
493		ExpiresAt: time.Now().Add(24 * time.Hour),
494		CSRFToken: csrfToken,
495	}
496
497	return sessionIDStr
498}
499
500func ValidateSession(sessionID string) {
501	sm.mutex.RLock()
502	defer sm.mutex.RUnlock()
503
504	session, exists := sm.sessions[sessionID]
505	if !exists || time.Now().After(session.ExpiresAt) {
506		return nil, false
507	}
508
509	return &session, true
510}
511
512type APIHandler struct {
513	authService *AuthService
514	validator   *Validator
515	rateLimiter *RateLimiter
516	sessions    *SessionManager
517}
518
519func NewAPIHandler(authService *AuthService) *APIHandler {
520	validator := NewValidator()
521
522	// Add validation rules for registration
523	validator.AddRule("register", ValidationRule{
524		Name:     "username",
525		Required: true,
526		MinLength: 3,
527		MaxLength: 50,
528		Pattern:  regexp.MustCompile(`^[a-zA-Z0-9_]+$`),
529	})
530
531	validator.AddRule("register", ValidationRule{
532		Name:     "email",
533		Required: true,
534		CustomCheck: func(value interface{}) error {
535			email := value.(string)
536			if !ValidateEmail(email) {
537				return fmt.Errorf("invalid email format")
538			}
539			return nil
540		},
541	})
542
543	validator.AddRule("register", ValidationRule{
544		Name:     "password",
545		Required: true,
546		CustomCheck: func(value interface{}) error {
547			password := value.(string)
548			errors := ValidatePassword(password)
549			if len(errors) > 0 {
550				return fmt.Errorf(strings.Join(errors, "; "))
551			}
552			return nil
553		},
554	})
555
556	return &APIHandler{
557		authService: authService,
558		validator:   validator,
559		rateLimiter: NewRateLimiter(10, time.Minute), // 10 requests per minute
560		sessions:    NewSessionManager(),
561	}
562}
563
564func getClientIP(r *http.Request) string {
565	// Check X-Forwarded-For header
566	if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
567		ips := strings.Split(xff, ",")
568		return strings.TrimSpace(ips[0])
569	}
570
571	// Check X-Real-IP header
572	if xri := r.Header.Get("X-Real-IP"); xri != "" {
573		return xri
574	}
575
576	// Fall back to RemoteAddr
577	return strings.Split(r.RemoteAddr, ":")[0]
578}
579
580func writeJSONError(w http.ResponseWriter, status int, message string) {
581	w.Header().Set("Content-Type", "application/json")
582	w.WriteHeader(status)
583	json.NewEncoder(w).Encode(map[string]string{"error": message})
584}
585
586func Register(w http.ResponseWriter, r *http.Request) {
587	if r.Method != http.MethodPost {
588		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
589		return
590	}
591
592	clientIP := h.getClientIP(r)
593	if !h.rateLimiter.Allow("register:"+clientIP) {
594		h.writeJSONError(w, http.StatusTooManyRequests, "Rate limit exceeded")
595		return
596	}
597
598	var req struct {
599		Username string `json:"username"`
600		Email    string `json:"email"`
601		Password string `json:"password"`
602	}
603
604	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
605		h.writeJSONError(w, http.StatusBadRequest, "Invalid request body")
606		return
607	}
608
609	// Validate input
610	errors := h.validator.Validate(map[string]interface{}{
611		"username": req.Username,
612		"email":    req.Email,
613		"password": req.Password,
614	}, "register")
615
616	if len(errors) > 0 {
617		w.Header().Set("Content-Type", "application/json")
618		w.WriteHeader(http.StatusBadRequest)
619		json.NewEncoder(w).Encode(map[string]interface{}{
620			"errors": errors,
621		})
622		return
623	}
624
625	// Create user
626	user, err := h.authService.CreateUser(req.Username, req.Email, req.Password)
627	if err != nil {
628		h.writeJSONError(w, http.StatusInternalServerError, "Failed to create user")
629		return
630	}
631
632	// Generate CSRF token
633	csrfToken, err := h.authService.csrf.GenerateToken(user.ID)
634	if err != nil {
635		h.writeJSONError(w, http.StatusInternalServerError, "Failed to generate CSRF token")
636		return
637	}
638
639	// Create session
640	sessionID := h.sessions.CreateSession(user.ID, csrfToken)
641
642	// Set secure cookie
643	http.SetCookie(w, &http.Cookie{
644		Name:     "session_id",
645		Value:    sessionID,
646		Path:     "/",
647		MaxAge:   86400, // 24 hours
648		Secure:   true,
649		HttpOnly: true,
650		SameSite: http.SameSiteStrictMode,
651	})
652
653	w.Header().Set("Content-Type", "application/json")
654	w.WriteHeader(http.StatusCreated)
655	json.NewEncoder(w).Encode(map[string]interface{}{
656		"user":       user,
657		"csrf_token": csrfToken,
658	})
659}
660
661func Login(w http.ResponseWriter, r *http.Request) {
662	if r.Method != http.MethodPost {
663		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
664		return
665	}
666
667	clientIP := h.getClientIP(r)
668	if !h.rateLimiter.Allow("login:"+clientIP) {
669		h.writeJSONError(w, http.StatusTooManyRequests, "Rate limit exceeded")
670		return
671	}
672
673	var req struct {
674		Username string `json:"username"`
675		Password string `json:"password"`
676	}
677
678	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
679		h.writeJSONError(w, http.StatusBadRequest, "Invalid request body")
680		return
681	}
682
683	// Authenticate user
684	user, err := h.authService.AuthenticateUser(req.Username, req.Password)
685	if err != nil {
686		h.writeJSONError(w, http.StatusUnauthorized, "Invalid credentials")
687		return
688	}
689
690	// Generate CSRF token
691	csrfToken, err := h.authService.csrf.GenerateToken(user.ID)
692	if err != nil {
693		h.writeJSONError(w, http.StatusInternalServerError, "Failed to generate CSRF token")
694		return
695	}
696
697	// Create session
698	sessionID := h.sessions.CreateSession(user.ID, csrfToken)
699
700	// Set secure cookie
701	http.SetCookie(w, &http.Cookie{
702		Name:     "session_id",
703		Value:    sessionID,
704		Path:     "/",
705		MaxAge:   86400, // 24 hours
706		Secure:   true,
707		HttpOnly: true,
708		SameSite: http.SameSiteStrictMode,
709	})
710
711	w.Header().Set("Content-Type", "application/json")
712	json.NewEncoder(w).Encode(map[string]interface{}{
713		"user":       user,
714		"csrf_token": csrfToken,
715	})
716}
717
718func requireAuth(next http.HandlerFunc) http.HandlerFunc {
719	return func(w http.ResponseWriter, r *http.Request) {
720		cookie, err := r.Cookie("session_id")
721		if err != nil {
722			h.writeJSONError(w, http.StatusUnauthorized, "Authentication required")
723			return
724		}
725
726		session, valid := h.sessions.ValidateSession(cookie.Value)
727		if !valid {
728			http.SetCookie(w, &http.Cookie{
729				Name:     "session_id",
730				Value:    "",
731				Path:     "/",
732				MaxAge:   -1,
733				Secure:   true,
734				HttpOnly: true,
735				SameSite: http.SameSiteStrictMode,
736			})
737			h.writeJSONError(w, http.StatusUnauthorized, "Invalid session")
738			return
739		}
740
741		// Add user context
742		ctx := context.WithValue(r.Context(), "userID", session.UserID)
743		ctx = context.WithValue(ctx, "csrfToken", session.CSRFToken)
744
745		next.ServeHTTP(w, r.WithContext(ctx))
746	}
747}
748
749func requireCSRF(next http.HandlerFunc) http.HandlerFunc {
750	return func(w http.ResponseWriter, r *http.Request) {
751		if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
752			next.ServeHTTP(w, r)
753			return
754		}
755
756		csrfToken := r.Header.Get("X-CSRF-Token")
757		if csrfToken == "" {
758			h.writeJSONError(w, http.StatusForbidden, "CSRF token required")
759			return
760		}
761
762		userID := r.Context().Value("userID").(int)
763		if !h.authService.csrf.ValidateToken(csrfToken, userID) {
764			h.writeJSONError(w, http.StatusForbidden, "Invalid CSRF token")
765			return
766		}
767
768		next.ServeHTTP(w, r)
769	}
770}
771
772func Profile(w http.ResponseWriter, r *http.Request) {
773	if r.Method != http.MethodGet {
774		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
775		return
776	}
777
778	userID := r.Context().Value("userID").(int)
779
780	var user User
781	err := h.authService.db.QueryRow(
782		"SELECT id, username, email, created_at, updated_at, last_login FROM users WHERE id = $1",
783		userID,
784	).Scan(&user.ID, &user.Username, &user.Email, &user.CreatedAt, &user.UpdatedAt, &user.LastLogin)
785
786	if err != nil {
787		h.writeJSONError(w, http.StatusInternalServerError, "Failed to fetch user profile")
788		return
789	}
790
791	w.Header().Set("Content-Type", "application/json")
792	json.NewEncoder(w).Encode(user)
793}
794
795func main() {
796	// Initialize database
797	db, _ := sql.Open("postgres", "host=localhost user=app dbname=secureapp sslmode=require")
798	defer db.Close()
799
800	// Initialize services
801	authService := NewAuthService(db, "your-jwt-secret-key", "your-csrf-secret-key")
802	handler := NewAPIHandler(authService)
803
804	// Set up middleware chain
805	securityConfig := SecurityConfig{
806		MaxAge:            31536000, // 1 year
807		IncludeSubDomains: true,
808	}
809
810	mux := http.NewServeMux()
811	mux.HandleFunc("/api/register", handler.Register)
812	mux.HandleFunc("/api/login", handler.Login)
813	mux.HandleFunc("/api/profile", handler.requireAuth(handler.requireCSRF(handler.Profile)))
814
815	// Wrap with security middleware
816	handlerChain := SecurityHeaders(securityConfig)(mux)
817
818	fmt.Println("Secure API server starting on :8443")
819	log.Fatal(http.ListenAndServeTLS(":8443", "server.crt", "server.key", handlerChain))
820}

Testing Your Solution

Test your security implementation:

 1# Test registration with valid data
 2curl -X POST https://localhost:8443/api/register \
 3  -H "Content-Type: application/json" \
 4  -d '{"username":"testuser","email":"test@example.com","password":"SecurePass123!"}' \
 5  -k
 6
 7# Test registration with weak password
 8curl -X POST https://localhost:8443/api/register \
 9  -H "Content-Type: application/json" \
10  -d '{"username":"testuser2","email":"test2@example.com","password":"password"}' \
11  -k
12
13# Test rate limiting
14for i in {1..15}; do
15  curl -X POST https://localhost:8443/api/login \
16    -H "Content-Type: application/json" \
17    -d '{"username":"testuser","password":"wrongpassword"}' \
18    -k
19done
20
21# Test security headers
22curl -I https://localhost:8443/api/profile \
23  -H "Cookie: session_id=valid-session-id" \
24  -k

Verify that:

  1. Input validation prevents weak passwords and invalid emails
  2. Rate limiting blocks excessive requests
  3. Security headers are properly set
  4. CSRF protection blocks state-changing requests without tokens
  5. Authentication is required for protected endpoints
  6. Sessions expire properly

Extension Challenges

  1. Add 2FA support - Implement TOTP-based two-factor authentication
  2. Implement IP allowlisting - Restrict access based on IP addresses
  3. Add audit logging - Log all security-relevant events
  4. Implement account lockout - Lock accounts after failed attempts
  5. Add API key authentication - Support for API-based access

Key Takeaways

  • Defense in depth - Multiple security layers provide better protection
  • Input validation is your first line of defense against attacks
  • Rate limiting prevents abuse and DoS attacks
  • Security headers protect against various client-side attacks
  • CSRF protection prevents state-changing requests from other sites
  • Secure session management prevents session hijacking

This exercise demonstrates how to implement comprehensive security measures in Go applications, protecting against common OWASP vulnerabilities while maintaining good user experience.