Security hardening
- JWT: validate signing algorithm (prevent alg confusion) - Login: rate limiting (10 attempts per 5 min per IP) - Request body: 10MB size limit (prevent DoS) - WebSocket: require JWT auth (token query param or cookie) - Daemon endpoints: require admin role (not just any user) - io.LimitReader on all request body decoding
This commit is contained in:
+3
-1
@@ -37,7 +37,9 @@ func main() {
|
|||||||
// Collab WebSocket hub
|
// Collab WebSocket hub
|
||||||
hub := collab.NewHub(database)
|
hub := collab.NewHub(database)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.Handle("/ws/collab/", http.HandlerFunc(hub.HandleWebSocket))
|
mux.Handle("/ws/collab/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hub.HandleWebSocket(w, r, secret)
|
||||||
|
}))
|
||||||
mux.Handle("/", router)
|
mux.Handle("/", router)
|
||||||
|
|
||||||
fmt.Printf("MarkdownHub listening on :%s\n", port)
|
fmt.Printf("MarkdownHub listening on :%s\n", port)
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -12,6 +14,8 @@ import (
|
|||||||
"markdownhub/internal/git"
|
"markdownhub/internal/git"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxBodySize = 10 * 1024 * 1024 // 10MB
|
||||||
|
|
||||||
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
@@ -20,12 +24,47 @@ func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
|||||||
|
|
||||||
func decodeBody(r *http.Request, v interface{}) error {
|
func decodeBody(r *http.Request, v interface{}) error {
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
return json.NewDecoder(r.Body).Decode(v)
|
limited := io.LimitReader(r.Body, maxBodySize)
|
||||||
|
return json.NewDecoder(limited).Decode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple rate limiter for login
|
||||||
|
var loginAttempts = struct {
|
||||||
|
sync.Mutex
|
||||||
|
m map[string][]time.Time
|
||||||
|
}{m: make(map[string][]time.Time)}
|
||||||
|
|
||||||
|
func isRateLimited(ip string) bool {
|
||||||
|
loginAttempts.Lock()
|
||||||
|
defer loginAttempts.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
cutoff := now.Add(-5 * time.Minute)
|
||||||
|
// Clean old entries
|
||||||
|
valid := loginAttempts.m[ip][:0]
|
||||||
|
for _, t := range loginAttempts.m[ip] {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
valid = append(valid, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loginAttempts.m[ip] = valid
|
||||||
|
return len(valid) >= 10 // max 10 attempts per 5 minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordLoginAttempt(ip string) {
|
||||||
|
loginAttempts.Lock()
|
||||||
|
defer loginAttempts.Unlock()
|
||||||
|
loginAttempts.m[ip] = append(loginAttempts.m[ip], time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Auth ────────────────────────────────────────────────────────────────────
|
// ─── Auth ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := r.RemoteAddr
|
||||||
|
if isRateLimited(ip) {
|
||||||
|
writeJSON(w, 429, map[string]string{"error": "too many attempts, try again later"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
@@ -43,6 +82,7 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"SELECT id, password_hash, is_admin, totp_secret FROM users WHERE email = ?", req.Email,
|
"SELECT id, password_hash, is_admin, totp_secret FROM users WHERE email = ?", req.Email,
|
||||||
).Scan(&id, &hash, &isAdmin, &totpSecret)
|
).Scan(&id, &hash, &isAdmin, &totpSecret)
|
||||||
if err != nil || !auth.CheckPassword(hash, req.Password) {
|
if err != nil || !auth.CheckPassword(hash, req.Password) {
|
||||||
|
recordLoginAttempt(ip)
|
||||||
writeJSON(w, 401, map[string]string{"error": "invalid credentials"})
|
writeJSON(w, 401, map[string]string{"error": "invalid credentials"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -69,10 +69,10 @@ func NewRouter(db *sql.DB, dataDir, secret string) http.Handler {
|
|||||||
mux.HandleFunc("POST /api/build/status", s.requireAuth(s.handleBuildStatus))
|
mux.HandleFunc("POST /api/build/status", s.requireAuth(s.handleBuildStatus))
|
||||||
mux.HandleFunc("POST /api/build/cancel", s.requireAuth(s.handleBuildCancel))
|
mux.HandleFunc("POST /api/build/cancel", s.requireAuth(s.handleBuildCancel))
|
||||||
|
|
||||||
// Daemon endpoints
|
// Daemon endpoints (admin only)
|
||||||
mux.HandleFunc("POST /api/daemon/poll", s.requireAuth(s.handleDaemonPoll))
|
mux.HandleFunc("POST /api/daemon/poll", s.requireAdmin(s.handleDaemonPoll))
|
||||||
mux.HandleFunc("POST /api/daemon/heartbeat", s.requireAuth(s.handleDaemonHeartbeat))
|
mux.HandleFunc("POST /api/daemon/heartbeat", s.requireAdmin(s.handleDaemonHeartbeat))
|
||||||
mux.HandleFunc("POST /api/daemon/report", s.requireAuth(s.handleDaemonReport))
|
mux.HandleFunc("POST /api/daemon/report", s.requireAdmin(s.handleDaemonReport))
|
||||||
|
|
||||||
// Static frontend
|
// Static frontend
|
||||||
frontendDir := filepath.Join(dataDir, "..", "frontend", "dist")
|
frontendDir := filepath.Join(dataDir, "..", "frontend", "dist")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@@ -28,6 +29,9 @@ func CreateToken(userID string, isAdmin bool, secret string) (string, error) {
|
|||||||
|
|
||||||
func ValidateToken(tokenStr, secret string) (userID string, isAdmin bool, err error) {
|
func ValidateToken(tokenStr, secret string) (userID string, isAdmin bool, err error) {
|
||||||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method")
|
||||||
|
}
|
||||||
return []byte(secret), nil
|
return []byte(secret), nil
|
||||||
})
|
})
|
||||||
if err != nil || !token.Valid {
|
if err != nil || !token.Valid {
|
||||||
|
|||||||
+38
-1
@@ -2,11 +2,13 @@ package collab
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -61,7 +63,26 @@ func (h *Hub) removeClient(fileID string, conn *websocket.Conn) {
|
|||||||
|
|
||||||
// HandleWebSocket handles the Yjs WebSocket sync protocol.
|
// HandleWebSocket handles the Yjs WebSocket sync protocol.
|
||||||
// Clients send binary Yjs update messages; the hub broadcasts to all others in the room.
|
// Clients send binary Yjs update messages; the hub broadcasts to all others in the room.
|
||||||
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, secret string) {
|
||||||
|
// Authenticate via query param or cookie
|
||||||
|
tokenStr := r.URL.Query().Get("token")
|
||||||
|
if tokenStr == "" {
|
||||||
|
if c, err := r.Cookie("authToken"); err == nil {
|
||||||
|
tokenStr = c.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tokenStr == "" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Validate token (import auth inline to avoid circular — just parse JWT here)
|
||||||
|
// We accept any valid token
|
||||||
|
_, _, err := parseToken(tokenStr, secret)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Extract file ID from path: /ws/collab/{fileID}
|
// Extract file ID from path: /ws/collab/{fileID}
|
||||||
path := r.URL.Path
|
path := r.URL.Path
|
||||||
parts := strings.Split(strings.TrimPrefix(path, "/ws/collab/"), "/")
|
parts := strings.Split(strings.TrimPrefix(path, "/ws/collab/"), "/")
|
||||||
@@ -131,3 +152,19 @@ func (h *Hub) ActiveUsers(fileID string) int {
|
|||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseToken(tokenStr, secret string) (string, bool, error) {
|
||||||
|
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method")
|
||||||
|
}
|
||||||
|
return []byte(secret), nil
|
||||||
|
})
|
||||||
|
if err != nil || !token.Valid {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
userID, _ := claims["sub"].(string)
|
||||||
|
isAdmin, _ := claims["admin"].(bool)
|
||||||
|
return userID, isAdmin, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user