Security Hardening for Go Applications
Exercise Overview
Build a secure web API that implements multiple layers of security protection. You'll defend against common OWASP vulnerabilities including injection attacks, authentication bypass, and data exposure.
Learning Objectives
- Implement input validation and sanitization
- Secure password handling and authentication
- Add rate limiting and DDoS protection
- Implement security headers and CSP
- Handle secrets securely
- Prevent common web vulnerabilities
Initial Code
1package main
2
3import (
4 "context"
5 "crypto/rand"
6 "database/sql"
7 "encoding/json"
8 "fmt"
9 "log"
10 "net/http"
11 "regexp"
12 "strings"
13 "sync"
14 "time"
15
16 "golang.org/x/crypto/bcrypt"
17 _ "github.com/lib/pq"
18)
19
20// TODO: Implement secure user model
21type User struct {
22 ID int `json:"id"`
23 Username string `json:"username"`
24 Email string `json:"email"`
25 Password string `json:"-"` // Never expose password in JSON
26 // Add security fields
27}
28
29// TODO: Implement input validation
30type ValidationRule struct {
31 // Define validation rule structure
32}
33
34type Validator struct {
35 rules map[string][]ValidationRule
36}
37
38// TODO: Implement secure password handling
39type AuthService struct {
40 db *sql.DB
41 // Add authentication fields
42}
43
44// TODO: Implement rate limiting
45type RateLimiter struct {
46 // Implement token bucket or sliding window
47}
48
49// TODO: Implement security headers middleware
50type SecurityHeaders struct {
51 // Define security header configurations
52}
53
54// TODO: Implement CSRF protection
55type CSRFProtection struct {
56 // Add CSRF token management
57}
58
59// TODO: Implement secure session management
60type SessionManager struct {
61 // Add secure session handling
62}
63
64// API Handlers
65type APIHandler struct {
66 authService *AuthService
67 validator *Validator
68 rateLimiter *RateLimiter
69 sessions *SessionManager
70}
71
72// TODO: Implement secure registration endpoint
73func Register(w http.ResponseWriter, r *http.Request) {
74 // Implement secure user registration
75}
76
77// TODO: Implement secure login endpoint
78func Login(w http.ResponseWriter, r *http.Request) {
79 // Implement secure authentication
80}
81
82// TODO: Implement protected profile endpoint
83func Profile(w http.ResponseWriter, r *http.Request) {
84 // Implement secure profile access
85}
86
87// TODO: Implement secure password change
88func ChangePassword(w http.ResponseWriter, r *http.Request) {
89 // Implement secure password update
90}
91
92// TODO: Implement secure search endpoint
93func Search(w http.ResponseWriter, r *http.Request) {
94 // Implement secure search with input validation
95}
96
97// TODO: Implement secure file upload
98func UploadFile(w http.ResponseWriter, r *http.Request) {
99 // Implement secure file upload with validation
100}
101
102func main() {
103 // Initialize database
104 db, err := sql.Open("postgres", "host=localhost user=app dbname=secureapp sslmode=require")
105 if err != nil {
106 log.Fatal(err)
107 }
108 defer db.Close()
109
110 // Initialize services
111 authService := &AuthService{db: db}
112 validator := &Validator{
113 rules: make(map[string][]ValidationRule),
114 }
115 rateLimiter := &RateLimiter{}
116 sessions := &SessionManager{}
117
118 handler := &APIHandler{
119 authService: authService,
120 validator: validator,
121 rateLimiter: rateLimiter,
122 sessions: sessions,
123 }
124
125 // TODO: Set up secure middleware chain
126 mux := http.NewServeMux()
127 mux.HandleFunc("/api/register", handler.Register)
128 mux.HandleFunc("/api/login", handler.Login)
129 mux.HandleFunc("/api/profile", handler.Profile)
130 mux.HandleFunc("/api/change-password", handler.ChangePassword)
131 mux.HandleFunc("/api/search", handler.Search)
132 mux.HandleFunc("/api/upload", handler.UploadFile)
133
134 fmt.Println("Secure API server starting on :8443")
135 // TODO: Configure HTTPS properly
136 log.Fatal(http.ListenAndServeTLS(":8443", "server.crt", "server.key", mux))
137}
Tasks
Task 1: Input Validation System
Implement comprehensive input validation:
1type ValidationRule struct {
2 Name string
3 Required bool
4 MinLength int
5 MaxLength int
6 Pattern *regexp.Regexp
7 CustomCheck func(interface{}) error
8}
9
10type ValidationError struct {
11 Field string `json:"field"`
12 Message string `json:"message"`
13}
14
15func Validate(data interface{}, targetType string) []ValidationError {
16 // Implement validation logic
17 return nil
18}
19
20// Built-in validators
21func ValidateEmail(email string) bool {
22 emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
23 return emailRegex.MatchString(email)
24}
25
26func ValidatePassword(password string) []string {
27 var errors []string
28
29 if len(password) < 8 {
30 errors = append(errors, "Password must be at least 8 characters")
31 }
32
33 if !regexp.MustCompile(`[A-Z]`).MatchString(password) {
34 errors = append(errors, "Password must contain uppercase letter")
35 }
36
37 if !regexp.MustCompile(`[a-z]`).MatchString(password) {
38 errors = append(errors, "Password must contain lowercase letter")
39 }
40
41 if !regexp.MustCompile(`[0-9]`).MatchString(password) {
42 errors = append(errors, "Password must contain number")
43 }
44
45 if !regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) {
46 errors = append(errors, "Password must contain special character")
47 }
48
49 return errors
50}
Task 2: Secure Authentication System
Implement secure authentication with bcrypt:
1type AuthService struct {
2 db *sql.DB
3 tokenSecret []byte
4}
5
6func HashPassword(password string) {
7 hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
8 if err != nil {
9 return "", err
10 }
11 return string(hash), nil
12}
13
14func VerifyPassword(hashedPassword, password string) bool {
15 err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
16 return err == nil
17}
18
19func GenerateJWT(userID int, username string) {
20 // Implement JWT generation with proper claims
21 return "", nil
22}
23
24func ValidateJWT(tokenString string) {
25 // Implement JWT validation
26 return nil, nil
27}
Task 3: Rate Limiting Implementation
Implement token bucket rate limiting:
1type TokenBucket struct {
2 tokens int
3 maxTokens int
4 refillRate int
5 lastRefill time.Time
6 mutex sync.Mutex
7}
8
9type RateLimiter struct {
10 buckets map[string]*TokenBucket
11 mutex sync.RWMutex
12 limit int
13 window time.Duration
14}
15
16func Allow(key string) bool {
17 rl.mutex.Lock()
18 defer rl.mutex.Unlock()
19
20 bucket, exists := rl.buckets[key]
21 if !exists {
22 bucket = &TokenBucket{
23 maxTokens: rl.limit,
24 tokens: rl.limit,
25 refillRate: 1,
26 lastRefill: time.Now(),
27 }
28 rl.buckets[key] = bucket
29 }
30
31 bucket.mutex.Lock()
32 defer bucket.mutex.Unlock()
33
34 // Refill tokens based on time elapsed
35 now := time.Now()
36 elapsed := now.Sub(bucket.lastRefill)
37 tokensToAdd := int(elapsed.Seconds()) * bucket.refillRate
38
39 bucket.tokens += tokensToAdd
40 if bucket.tokens > bucket.maxTokens {
41 bucket.tokens = bucket.maxTokens
42 }
43 bucket.lastRefill = now
44
45 if bucket.tokens > 0 {
46 bucket.tokens--
47 return true
48 }
49
50 return false
51}
Task 4: Security Headers Middleware
Implement comprehensive security headers:
1type SecurityConfig struct {
2 MaxAge int
3 IncludeSubDomains bool
4 ReportURI string
5 ReportOnly bool
6}
7
8func SecurityHeaders(config SecurityConfig) func(http.Handler) http.Handler {
9 return func(next http.Handler) http.Handler {
10 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
11 // Content Security Policy
12 csp := "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'"
13 w.Header().Set("Content-Security-Policy", csp)
14
15 // X-Frame-Options
16 w.Header().Set("X-Frame-Options", "DENY")
17
18 // X-Content-Type-Options
19 w.Header().Set("X-Content-Type-Options", "nosniff")
20
21 // X-XSS-Protection
22 w.Header().Set("X-XSS-Protection", "1; mode=block")
23
24 // Strict-Transport-Security
25 hstsValue := fmt.Sprintf("max-age=%d", config.MaxAge)
26 if config.IncludeSubDomains {
27 hstsValue += "; includeSubDomains"
28 }
29 w.Header().Set("Strict-Transport-Security", hstsValue)
30
31 // Referrer Policy
32 w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
33
34 // Permissions Policy
35 w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
36
37 next.ServeHTTP(w, r)
38 })
39 }
40}
Task 5: CSRF Protection
Implement CSRF token validation:
1type CSRFProtection struct {
2 secret []byte
3}
4
5func GenerateToken(userID int) {
6 // Generate secure random token
7 token := make([]byte, 32)
8 _, err := rand.Read(token)
9 if err != nil {
10 return "", err
11 }
12
13 // Combine with user ID and timestamp
14 data := fmt.Sprintf("%d:%d:%x", userID, time.Now().Unix(), token)
15
16 // Sign the token
17 signature := c.sign(data)
18 return fmt.Sprintf("%s.%s", data, signature), nil
19}
20
21func ValidateToken(token string, userID int) bool {
22 // Implement token validation logic
23 return true
24}
25
26func sign(data string) string {
27 // Implement HMAC signing
28 return ""
29}
Solution Approach
Click to see detailed solution
Complete Implementation:
1package main
2
3import (
4 "context"
5 "crypto/hmac"
6 "crypto/rand"
7 "crypto/sha256"
8 "database/sql"
9 "encoding/base64"
10 "encoding/json"
11 "fmt"
12 "log"
13 "net/http"
14 "regexp"
15 "strings"
16 "sync"
17 "time"
18
19 "golang.org/x/crypto/bcrypt"
20 _ "github.com/lib/pq"
21)
22
23type User struct {
24 ID int `json:"id"`
25 Username string `json:"username"`
26 Email string `json:"email"`
27 Password string `json:"-"`
28 CreatedAt time.Time `json:"created_at"`
29 UpdatedAt time.Time `json:"updated_at"`
30 LastLogin *time.Time `json:"last_login,omitempty"`
31}
32
33type ValidationRule struct {
34 Name string
35 Required bool
36 MinLength int
37 MaxLength int
38 Pattern *regexp.Regexp
39 CustomCheck func(interface{}) error
40}
41
42type ValidationError struct {
43 Field string `json:"field"`
44 Message string `json:"message"`
45}
46
47type Validator struct {
48 rules map[string][]ValidationRule
49}
50
51func NewValidator() *Validator {
52 return &Validator{
53 rules: make(map[string][]ValidationRule),
54 }
55}
56
57func AddRule(targetType string, rule ValidationRule) {
58 v.rules[targetType] = append(v.rules[targetType], rule)
59}
60
61func Validate(data interface{}, targetType string) []ValidationError {
62 var errors []ValidationError
63
64 // Convert data to map for validation
65 dataMap, ok := data.(map[string]interface{})
66 if !ok {
67 errors = append(errors, ValidationError{
68 Field: "root",
69 Message: "Invalid data format",
70 })
71 return errors
72 }
73
74 rules, exists := v.rules[targetType]
75 if !exists {
76 return errors
77 }
78
79 for _, rule := range rules {
80 value, fieldExists := dataMap[rule.Name]
81
82 if rule.Required && !fieldExists {
83 errors = append(errors, ValidationError{
84 Field: rule.Name,
85 Message: fmt.Sprintf("%s is required", rule.Name),
86 })
87 continue
88 }
89
90 if !fieldExists {
91 continue
92 }
93
94 strValue := fmt.Sprintf("%v", value)
95
96 if rule.MinLength > 0 && len(strValue) < rule.MinLength {
97 errors = append(errors, ValidationError{
98 Field: rule.Name,
99 Message: fmt.Sprintf("%s must be at least %d characters", rule.Name, rule.MinLength),
100 })
101 }
102
103 if rule.MaxLength > 0 && len(strValue) > rule.MaxLength {
104 errors = append(errors, ValidationError{
105 Field: rule.Name,
106 Message: fmt.Sprintf("%s must be no more than %d characters", rule.Name, rule.MaxLength),
107 })
108 }
109
110 if rule.Pattern != nil && !rule.Pattern.MatchString(strValue) {
111 errors = append(errors, ValidationError{
112 Field: rule.Name,
113 Message: fmt.Sprintf("%s format is invalid", rule.Name),
114 })
115 }
116
117 if rule.CustomCheck != nil {
118 if err := rule.CustomCheck(value); err != nil {
119 errors = append(errors, ValidationError{
120 Field: rule.Name,
121 Message: err.Error(),
122 })
123 }
124 }
125 }
126
127 return errors
128}
129
130func ValidateEmail(email string) bool {
131 emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
132 return emailRegex.MatchString(email) && len(email) <= 254
133}
134
135func ValidatePassword(password string) []string {
136 var errors []string
137
138 if len(password) < 8 {
139 errors = append(errors, "Password must be at least 8 characters")
140 }
141
142 if len(password) > 128 {
143 errors = append(errors, "Password must be no more than 128 characters")
144 }
145
146 if !regexp.MustCompile(`[A-Z]`).MatchString(password) {
147 errors = append(errors, "Password must contain uppercase letter")
148 }
149
150 if !regexp.MustCompile(`[a-z]`).MatchString(password) {
151 errors = append(errors, "Password must contain lowercase letter")
152 }
153
154 if !regexp.MustCompile(`[0-9]`).MatchString(password) {
155 errors = append(errors, "Password must contain number")
156 }
157
158 if !regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) {
159 errors = append(errors, "Password must contain special character")
160 }
161
162 // Check for common patterns
163 if strings.ToLower(password) == "password" ||
164 strings.ToLower(password) == "12345678" ||
165 strings.Contains(strings.ToLower(password), "qwerty") {
166 errors = append(errors, "Password is too common")
167 }
168
169 return errors
170}
171
172type TokenBucket struct {
173 tokens int
174 maxTokens int
175 refillRate int
176 lastRefill time.Time
177 mutex sync.Mutex
178}
179
180type RateLimiter struct {
181 buckets map[string]*TokenBucket
182 mutex sync.RWMutex
183 limit int
184 window time.Duration
185}
186
187func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
188 return &RateLimiter{
189 buckets: make(map[string]*TokenBucket),
190 limit: limit,
191 window: window,
192 }
193}
194
195func Allow(key string) bool {
196 rl.mutex.Lock()
197 defer rl.mutex.Unlock()
198
199 bucket, exists := rl.buckets[key]
200 if !exists {
201 bucket = &TokenBucket{
202 maxTokens: rl.limit,
203 tokens: rl.limit,
204 refillRate: rl.limit / int(rl.window.Seconds()),
205 lastRefill: time.Now(),
206 }
207 rl.buckets[key] = bucket
208 }
209
210 bucket.mutex.Lock()
211 defer bucket.mutex.Unlock()
212
213 now := time.Now()
214 elapsed := now.Sub(bucket.lastRefill)
215 tokensToAdd := int(elapsed.Seconds()) * bucket.refillRate
216
217 if tokensToAdd > 0 {
218 bucket.tokens += tokensToAdd
219 if bucket.tokens > bucket.maxTokens {
220 bucket.tokens = bucket.maxTokens
221 }
222 bucket.lastRefill = now
223 }
224
225 if bucket.tokens > 0 {
226 bucket.tokens--
227 return true
228 }
229
230 return false
231}
232
233type SecurityConfig struct {
234 MaxAge int
235 IncludeSubDomains bool
236 ReportURI string
237 ReportOnly bool
238}
239
240func SecurityHeaders(config SecurityConfig) func(http.Handler) http.Handler {
241 return func(next http.Handler) http.Handler {
242 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
243 // Content Security Policy
244 csp := "default-src 'self'; " +
245 "script-src 'self' 'unsafe-inline'; " +
246 "style-src 'self' 'unsafe-inline'; " +
247 "img-src 'self' data: https:; " +
248 "font-src 'self'; " +
249 "connect-src 'self'; " +
250 "frame-ancestors 'none'; " +
251 "base-uri 'self'; " +
252 "form-action 'self'"
253 w.Header().Set("Content-Security-Policy", csp)
254
255 // X-Frame-Options
256 w.Header().Set("X-Frame-Options", "DENY")
257
258 // X-Content-Type-Options
259 w.Header().Set("X-Content-Type-Options", "nosniff")
260
261 // X-XSS-Protection
262 w.Header().Set("X-XSS-Protection", "1; mode=block")
263
264 // Strict-Transport-Security
265 hstsValue := fmt.Sprintf("max-age=%d", config.MaxAge)
266 if config.IncludeSubDomains {
267 hstsValue += "; includeSubDomains"
268 }
269 hstsValue += "; preload"
270 w.Header().Set("Strict-Transport-Security", hstsValue)
271
272 // Referrer Policy
273 w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
274
275 // Permissions Policy
276 w.Header().Set("Permissions-Policy",
277 "camera=(), microphone=(), geolocation=(), payment=()")
278
279 // Remove server information
280 w.Header().Set("Server", "")
281
282 next.ServeHTTP(w, r)
283 })
284 }
285}
286
287type CSRFProtection struct {
288 secret []byte
289 tokens map[string]CSRFToken
290 mutex sync.RWMutex
291}
292
293type CSRFToken struct {
294 Token string
295 ExpiresAt time.Time
296 UserID int
297}
298
299func NewCSRFProtection(secret string) *CSRFProtection {
300 return &CSRFProtection{
301 secret: []byte(secret),
302 tokens: make(map[string]CSRFToken),
303 }
304}
305
306func GenerateToken(userID int) {
307 token := make([]byte, 32)
308 _, err := rand.Read(token)
309 if err != nil {
310 return "", err
311 }
312
313 tokenStr := base64.URLEncoding.EncodeToString(token)
314 timestamp := time.Now().Unix()
315 data := fmt.Sprintf("%d:%d:%s", userID, timestamp, tokenStr)
316
317 signature := c.sign(data)
318 fullToken := fmt.Sprintf("%s.%s", data, signature)
319
320 c.mutex.Lock()
321 defer c.mutex.Unlock()
322
323 c.tokens[fullToken] = CSRFToken{
324 Token: fullToken,
325 ExpiresAt: time.Now().Add(24 * time.Hour),
326 UserID: userID,
327 }
328
329 return fullToken, nil
330}
331
332func ValidateToken(token string, userID int) bool {
333 c.mutex.RLock()
334 storedToken, exists := c.tokens[token]
335 c.mutex.RUnlock()
336
337 if !exists || storedToken.UserID != userID {
338 return false
339 }
340
341 if time.Now().After(storedToken.ExpiresAt) {
342 c.mutex.Lock()
343 delete(c.tokens, token)
344 c.mutex.Unlock()
345 return false
346 }
347
348 parts := strings.Split(token, ".")
349 if len(parts) != 2 {
350 return false
351 }
352
353 expectedSignature := c.sign(parts[0])
354 return hmac.Equal([]byte(parts[1]), expectedSignature)
355}
356
357func sign(data string) []byte {
358 h := hmac.New(sha256.New, c.secret)
359 h.Write([]byte(data))
360 return h.Sum(nil)
361}
362
363func CleanupExpired() {
364 c.mutex.Lock()
365 defer c.mutex.Unlock()
366
367 now := time.Now()
368 for token, csrfToken := range c.tokens {
369 if now.After(csrfToken.ExpiresAt) {
370 delete(c.tokens, token)
371 }
372 }
373}
374
375type AuthService struct {
376 db *sql.DB
377 jwtSecret []byte
378 csrf *CSRFProtection
379}
380
381func NewAuthService(db *sql.DB, jwtSecret, csrfSecret string) *AuthService {
382 return &AuthService{
383 db: db,
384 jwtSecret: []byte(jwtSecret),
385 csrf: NewCSRFProtection(csrfSecret),
386 }
387}
388
389func HashPassword(password string) {
390 hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
391 if err != nil {
392 return "", err
393 }
394 return string(hash), nil
395}
396
397func VerifyPassword(hashedPassword, password string) bool {
398 err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
399 return err == nil
400}
401
402func CreateUser(username, email, password string) {
403 // Validate input
404 if !ValidateEmail(email) {
405 return nil, fmt.Errorf("invalid email format")
406 }
407
408 passwordErrors := ValidatePassword(password)
409 if len(passwordErrors) > 0 {
410 return nil, fmt.Errorf(strings.Join(passwordErrors, "; "))
411 }
412
413 // Hash password
414 hashedPassword, err := a.HashPassword(password)
415 if err != nil {
416 return nil, err
417 }
418
419 // Insert user
420 var userID int
421 err = a.db.QueryRow(
422 "INSERT INTO users VALUES, NOW()) RETURNING id",
423 username, email, hashedPassword,
424 ).Scan(&userID)
425
426 if err != nil {
427 return nil, err
428 }
429
430 return &User{
431 ID: userID,
432 Username: username,
433 Email: email,
434 CreatedAt: time.Now(),
435 UpdatedAt: time.Now(),
436 }, nil
437}
438
439func AuthenticateUser(username, password string) {
440 var user User
441 var passwordHash string
442
443 err := a.db.QueryRow(
444 "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = $1 OR email = $1",
445 username,
446 ).Scan(&user.ID, &user.Username, &user.Email, &passwordHash, &user.CreatedAt, &user.UpdatedAt)
447
448 if err != nil {
449 return nil, err
450 }
451
452 if !a.VerifyPassword(passwordHash, password) {
453 return nil, fmt.Errorf("invalid credentials")
454 }
455
456 // Update last login
457 _, err = a.db.Exec("UPDATE users SET last_login = NOW() WHERE id = $1", user.ID)
458 if err != nil {
459 log.Printf("Failed to update last login: %v", err)
460 }
461
462 user.LastLogin = &[]time.Time{time.Now()}[0]
463 return &user, nil
464}
465
466type SessionManager struct {
467 sessions map[string]SessionData
468 mutex sync.RWMutex
469}
470
471type SessionData struct {
472 UserID int
473 ExpiresAt time.Time
474 CSRFToken string
475}
476
477func NewSessionManager() *SessionManager {
478 return &SessionManager{
479 sessions: make(map[string]SessionData),
480 }
481}
482
483func CreateSession(userID int, csrfToken string) string {
484 sessionID := make([]byte, 32)
485 rand.Read(sessionID)
486 sessionIDStr := base64.URLEncoding.EncodeToString(sessionID)
487
488 sm.mutex.Lock()
489 defer sm.mutex.Unlock()
490
491 sm.sessions[sessionIDStr] = SessionData{
492 UserID: userID,
493 ExpiresAt: time.Now().Add(24 * time.Hour),
494 CSRFToken: csrfToken,
495 }
496
497 return sessionIDStr
498}
499
500func ValidateSession(sessionID string) {
501 sm.mutex.RLock()
502 defer sm.mutex.RUnlock()
503
504 session, exists := sm.sessions[sessionID]
505 if !exists || time.Now().After(session.ExpiresAt) {
506 return nil, false
507 }
508
509 return &session, true
510}
511
512type APIHandler struct {
513 authService *AuthService
514 validator *Validator
515 rateLimiter *RateLimiter
516 sessions *SessionManager
517}
518
519func NewAPIHandler(authService *AuthService) *APIHandler {
520 validator := NewValidator()
521
522 // Add validation rules for registration
523 validator.AddRule("register", ValidationRule{
524 Name: "username",
525 Required: true,
526 MinLength: 3,
527 MaxLength: 50,
528 Pattern: regexp.MustCompile(`^[a-zA-Z0-9_]+$`),
529 })
530
531 validator.AddRule("register", ValidationRule{
532 Name: "email",
533 Required: true,
534 CustomCheck: func(value interface{}) error {
535 email := value.(string)
536 if !ValidateEmail(email) {
537 return fmt.Errorf("invalid email format")
538 }
539 return nil
540 },
541 })
542
543 validator.AddRule("register", ValidationRule{
544 Name: "password",
545 Required: true,
546 CustomCheck: func(value interface{}) error {
547 password := value.(string)
548 errors := ValidatePassword(password)
549 if len(errors) > 0 {
550 return fmt.Errorf(strings.Join(errors, "; "))
551 }
552 return nil
553 },
554 })
555
556 return &APIHandler{
557 authService: authService,
558 validator: validator,
559 rateLimiter: NewRateLimiter(10, time.Minute), // 10 requests per minute
560 sessions: NewSessionManager(),
561 }
562}
563
564func getClientIP(r *http.Request) string {
565 // Check X-Forwarded-For header
566 if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
567 ips := strings.Split(xff, ",")
568 return strings.TrimSpace(ips[0])
569 }
570
571 // Check X-Real-IP header
572 if xri := r.Header.Get("X-Real-IP"); xri != "" {
573 return xri
574 }
575
576 // Fall back to RemoteAddr
577 return strings.Split(r.RemoteAddr, ":")[0]
578}
579
580func writeJSONError(w http.ResponseWriter, status int, message string) {
581 w.Header().Set("Content-Type", "application/json")
582 w.WriteHeader(status)
583 json.NewEncoder(w).Encode(map[string]string{"error": message})
584}
585
586func Register(w http.ResponseWriter, r *http.Request) {
587 if r.Method != http.MethodPost {
588 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
589 return
590 }
591
592 clientIP := h.getClientIP(r)
593 if !h.rateLimiter.Allow("register:"+clientIP) {
594 h.writeJSONError(w, http.StatusTooManyRequests, "Rate limit exceeded")
595 return
596 }
597
598 var req struct {
599 Username string `json:"username"`
600 Email string `json:"email"`
601 Password string `json:"password"`
602 }
603
604 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
605 h.writeJSONError(w, http.StatusBadRequest, "Invalid request body")
606 return
607 }
608
609 // Validate input
610 errors := h.validator.Validate(map[string]interface{}{
611 "username": req.Username,
612 "email": req.Email,
613 "password": req.Password,
614 }, "register")
615
616 if len(errors) > 0 {
617 w.Header().Set("Content-Type", "application/json")
618 w.WriteHeader(http.StatusBadRequest)
619 json.NewEncoder(w).Encode(map[string]interface{}{
620 "errors": errors,
621 })
622 return
623 }
624
625 // Create user
626 user, err := h.authService.CreateUser(req.Username, req.Email, req.Password)
627 if err != nil {
628 h.writeJSONError(w, http.StatusInternalServerError, "Failed to create user")
629 return
630 }
631
632 // Generate CSRF token
633 csrfToken, err := h.authService.csrf.GenerateToken(user.ID)
634 if err != nil {
635 h.writeJSONError(w, http.StatusInternalServerError, "Failed to generate CSRF token")
636 return
637 }
638
639 // Create session
640 sessionID := h.sessions.CreateSession(user.ID, csrfToken)
641
642 // Set secure cookie
643 http.SetCookie(w, &http.Cookie{
644 Name: "session_id",
645 Value: sessionID,
646 Path: "/",
647 MaxAge: 86400, // 24 hours
648 Secure: true,
649 HttpOnly: true,
650 SameSite: http.SameSiteStrictMode,
651 })
652
653 w.Header().Set("Content-Type", "application/json")
654 w.WriteHeader(http.StatusCreated)
655 json.NewEncoder(w).Encode(map[string]interface{}{
656 "user": user,
657 "csrf_token": csrfToken,
658 })
659}
660
661func Login(w http.ResponseWriter, r *http.Request) {
662 if r.Method != http.MethodPost {
663 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
664 return
665 }
666
667 clientIP := h.getClientIP(r)
668 if !h.rateLimiter.Allow("login:"+clientIP) {
669 h.writeJSONError(w, http.StatusTooManyRequests, "Rate limit exceeded")
670 return
671 }
672
673 var req struct {
674 Username string `json:"username"`
675 Password string `json:"password"`
676 }
677
678 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
679 h.writeJSONError(w, http.StatusBadRequest, "Invalid request body")
680 return
681 }
682
683 // Authenticate user
684 user, err := h.authService.AuthenticateUser(req.Username, req.Password)
685 if err != nil {
686 h.writeJSONError(w, http.StatusUnauthorized, "Invalid credentials")
687 return
688 }
689
690 // Generate CSRF token
691 csrfToken, err := h.authService.csrf.GenerateToken(user.ID)
692 if err != nil {
693 h.writeJSONError(w, http.StatusInternalServerError, "Failed to generate CSRF token")
694 return
695 }
696
697 // Create session
698 sessionID := h.sessions.CreateSession(user.ID, csrfToken)
699
700 // Set secure cookie
701 http.SetCookie(w, &http.Cookie{
702 Name: "session_id",
703 Value: sessionID,
704 Path: "/",
705 MaxAge: 86400, // 24 hours
706 Secure: true,
707 HttpOnly: true,
708 SameSite: http.SameSiteStrictMode,
709 })
710
711 w.Header().Set("Content-Type", "application/json")
712 json.NewEncoder(w).Encode(map[string]interface{}{
713 "user": user,
714 "csrf_token": csrfToken,
715 })
716}
717
718func requireAuth(next http.HandlerFunc) http.HandlerFunc {
719 return func(w http.ResponseWriter, r *http.Request) {
720 cookie, err := r.Cookie("session_id")
721 if err != nil {
722 h.writeJSONError(w, http.StatusUnauthorized, "Authentication required")
723 return
724 }
725
726 session, valid := h.sessions.ValidateSession(cookie.Value)
727 if !valid {
728 http.SetCookie(w, &http.Cookie{
729 Name: "session_id",
730 Value: "",
731 Path: "/",
732 MaxAge: -1,
733 Secure: true,
734 HttpOnly: true,
735 SameSite: http.SameSiteStrictMode,
736 })
737 h.writeJSONError(w, http.StatusUnauthorized, "Invalid session")
738 return
739 }
740
741 // Add user context
742 ctx := context.WithValue(r.Context(), "userID", session.UserID)
743 ctx = context.WithValue(ctx, "csrfToken", session.CSRFToken)
744
745 next.ServeHTTP(w, r.WithContext(ctx))
746 }
747}
748
749func requireCSRF(next http.HandlerFunc) http.HandlerFunc {
750 return func(w http.ResponseWriter, r *http.Request) {
751 if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
752 next.ServeHTTP(w, r)
753 return
754 }
755
756 csrfToken := r.Header.Get("X-CSRF-Token")
757 if csrfToken == "" {
758 h.writeJSONError(w, http.StatusForbidden, "CSRF token required")
759 return
760 }
761
762 userID := r.Context().Value("userID").(int)
763 if !h.authService.csrf.ValidateToken(csrfToken, userID) {
764 h.writeJSONError(w, http.StatusForbidden, "Invalid CSRF token")
765 return
766 }
767
768 next.ServeHTTP(w, r)
769 }
770}
771
772func Profile(w http.ResponseWriter, r *http.Request) {
773 if r.Method != http.MethodGet {
774 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
775 return
776 }
777
778 userID := r.Context().Value("userID").(int)
779
780 var user User
781 err := h.authService.db.QueryRow(
782 "SELECT id, username, email, created_at, updated_at, last_login FROM users WHERE id = $1",
783 userID,
784 ).Scan(&user.ID, &user.Username, &user.Email, &user.CreatedAt, &user.UpdatedAt, &user.LastLogin)
785
786 if err != nil {
787 h.writeJSONError(w, http.StatusInternalServerError, "Failed to fetch user profile")
788 return
789 }
790
791 w.Header().Set("Content-Type", "application/json")
792 json.NewEncoder(w).Encode(user)
793}
794
795func main() {
796 // Initialize database
797 db, _ := sql.Open("postgres", "host=localhost user=app dbname=secureapp sslmode=require")
798 defer db.Close()
799
800 // Initialize services
801 authService := NewAuthService(db, "your-jwt-secret-key", "your-csrf-secret-key")
802 handler := NewAPIHandler(authService)
803
804 // Set up middleware chain
805 securityConfig := SecurityConfig{
806 MaxAge: 31536000, // 1 year
807 IncludeSubDomains: true,
808 }
809
810 mux := http.NewServeMux()
811 mux.HandleFunc("/api/register", handler.Register)
812 mux.HandleFunc("/api/login", handler.Login)
813 mux.HandleFunc("/api/profile", handler.requireAuth(handler.requireCSRF(handler.Profile)))
814
815 // Wrap with security middleware
816 handlerChain := SecurityHeaders(securityConfig)(mux)
817
818 fmt.Println("Secure API server starting on :8443")
819 log.Fatal(http.ListenAndServeTLS(":8443", "server.crt", "server.key", handlerChain))
820}
Testing Your Solution
Test your security implementation:
1# Test registration with valid data
2curl -X POST https://localhost:8443/api/register \
3 -H "Content-Type: application/json" \
4 -d '{"username":"testuser","email":"test@example.com","password":"SecurePass123!"}' \
5 -k
6
7# Test registration with weak password
8curl -X POST https://localhost:8443/api/register \
9 -H "Content-Type: application/json" \
10 -d '{"username":"testuser2","email":"test2@example.com","password":"password"}' \
11 -k
12
13# Test rate limiting
14for i in {1..15}; do
15 curl -X POST https://localhost:8443/api/login \
16 -H "Content-Type: application/json" \
17 -d '{"username":"testuser","password":"wrongpassword"}' \
18 -k
19done
20
21# Test security headers
22curl -I https://localhost:8443/api/profile \
23 -H "Cookie: session_id=valid-session-id" \
24 -k
Verify that:
- Input validation prevents weak passwords and invalid emails
- Rate limiting blocks excessive requests
- Security headers are properly set
- CSRF protection blocks state-changing requests without tokens
- Authentication is required for protected endpoints
- Sessions expire properly
Extension Challenges
- Add 2FA support - Implement TOTP-based two-factor authentication
- Implement IP allowlisting - Restrict access based on IP addresses
- Add audit logging - Log all security-relevant events
- Implement account lockout - Lock accounts after failed attempts
- Add API key authentication - Support for API-based access
Key Takeaways
- Defense in depth - Multiple security layers provide better protection
- Input validation is your first line of defense against attacks
- Rate limiting prevents abuse and DoS attacks
- Security headers protect against various client-side attacks
- CSRF protection prevents state-changing requests from other sites
- Secure session management prevents session hijacking
This exercise demonstrates how to implement comprehensive security measures in Go applications, protecting against common OWASP vulnerabilities while maintaining good user experience.