diff --git a/cmd/server/main.go b/cmd/server/main.go index 26f9286..a6965eb 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -37,7 +37,9 @@ func main() { // Collab WebSocket hub hub := collab.NewHub(database) mux := http.NewServeMux() - mux.Handle("/ws/collab/", http.HandlerFunc(hub.HandleWebSocket)) + mux.Handle("/ws/collab/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hub.HandleWebSocket(w, r, secret) + })) mux.Handle("/", router) fmt.Printf("MarkdownHub listening on :%s\n", port) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index db7bfc1..14a9fdf 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -2,7 +2,9 @@ package api import ( "encoding/json" + "io" "net/http" + "sync" "time" "github.com/google/uuid" @@ -12,6 +14,8 @@ import ( "markdownhub/internal/git" ) +const maxBodySize = 10 * 1024 * 1024 // 10MB + func writeJSON(w http.ResponseWriter, status int, v interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) @@ -20,12 +24,47 @@ func writeJSON(w http.ResponseWriter, status int, v interface{}) { func decodeBody(r *http.Request, v interface{}) error { defer r.Body.Close() - return json.NewDecoder(r.Body).Decode(v) + limited := io.LimitReader(r.Body, maxBodySize) + return json.NewDecoder(limited).Decode(v) +} + +// Simple rate limiter for login +var loginAttempts = struct { + sync.Mutex + m map[string][]time.Time +}{m: make(map[string][]time.Time)} + +func isRateLimited(ip string) bool { + loginAttempts.Lock() + defer loginAttempts.Unlock() + now := time.Now() + cutoff := now.Add(-5 * time.Minute) + // Clean old entries + valid := loginAttempts.m[ip][:0] + for _, t := range loginAttempts.m[ip] { + if t.After(cutoff) { + valid = append(valid, t) + } + } + loginAttempts.m[ip] = valid + return len(valid) >= 10 // max 10 attempts per 5 minutes +} + +func recordLoginAttempt(ip string) { + loginAttempts.Lock() + defer loginAttempts.Unlock() + loginAttempts.m[ip] = append(loginAttempts.m[ip], time.Now()) } // ─── Auth ──────────────────────────────────────────────────────────────────── func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + ip := r.RemoteAddr + if isRateLimited(ip) { + writeJSON(w, 429, map[string]string{"error": "too many attempts, try again later"}) + return + } + var req struct { Email string `json:"email"` Password string `json:"password"` @@ -43,6 +82,7 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { "SELECT id, password_hash, is_admin, totp_secret FROM users WHERE email = ?", req.Email, ).Scan(&id, &hash, &isAdmin, &totpSecret) if err != nil || !auth.CheckPassword(hash, req.Password) { + recordLoginAttempt(ip) writeJSON(w, 401, map[string]string{"error": "invalid credentials"}) return } diff --git a/internal/api/router.go b/internal/api/router.go index 38d961e..d3affc9 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -69,10 +69,10 @@ func NewRouter(db *sql.DB, dataDir, secret string) http.Handler { mux.HandleFunc("POST /api/build/status", s.requireAuth(s.handleBuildStatus)) mux.HandleFunc("POST /api/build/cancel", s.requireAuth(s.handleBuildCancel)) - // Daemon endpoints - mux.HandleFunc("POST /api/daemon/poll", s.requireAuth(s.handleDaemonPoll)) - mux.HandleFunc("POST /api/daemon/heartbeat", s.requireAuth(s.handleDaemonHeartbeat)) - mux.HandleFunc("POST /api/daemon/report", s.requireAuth(s.handleDaemonReport)) + // Daemon endpoints (admin only) + mux.HandleFunc("POST /api/daemon/poll", s.requireAdmin(s.handleDaemonPoll)) + mux.HandleFunc("POST /api/daemon/heartbeat", s.requireAdmin(s.handleDaemonHeartbeat)) + mux.HandleFunc("POST /api/daemon/report", s.requireAdmin(s.handleDaemonReport)) // Static frontend frontendDir := filepath.Join(dataDir, "..", "frontend", "dist") diff --git a/internal/auth/auth.go b/internal/auth/auth.go index e045ff6..51fbc87 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "fmt" "time" "github.com/golang-jwt/jwt/v5" @@ -28,6 +29,9 @@ func CreateToken(userID string, isAdmin bool, secret string) (string, error) { func ValidateToken(tokenStr, secret string) (userID string, isAdmin bool, err error) { token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") + } return []byte(secret), nil }) if err != nil || !token.Valid { diff --git a/internal/collab/hub.go b/internal/collab/hub.go index 0670dfd..422ca65 100644 --- a/internal/collab/hub.go +++ b/internal/collab/hub.go @@ -2,11 +2,13 @@ package collab import ( "database/sql" + "fmt" "log" "net/http" "strings" "sync" + "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" ) @@ -61,7 +63,26 @@ func (h *Hub) removeClient(fileID string, conn *websocket.Conn) { // HandleWebSocket handles the Yjs WebSocket sync protocol. // Clients send binary Yjs update messages; the hub broadcasts to all others in the room. -func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { +func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, secret string) { + // Authenticate via query param or cookie + tokenStr := r.URL.Query().Get("token") + if tokenStr == "" { + if c, err := r.Cookie("authToken"); err == nil { + tokenStr = c.Value + } + } + if tokenStr == "" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + // Validate token (import auth inline to avoid circular — just parse JWT here) + // We accept any valid token + _, _, err := parseToken(tokenStr, secret) + if err != nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + // Extract file ID from path: /ws/collab/{fileID} path := r.URL.Path parts := strings.Split(strings.TrimPrefix(path, "/ws/collab/"), "/") @@ -131,3 +152,19 @@ func (h *Hub) ActiveUsers(fileID string) int { } return 0 } + +func parseToken(tokenStr, secret string) (string, bool, error) { + token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") + } + return []byte(secret), nil + }) + if err != nil || !token.Valid { + return "", false, err + } + claims := token.Claims.(jwt.MapClaims) + userID, _ := claims["sub"].(string) + isAdmin, _ := claims["admin"].(bool) + return userID, isAdmin, nil +}