diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 48fb442..3abd07c 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -1,10 +1,16 @@ package main import ( + "context" "log" + "os" + "os/signal" + "syscall" + "tyto/internal/agent" "tyto/internal/api" "tyto/internal/config" + "tyto/internal/server" "tyto/internal/sse" ) @@ -60,29 +66,81 @@ func runServer(cfg *config.Config) { log.Printf("gRPC port for agents: %d", cfg.Server.GRPCPort) log.Printf("Database: %s", cfg.Database.Type) - // TODO: Initialize database - // TODO: Initialize authentication - // TODO: Initialize gRPC server for agents - // TODO: Initialize agent hub + // Set up signal handling + _, cancel := context.WithCancel(context.Background()) + defer cancel() - // For now, run in standalone-compatible mode - // Full server mode will be implemented in subsequent sprints + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + // Initialize agent registry + registryPath := "/var/lib/tyto/agents.json" + if cfg.Database.SQLitePath != "" { + registryPath = cfg.Database.SQLitePath + ".agents.json" + } + registry := server.NewRegistry(registryPath) + log.Printf("Agent registry initialized: %s", registryPath) + + // Initialize Hub + hubConfig := &server.HubConfig{ + RequireApproval: cfg.Server.Registration.RequireApproval, + AutoApprove: cfg.Server.Registration.AutoEnabled && !cfg.Server.Registration.RequireApproval, + } + hub := server.NewHub(registry, hubConfig) + hub.Start() + defer hub.Stop() + log.Println("Agent hub started") + + // Initialize SSE bridge for multi-device streaming + bridge := server.NewSSEBridge(hub) + bridge.Start() + defer bridge.Stop() + + // Initialize gRPC server for agent connections + grpcServer, err := server.NewGRPCServer(hub, &cfg.Server) + if err != nil { + log.Fatalf("Failed to create gRPC server: %v", err) + } + + // Start gRPC server in background + go func() { + log.Printf("Starting gRPC server on port %d", cfg.Server.GRPCPort) + if err := grpcServer.Start(cfg.Server.GRPCPort); err != nil { + log.Printf("gRPC server error: %v", err) + } + }() + defer grpcServer.Stop() + + // Initialize SSE broker for local metrics (also runs in server mode) broker := sse.NewBroker(cfg) go broker.Run() - server := api.NewServer(cfg, broker) + // Initialize HTTP API server with agent management + apiServer := api.NewServer(cfg, broker) - var err error - if cfg.TLSEnabled { - log.Printf("Starting HTTPS server on port %s", cfg.Port) - err = server.RunTLS(cfg.TLSCertFile, cfg.TLSKeyFile) - } else { - err = server.Run() - } + // Add agent API routes + agentAPI := api.NewAgentAPI(registry, hub) + agentAPI.RegisterRoutes(apiServer.Router().Group("/api/v1")) - if err != nil { - log.Fatalf("Failed to start server: %v", err) - } + // Start HTTP server in background + go func() { + var err error + if cfg.TLSEnabled { + log.Printf("Starting HTTPS server on port %s", cfg.Port) + err = apiServer.RunTLS(cfg.TLSCertFile, cfg.TLSKeyFile) + } else { + log.Printf("Starting HTTP server on port %s", cfg.Port) + err = apiServer.Run() + } + if err != nil { + log.Printf("HTTP server error: %v", err) + } + }() + + // Wait for shutdown signal + <-sigCh + log.Println("Shutting down server...") + cancel() } // runAgent starts Tyto as a lightweight agent that reports to a central server. @@ -98,7 +156,28 @@ func runAgent(cfg *config.Config) { log.Printf("Reporting to: %s", cfg.Agent.ServerURL) log.Printf("Collection interval: %s", cfg.Agent.Interval) - // TODO: Implement gRPC client and metrics collection loop - // This will be implemented in Sprint 3 (Agent Implementation) - log.Fatal("Agent mode not yet implemented") + // Set up signal handling + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + // Create agent + a := agent.New(cfg) + + // Handle shutdown signal + go func() { + <-sigCh + log.Println("Received shutdown signal, stopping agent...") + a.Stop() + cancel() + }() + + // Run agent + if err := a.Run(ctx); err != nil && err != context.Canceled { + log.Fatalf("Agent error: %v", err) + } + + log.Println("Agent stopped") } diff --git a/backend/internal/api/agents.go b/backend/internal/api/agents.go new file mode 100644 index 0000000..bd905fe --- /dev/null +++ b/backend/internal/api/agents.go @@ -0,0 +1,224 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "tyto/internal/server" +) + +// AgentAPI handles agent management endpoints. +type AgentAPI struct { + registry *server.Registry + hub *server.Hub +} + +// NewAgentAPI creates a new agent API handler. +func NewAgentAPI(registry *server.Registry, hub *server.Hub) *AgentAPI { + return &AgentAPI{ + registry: registry, + hub: hub, + } +} + +// RegisterRoutes adds agent management routes to a router group. +func (a *AgentAPI) RegisterRoutes(group *gin.RouterGroup) { + agents := group.Group("/agents") + { + agents.GET("", a.listAgents) + agents.GET("/:id", a.getAgent) + agents.DELETE("/:id", a.removeAgent) + agents.POST("/:id/revoke", a.revokeAgent) + + // Pending registrations + agents.GET("/pending", a.listPending) + agents.POST("/pending/:id/approve", a.approveAgent) + agents.POST("/pending/:id/reject", a.rejectAgent) + + // Connected agents + agents.GET("/connected", a.listConnected) + agents.GET("/:id/metrics", a.getAgentMetrics) + } +} + +// AgentResponse is the API representation of an agent. +type AgentResponse struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Hostname string `json:"hostname"` + OS string `json:"os"` + Architecture string `json:"architecture"` + Version string `json:"version"` + Capabilities []string `json:"capabilities,omitempty"` + Status string `json:"status"` + Connected bool `json:"connected"` + LastSeen string `json:"lastSeen,omitempty"` + RegisteredAt string `json:"registeredAt"` + Tags []string `json:"tags,omitempty"` +} + +func agentToResponse(agent *server.AgentRecord, connected bool) AgentResponse { + lastSeen := "" + if !agent.LastSeen.IsZero() { + lastSeen = agent.LastSeen.Format("2006-01-02T15:04:05Z07:00") + } + + registeredAt := "" + if !agent.RegisteredAt.IsZero() { + registeredAt = agent.RegisteredAt.Format("2006-01-02T15:04:05Z07:00") + } + + return AgentResponse{ + ID: agent.ID, + Name: agent.Name, + Hostname: agent.Hostname, + OS: agent.OS, + Architecture: agent.Architecture, + Version: agent.Version, + Capabilities: agent.Capabilities, + Status: string(agent.Status), + Connected: connected, + LastSeen: lastSeen, + RegisteredAt: registeredAt, + Tags: agent.Tags, + } +} + +// listAgents returns all registered agents. +func (a *AgentAPI) listAgents(c *gin.Context) { + agents := a.registry.List() + connectedIDs := make(map[string]bool) + if a.hub != nil { + for _, id := range a.hub.GetConnectedAgents() { + connectedIDs[id] = true + } + } + + response := make([]AgentResponse, len(agents)) + for i, agent := range agents { + response[i] = agentToResponse(agent, connectedIDs[agent.ID]) + } + + c.JSON(http.StatusOK, response) +} + +// getAgent returns a specific agent. +func (a *AgentAPI) getAgent(c *gin.Context) { + id := c.Param("id") + + agent, exists := a.registry.Get(id) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"}) + return + } + + connected := false + if a.hub != nil { + for _, connID := range a.hub.GetConnectedAgents() { + if connID == id { + connected = true + break + } + } + } + + c.JSON(http.StatusOK, agentToResponse(agent, connected)) +} + +// removeAgent removes an agent registration. +func (a *AgentAPI) removeAgent(c *gin.Context) { + id := c.Param("id") + + if err := a.registry.Remove(id); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "removed"}) +} + +// revokeAgent revokes an agent's registration. +func (a *AgentAPI) revokeAgent(c *gin.Context) { + id := c.Param("id") + + if err := a.registry.Revoke(id); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "revoked"}) +} + +// listPending returns agents awaiting approval. +func (a *AgentAPI) listPending(c *gin.Context) { + agents := a.registry.ListPending() + + response := make([]AgentResponse, len(agents)) + for i, agent := range agents { + response[i] = agentToResponse(agent, false) + } + + c.JSON(http.StatusOK, response) +} + +// approveAgent approves a pending agent. +func (a *AgentAPI) approveAgent(c *gin.Context) { + id := c.Param("id") + + if err := a.registry.Approve(id); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "approved"}) +} + +// rejectAgent rejects a pending agent. +func (a *AgentAPI) rejectAgent(c *gin.Context) { + id := c.Param("id") + + if err := a.registry.Reject(id); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "rejected"}) +} + +// listConnected returns currently connected agents. +func (a *AgentAPI) listConnected(c *gin.Context) { + if a.hub == nil { + c.JSON(http.StatusOK, []AgentResponse{}) + return + } + + connectedIDs := a.hub.GetConnectedAgents() + response := make([]AgentResponse, 0, len(connectedIDs)) + + for _, id := range connectedIDs { + if agent, exists := a.registry.Get(id); exists { + response = append(response, agentToResponse(agent, true)) + } + } + + c.JSON(http.StatusOK, response) +} + +// getAgentMetrics returns the latest metrics for an agent. +func (a *AgentAPI) getAgentMetrics(c *gin.Context) { + id := c.Param("id") + + if a.hub == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "hub not available"}) + return + } + + metrics, exists := a.hub.GetAgentMetrics(id) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "no metrics available"}) + return + } + + c.JSON(http.StatusOK, metrics) +} diff --git a/backend/internal/api/routes.go b/backend/internal/api/routes.go index 8a60f23..9804228 100644 --- a/backend/internal/api/routes.go +++ b/backend/internal/api/routes.go @@ -335,6 +335,11 @@ func (s *Server) ListenAddr() string { return fmt.Sprintf(":%s", s.cfg.Port) } +// Router returns the underlying Gin engine for adding routes. +func (s *Server) Router() *gin.Engine { + return s.router +} + // Alert handlers func (s *Server) getAlertsHandler(c *gin.Context) { response := models.AlertsResponse{ diff --git a/backend/internal/pki/tls.go b/backend/internal/pki/tls.go index 319a49e..0d27071 100644 --- a/backend/internal/pki/tls.go +++ b/backend/internal/pki/tls.go @@ -101,18 +101,22 @@ func ServerTLSConfigWithCRL(ca *CA, serverCertPath, serverKeyPath string) (*tls. }, nil } -// ExtractAgentID extracts the agent ID from a verified client certificate. +// ExtractAgentIDFromState extracts the agent ID from a verified TLS connection state. // The agent ID is stored in the certificate's CommonName. -func ExtractAgentID(state *tls.ConnectionState) (string, error) { +func ExtractAgentIDFromState(state *tls.ConnectionState) (string, error) { if len(state.VerifiedChains) == 0 || len(state.VerifiedChains[0]) == 0 { return "", fmt.Errorf("no verified certificate chain") } cert := state.VerifiedChains[0][0] - agentID := cert.Subject.CommonName - if agentID == "" { - return "", fmt.Errorf("certificate has no CommonName") - } - - return agentID, nil + return ExtractAgentID(cert), nil +} + +// ExtractAgentID extracts the agent ID from a certificate. +// The agent ID is stored in the certificate's CommonName. +func ExtractAgentID(cert *x509.Certificate) string { + if cert == nil { + return "" + } + return cert.Subject.CommonName } diff --git a/backend/internal/server/bridge.go b/backend/internal/server/bridge.go new file mode 100644 index 0000000..819d01d --- /dev/null +++ b/backend/internal/server/bridge.go @@ -0,0 +1,275 @@ +package server + +import ( + "encoding/json" + "log" + "sync" + "time" + + "tyto/internal/models" + pb "tyto/internal/proto" +) + +// SSEBridge bridges the Hub to SSE clients for multi-device monitoring. +// It implements MetricsSubscriber to receive agent updates. +type SSEBridge struct { + hub *Hub + clients map[chan []byte]bool + mu sync.RWMutex + broadcast chan []byte + + // Cache of latest metrics per device + metricsCache map[string]*DeviceMetrics + cacheMu sync.RWMutex + + stopCh chan struct{} + wg sync.WaitGroup +} + +// DeviceMetrics wraps metrics with device info. +type DeviceMetrics struct { + DeviceID string `json:"deviceId"` + Hostname string `json:"hostname,omitempty"` + Status string `json:"status"` + LastUpdated time.Time `json:"lastUpdated"` + Metrics *models.AllMetrics `json:"metrics"` +} + +// MultiDeviceSnapshot is sent to SSE clients. +type MultiDeviceSnapshot struct { + Type string `json:"type"` + Timestamp string `json:"timestamp"` + Devices map[string]*DeviceMetrics `json:"devices"` +} + +// DeviceUpdate is sent when a single device's metrics change. +type DeviceUpdate struct { + Type string `json:"type"` + DeviceID string `json:"deviceId"` + Status string `json:"status"` + Metrics *models.AllMetrics `json:"metrics"` +} + +// DeviceStatusChange is sent when a device connects/disconnects. +type DeviceStatusChange struct { + Type string `json:"type"` + DeviceID string `json:"deviceId"` + Status string `json:"status"` + Hostname string `json:"hostname,omitempty"` +} + +// NewSSEBridge creates a new SSE bridge. +func NewSSEBridge(hub *Hub) *SSEBridge { + bridge := &SSEBridge{ + hub: hub, + clients: make(map[chan []byte]bool), + broadcast: make(chan []byte, 256), + metricsCache: make(map[string]*DeviceMetrics), + stopCh: make(chan struct{}), + } + + // Subscribe to hub events + hub.Subscribe(bridge) + + return bridge +} + +// Start begins the broadcast loop. +func (b *SSEBridge) Start() { + b.wg.Add(1) + go b.broadcastLoop() +} + +// Stop stops the bridge. +func (b *SSEBridge) Stop() { + close(b.stopCh) + b.wg.Wait() +} + +func (b *SSEBridge) broadcastLoop() { + defer b.wg.Done() + + for { + select { + case <-b.stopCh: + return + case data := <-b.broadcast: + b.mu.RLock() + for client := range b.clients { + select { + case client <- data: + default: + // Client buffer full, skip + } + } + b.mu.RUnlock() + } + } +} + +// Register adds an SSE client. +func (b *SSEBridge) Register(client chan []byte) { + b.mu.Lock() + b.clients[client] = true + b.mu.Unlock() + + // Send current snapshot to new client + go b.sendSnapshot(client) +} + +// Unregister removes an SSE client. +func (b *SSEBridge) Unregister(client chan []byte) { + b.mu.Lock() + delete(b.clients, client) + b.mu.Unlock() +} + +func (b *SSEBridge) sendSnapshot(client chan []byte) { + b.cacheMu.RLock() + snapshot := MultiDeviceSnapshot{ + Type: "snapshot", + Timestamp: time.Now().Format(time.RFC3339), + Devices: b.metricsCache, + } + b.cacheMu.RUnlock() + + data, err := json.Marshal(snapshot) + if err != nil { + log.Printf("Failed to marshal snapshot: %v", err) + return + } + + select { + case client <- data: + default: + } +} + +// MetricsSubscriber implementation + +// OnAgentMetrics is called when an agent sends metrics. +func (b *SSEBridge) OnAgentMetrics(agentID string, metrics *models.AllMetrics) { + // Update cache + b.cacheMu.Lock() + if existing, ok := b.metricsCache[agentID]; ok { + existing.Metrics = metrics + existing.LastUpdated = time.Now() + existing.Status = "online" + } else { + b.metricsCache[agentID] = &DeviceMetrics{ + DeviceID: agentID, + Status: "online", + LastUpdated: time.Now(), + Metrics: metrics, + } + } + b.cacheMu.Unlock() + + // Broadcast update + update := DeviceUpdate{ + Type: "update", + DeviceID: agentID, + Status: "online", + Metrics: metrics, + } + + data, err := json.Marshal(update) + if err != nil { + log.Printf("Failed to marshal update: %v", err) + return + } + + select { + case b.broadcast <- data: + default: + log.Println("Broadcast buffer full, dropping update") + } +} + +// OnAgentConnected is called when an agent connects. +func (b *SSEBridge) OnAgentConnected(agentID string, info *pb.AgentInfo) { + hostname := "" + if info != nil { + hostname = info.Hostname + } + + // Update cache + b.cacheMu.Lock() + if existing, ok := b.metricsCache[agentID]; ok { + existing.Status = "online" + existing.Hostname = hostname + } else { + b.metricsCache[agentID] = &DeviceMetrics{ + DeviceID: agentID, + Hostname: hostname, + Status: "online", + } + } + b.cacheMu.Unlock() + + // Broadcast status change + change := DeviceStatusChange{ + Type: "connected", + DeviceID: agentID, + Status: "online", + Hostname: hostname, + } + + data, err := json.Marshal(change) + if err != nil { + return + } + + select { + case b.broadcast <- data: + default: + } +} + +// OnAgentDisconnected is called when an agent disconnects. +func (b *SSEBridge) OnAgentDisconnected(agentID string) { + // Update cache + b.cacheMu.Lock() + if existing, ok := b.metricsCache[agentID]; ok { + existing.Status = "offline" + } + b.cacheMu.Unlock() + + // Broadcast status change + change := DeviceStatusChange{ + Type: "disconnected", + DeviceID: agentID, + Status: "offline", + } + + data, err := json.Marshal(change) + if err != nil { + return + } + + select { + case b.broadcast <- data: + default: + } +} + +// GetSnapshot returns the current metrics snapshot. +func (b *SSEBridge) GetSnapshot() map[string]*DeviceMetrics { + b.cacheMu.RLock() + defer b.cacheMu.RUnlock() + + snapshot := make(map[string]*DeviceMetrics, len(b.metricsCache)) + for k, v := range b.metricsCache { + snapshot[k] = v + } + return snapshot +} + +// GetDeviceMetrics returns metrics for a specific device. +func (b *SSEBridge) GetDeviceMetrics(deviceID string) (*DeviceMetrics, bool) { + b.cacheMu.RLock() + defer b.cacheMu.RUnlock() + + metrics, ok := b.metricsCache[deviceID] + return metrics, ok +} diff --git a/backend/internal/server/grpc.go b/backend/internal/server/grpc.go new file mode 100644 index 0000000..1e451d9 --- /dev/null +++ b/backend/internal/server/grpc.go @@ -0,0 +1,232 @@ +package server + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "log" + "net" + "os" + "strings" + + "tyto/internal/config" + "tyto/internal/pki" + pb "tyto/internal/proto" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +// GRPCServer wraps the gRPC server and Hub. +type GRPCServer struct { + pb.UnimplementedAgentServiceServer + + hub *Hub + server *grpc.Server + listener net.Listener + config *config.ServerConfig +} + +// NewGRPCServer creates a new gRPC server. +func NewGRPCServer(hub *Hub, cfg *config.ServerConfig) (*GRPCServer, error) { + s := &GRPCServer{ + hub: hub, + config: cfg, + } + + // Build server options + opts, err := s.serverOptions() + if err != nil { + return nil, fmt.Errorf("failed to create server options: %w", err) + } + + s.server = grpc.NewServer(opts...) + pb.RegisterAgentServiceServer(s.server, s) + + return s, nil +} + +func (s *GRPCServer) serverOptions() ([]grpc.ServerOption, error) { + var opts []grpc.ServerOption + + tlsCfg := s.config.TLS + if tlsCfg.CACert != "" && tlsCfg.ServerCert != "" { + // Load mTLS configuration + tlsConfig, err := s.loadTLSConfig() + if err != nil { + return nil, err + } + opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig))) + } else { + log.Println("Warning: gRPC server running without TLS (insecure mode)") + } + + // Add interceptors for authentication and logging + opts = append(opts, + grpc.UnaryInterceptor(s.unaryInterceptor), + grpc.StreamInterceptor(s.streamInterceptor), + ) + + return opts, nil +} + +func (s *GRPCServer) loadTLSConfig() (*tls.Config, error) { + tlsCfg := s.config.TLS + + // Load CA certificate + caCert, err := os.ReadFile(tlsCfg.CACert) + if err != nil { + return nil, fmt.Errorf("failed to read CA cert: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA cert") + } + + // Load server certificate + cert, err := tls.LoadX509KeyPair(tlsCfg.ServerCert, tlsCfg.ServerKey) + if err != nil { + return nil, fmt.Errorf("failed to load server cert: %w", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + }, nil +} + +// Start begins listening for connections. +func (s *GRPCServer) Start(port int) error { + addr := fmt.Sprintf(":%d", port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + + s.listener = listener + log.Printf("gRPC server listening on %s", addr) + + return s.server.Serve(listener) +} + +// Stop gracefully stops the server. +func (s *GRPCServer) Stop() { + if s.server != nil { + s.server.GracefulStop() + } +} + +// Interceptors + +func (s *GRPCServer) unaryInterceptor( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (interface{}, error) { + // Extract agent ID from TLS certificate + agentID, err := s.extractAgentID(ctx) + if err != nil { + // Allow register without pre-auth for initial registration + if !strings.Contains(info.FullMethod, "Register") { + return nil, err + } + // For registration, use the agent ID from the request + if regReq, ok := req.(*pb.RegisterRequest); ok { + agentID = regReq.AgentId + } + } + + // Add agent ID to context + ctx = ContextWithAgentID(ctx, agentID) + + // Log the request + log.Printf("gRPC %s from agent %s", info.FullMethod, agentID) + + return handler(ctx, req) +} + +func (s *GRPCServer) streamInterceptor( + srv interface{}, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + // Extract agent ID from TLS certificate + agentID, err := s.extractAgentID(ss.Context()) + if err != nil { + return err + } + + // Wrap the stream with context containing agent ID + wrapped := &wrappedServerStream{ + ServerStream: ss, + ctx: ContextWithAgentID(ss.Context(), agentID), + } + + log.Printf("gRPC stream %s from agent %s", info.FullMethod, agentID) + + return handler(srv, wrapped) +} + +func (s *GRPCServer) extractAgentID(ctx context.Context) (string, error) { + // Try to get peer info + p, ok := peer.FromContext(ctx) + if !ok { + return "", status.Error(codes.Unauthenticated, "no peer info") + } + + // Check for TLS info + tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo) + if !ok { + // No TLS, might be insecure mode - check for existing context value + if id, ok := AgentIDFromContext(ctx); ok { + return id, nil + } + return "", status.Error(codes.Unauthenticated, "no TLS info") + } + + // Extract CN from client certificate + if len(tlsInfo.State.PeerCertificates) == 0 { + return "", status.Error(codes.Unauthenticated, "no client certificate") + } + + cert := tlsInfo.State.PeerCertificates[0] + agentID := pki.ExtractAgentID(cert) + if agentID == "" { + return "", status.Error(codes.Unauthenticated, "no agent ID in certificate") + } + + return agentID, nil +} + +// wrappedServerStream wraps a ServerStream to override Context(). +type wrappedServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedServerStream) Context() context.Context { + return w.ctx +} + +// Service methods - delegate to Hub + +func (s *GRPCServer) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) { + return s.hub.Register(ctx, req) +} + +func (s *GRPCServer) Stream(stream pb.AgentService_StreamServer) error { + return s.hub.Stream(stream) +} + +func (s *GRPCServer) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) { + return s.hub.Heartbeat(ctx, req) +} diff --git a/backend/internal/server/hub.go b/backend/internal/server/hub.go new file mode 100644 index 0000000..caa9aec --- /dev/null +++ b/backend/internal/server/hub.go @@ -0,0 +1,513 @@ +package server + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + "tyto/internal/models" + pb "tyto/internal/proto" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// AgentConnection represents a connected agent. +type AgentConnection struct { + ID string + Stream pb.AgentService_StreamServer + Info *pb.AgentInfo + LastMetrics *models.AllMetrics + LastSeen time.Time + Connected bool + SendCh chan *pb.ServerMessage + cancel context.CancelFunc +} + +// MetricsSubscriber receives aggregated metrics from all agents. +type MetricsSubscriber interface { + OnAgentMetrics(agentID string, metrics *models.AllMetrics) + OnAgentConnected(agentID string, info *pb.AgentInfo) + OnAgentDisconnected(agentID string) +} + +// Hub manages agent connections and message routing. +type Hub struct { + registry *Registry + config *HubConfig + agents map[string]*AgentConnection + mu sync.RWMutex + subscribers []MetricsSubscriber + + // Channels for internal coordination + registerCh chan *AgentConnection + unregisterCh chan string + stopCh chan struct{} + wg sync.WaitGroup +} + +// HubConfig contains Hub configuration. +type HubConfig struct { + RequireApproval bool + AutoApprove bool +} + +// NewHub creates a new Hub instance. +func NewHub(registry *Registry, config *HubConfig) *Hub { + if config == nil { + config = &HubConfig{ + RequireApproval: true, + AutoApprove: false, + } + } + + return &Hub{ + registry: registry, + config: config, + agents: make(map[string]*AgentConnection), + registerCh: make(chan *AgentConnection, 16), + unregisterCh: make(chan string, 16), + stopCh: make(chan struct{}), + } +} + +// Subscribe adds a metrics subscriber. +func (h *Hub) Subscribe(sub MetricsSubscriber) { + h.mu.Lock() + defer h.mu.Unlock() + h.subscribers = append(h.subscribers, sub) +} + +// Start begins the hub's event loop. +func (h *Hub) Start() { + h.wg.Add(1) + go h.run() +} + +// Stop gracefully shuts down the hub. +func (h *Hub) Stop() { + close(h.stopCh) + h.wg.Wait() +} + +func (h *Hub) run() { + defer h.wg.Done() + + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-h.stopCh: + h.disconnectAll() + return + + case conn := <-h.registerCh: + h.handleRegister(conn) + + case agentID := <-h.unregisterCh: + h.handleUnregister(agentID) + + case <-ticker.C: + h.checkStaleConnections() + } + } +} + +func (h *Hub) handleRegister(conn *AgentConnection) { + h.mu.Lock() + defer h.mu.Unlock() + + // Close existing connection if any + if existing, ok := h.agents[conn.ID]; ok { + existing.Connected = false + if existing.cancel != nil { + existing.cancel() + } + close(existing.SendCh) + } + + h.agents[conn.ID] = conn + log.Printf("Agent registered: %s", conn.ID) + + // Notify subscribers + for _, sub := range h.subscribers { + sub.OnAgentConnected(conn.ID, conn.Info) + } +} + +func (h *Hub) handleUnregister(agentID string) { + h.mu.Lock() + defer h.mu.Unlock() + + conn, ok := h.agents[agentID] + if !ok { + return + } + + conn.Connected = false + delete(h.agents, agentID) + log.Printf("Agent unregistered: %s", agentID) + + // Update registry status + h.registry.UpdateStatus(agentID, AgentStatusOffline) + + // Notify subscribers + for _, sub := range h.subscribers { + sub.OnAgentDisconnected(agentID) + } +} + +func (h *Hub) checkStaleConnections() { + h.mu.RLock() + staleIDs := make([]string, 0) + for id, conn := range h.agents { + if time.Since(conn.LastSeen) > 60*time.Second { + staleIDs = append(staleIDs, id) + } + } + h.mu.RUnlock() + + for _, id := range staleIDs { + log.Printf("Removing stale agent: %s", id) + h.unregisterCh <- id + } +} + +func (h *Hub) disconnectAll() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, conn := range h.agents { + conn.Connected = false + if conn.cancel != nil { + conn.cancel() + } + close(conn.SendCh) + } + h.agents = make(map[string]*AgentConnection) +} + +// Register handles agent registration requests. +func (h *Hub) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) { + agentID := req.AgentId + info := req.Info + + // Check if already registered + existing, exists := h.registry.Get(agentID) + if exists { + switch existing.Status { + case AgentStatusRevoked: + return &pb.RegisterResponse{ + Status: pb.RegisterStatus_REGISTER_STATUS_REJECTED, + Message: "agent certificate has been revoked", + }, nil + + case AgentStatusApproved, AgentStatusConnected: + // Update info and return success + record := &AgentRecord{ + ID: agentID, + Hostname: info.Hostname, + OS: info.Os, + Architecture: info.Architecture, + Version: info.Version, + Capabilities: info.Capabilities, + } + h.registry.Register(record) + + return &pb.RegisterResponse{ + Status: pb.RegisterStatus_REGISTER_STATUS_ALREADY_REGISTERED, + Message: "already registered", + Config: h.getAgentConfig(), + }, nil + + case AgentStatusPending: + return &pb.RegisterResponse{ + Status: pb.RegisterStatus_REGISTER_STATUS_PENDING_APPROVAL, + Message: "awaiting approval", + }, nil + } + } + + // New registration + record := &AgentRecord{ + ID: agentID, + Hostname: info.Hostname, + OS: info.Os, + Architecture: info.Architecture, + Version: info.Version, + Capabilities: info.Capabilities, + Status: AgentStatusPending, + } + + // Auto-approve if configured + if h.config.AutoApprove || !h.config.RequireApproval { + record.Status = AgentStatusApproved + } + + if err := h.registry.Register(record); err != nil { + return nil, status.Errorf(codes.Internal, "registration failed: %v", err) + } + + if record.Status == AgentStatusApproved { + return &pb.RegisterResponse{ + Status: pb.RegisterStatus_REGISTER_STATUS_ACCEPTED, + Message: "registration accepted", + Config: h.getAgentConfig(), + }, nil + } + + return &pb.RegisterResponse{ + Status: pb.RegisterStatus_REGISTER_STATUS_PENDING_APPROVAL, + Message: "awaiting approval", + }, nil +} + +// Stream handles the bidirectional streaming RPC. +func (h *Hub) Stream(stream pb.AgentService_StreamServer) error { + // Extract agent ID from context (set by auth interceptor) + agentID, ok := AgentIDFromContext(stream.Context()) + if !ok { + return status.Error(codes.Unauthenticated, "agent ID not found in context") + } + + // Verify agent is approved + if !h.registry.IsApproved(agentID) { + return status.Error(codes.PermissionDenied, "agent not approved") + } + + // Create connection + ctx, cancel := context.WithCancel(stream.Context()) + conn := &AgentConnection{ + ID: agentID, + Stream: stream, + Connected: true, + LastSeen: time.Now(), + SendCh: make(chan *pb.ServerMessage, 16), + cancel: cancel, + } + + // Register connection + h.registerCh <- conn + + // Update registry status + h.registry.UpdateStatus(agentID, AgentStatusConnected) + + // Start sender goroutine + h.wg.Add(1) + go h.sendLoop(conn) + + // Receive loop + err := h.receiveLoop(ctx, conn) + + // Cleanup + h.unregisterCh <- agentID + + return err +} + +func (h *Hub) sendLoop(conn *AgentConnection) { + defer h.wg.Done() + + for msg := range conn.SendCh { + if err := conn.Stream.Send(msg); err != nil { + log.Printf("Send error for agent %s: %v", conn.ID, err) + return + } + } +} + +func (h *Hub) receiveLoop(ctx context.Context, conn *AgentConnection) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + msg, err := conn.Stream.Recv() + if err != nil { + return err + } + + conn.LastSeen = time.Now() + h.registry.UpdateLastSeen(conn.ID) + + switch payload := msg.Payload.(type) { + case *pb.AgentMessage_Metrics: + h.handleMetrics(conn, payload.Metrics) + + case *pb.AgentMessage_Heartbeat: + h.handleHeartbeat(conn, payload.Heartbeat) + + case *pb.AgentMessage_Info: + conn.Info = payload.Info + } + } +} + +func (h *Hub) handleMetrics(conn *AgentConnection, report *pb.MetricsReport) { + // Deserialize metrics + var metrics models.AllMetrics + if err := json.Unmarshal(report.MetricsJson, &metrics); err != nil { + log.Printf("Failed to unmarshal metrics from %s: %v", conn.ID, err) + return + } + + metrics.Timestamp = time.UnixMilli(report.TimestampMs) + conn.LastMetrics = &metrics + + // Notify subscribers + h.mu.RLock() + for _, sub := range h.subscribers { + sub.OnAgentMetrics(conn.ID, &metrics) + } + h.mu.RUnlock() + + // Send acknowledgment + conn.SendCh <- &pb.ServerMessage{ + Payload: &pb.ServerMessage_Ack{ + Ack: &pb.Ack{Success: true}, + }, + } +} + +func (h *Hub) handleHeartbeat(conn *AgentConnection, hb *pb.HeartbeatRequest) { + // Just update last seen (already done in receive loop) + log.Printf("Heartbeat from %s (uptime: %ds)", conn.ID, hb.UptimeSeconds) +} + +// Heartbeat handles simple heartbeat RPCs. +func (h *Hub) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) { + h.registry.UpdateLastSeen(req.AgentId) + + return &pb.HeartbeatResponse{ + ServerTimeMs: time.Now().UnixMilli(), + ConfigChanged: false, + }, nil +} + +// SendCommand sends a command to a specific agent. +func (h *Hub) SendCommand(agentID string, cmd *pb.Command) error { + h.mu.RLock() + conn, ok := h.agents[agentID] + h.mu.RUnlock() + + if !ok || !conn.Connected { + return status.Error(codes.NotFound, "agent not connected") + } + + select { + case conn.SendCh <- &pb.ServerMessage{ + Payload: &pb.ServerMessage_Command{Command: cmd}, + }: + return nil + default: + return status.Error(codes.ResourceExhausted, "agent send buffer full") + } +} + +// SendConfigUpdate sends a config update to a specific agent. +func (h *Hub) SendConfigUpdate(agentID string, config *pb.ConfigUpdate) error { + h.mu.RLock() + conn, ok := h.agents[agentID] + h.mu.RUnlock() + + if !ok || !conn.Connected { + return status.Error(codes.NotFound, "agent not connected") + } + + select { + case conn.SendCh <- &pb.ServerMessage{ + Payload: &pb.ServerMessage_Config{Config: config}, + }: + return nil + default: + return status.Error(codes.ResourceExhausted, "agent send buffer full") + } +} + +// BroadcastCommand sends a command to all connected agents. +func (h *Hub) BroadcastCommand(cmd *pb.Command) { + h.mu.RLock() + defer h.mu.RUnlock() + + msg := &pb.ServerMessage{ + Payload: &pb.ServerMessage_Command{Command: cmd}, + } + + for _, conn := range h.agents { + if conn.Connected { + select { + case conn.SendCh <- msg: + default: + log.Printf("Send buffer full for agent %s", conn.ID) + } + } + } +} + +// GetConnectedAgents returns a list of currently connected agent IDs. +func (h *Hub) GetConnectedAgents() []string { + h.mu.RLock() + defer h.mu.RUnlock() + + ids := make([]string, 0, len(h.agents)) + for id, conn := range h.agents { + if conn.Connected { + ids = append(ids, id) + } + } + return ids +} + +// GetAgentMetrics returns the last metrics for an agent. +func (h *Hub) GetAgentMetrics(agentID string) (*models.AllMetrics, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + conn, ok := h.agents[agentID] + if !ok || conn.LastMetrics == nil { + return nil, false + } + return conn.LastMetrics, true +} + +// GetAllMetrics returns the last metrics for all connected agents. +func (h *Hub) GetAllMetrics() map[string]*models.AllMetrics { + h.mu.RLock() + defer h.mu.RUnlock() + + result := make(map[string]*models.AllMetrics, len(h.agents)) + for id, conn := range h.agents { + if conn.LastMetrics != nil { + result[id] = conn.LastMetrics + } + } + return result +} + +func (h *Hub) getAgentConfig() *pb.AgentConfig { + return &pb.AgentConfig{ + CollectionIntervalSeconds: 5, + EnabledCollectors: []string{"cpu", "memory", "disk", "network", "process", "temperature", "gpu"}, + } +} + +// Context key for agent ID +type contextKey string + +const agentIDKey contextKey = "agentID" + +// AgentIDFromContext extracts the agent ID from context. +func AgentIDFromContext(ctx context.Context) (string, bool) { + id, ok := ctx.Value(agentIDKey).(string) + return id, ok +} + +// ContextWithAgentID returns a context with the agent ID set. +func ContextWithAgentID(ctx context.Context, agentID string) context.Context { + return context.WithValue(ctx, agentIDKey, agentID) +} diff --git a/backend/internal/server/registry.go b/backend/internal/server/registry.go new file mode 100644 index 0000000..b04d435 --- /dev/null +++ b/backend/internal/server/registry.go @@ -0,0 +1,284 @@ +// Package server implements the central Tyto server for multi-device monitoring. +package server + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +// AgentStatus represents the current state of an agent. +type AgentStatus string + +const ( + AgentStatusPending AgentStatus = "pending" + AgentStatusApproved AgentStatus = "approved" + AgentStatusConnected AgentStatus = "connected" + AgentStatusOffline AgentStatus = "offline" + AgentStatusRevoked AgentStatus = "revoked" +) + +// AgentRecord stores metadata about a registered agent. +type AgentRecord struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Hostname string `json:"hostname"` + OS string `json:"os"` + Architecture string `json:"architecture"` + Version string `json:"version"` + Capabilities []string `json:"capabilities,omitempty"` + Status AgentStatus `json:"status"` + CertSerial string `json:"certSerial,omitempty"` + CertExpiry time.Time `json:"certExpiry,omitempty"` + LastSeen time.Time `json:"lastSeen,omitempty"` + RegisteredAt time.Time `json:"registeredAt"` + Tags []string `json:"tags,omitempty"` +} + +// Registry manages agent registrations. +type Registry struct { + mu sync.RWMutex + agents map[string]*AgentRecord + filePath string +} + +// NewRegistry creates a new agent registry. +func NewRegistry(filePath string) *Registry { + r := &Registry{ + agents: make(map[string]*AgentRecord), + filePath: filePath, + } + + // Load existing registrations + if filePath != "" { + r.load() + } + + return r +} + +// Register adds or updates an agent registration. +func (r *Registry) Register(agent *AgentRecord) error { + r.mu.Lock() + defer r.mu.Unlock() + + existing, exists := r.agents[agent.ID] + if exists { + // Update existing record + existing.Hostname = agent.Hostname + existing.OS = agent.OS + existing.Architecture = agent.Architecture + existing.Version = agent.Version + existing.Capabilities = agent.Capabilities + existing.LastSeen = time.Now() + + // Don't change status if already approved/connected + if existing.Status == AgentStatusRevoked { + return fmt.Errorf("agent %s is revoked", agent.ID) + } + } else { + // New registration + agent.RegisteredAt = time.Now() + agent.LastSeen = time.Now() + if agent.Status == "" { + agent.Status = AgentStatusPending + } + r.agents[agent.ID] = agent + } + + return r.save() +} + +// Get returns an agent record by ID. +func (r *Registry) Get(id string) (*AgentRecord, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + agent, exists := r.agents[id] + if !exists { + return nil, false + } + + // Return a copy + copy := *agent + return ©, true +} + +// List returns all registered agents. +func (r *Registry) List() []*AgentRecord { + r.mu.RLock() + defer r.mu.RUnlock() + + result := make([]*AgentRecord, 0, len(r.agents)) + for _, agent := range r.agents { + copy := *agent + result = append(result, ©) + } + return result +} + +// ListPending returns agents awaiting approval. +func (r *Registry) ListPending() []*AgentRecord { + r.mu.RLock() + defer r.mu.RUnlock() + + var result []*AgentRecord + for _, agent := range r.agents { + if agent.Status == AgentStatusPending { + copy := *agent + result = append(result, ©) + } + } + return result +} + +// Approve marks an agent as approved. +func (r *Registry) Approve(id string) error { + r.mu.Lock() + defer r.mu.Unlock() + + agent, exists := r.agents[id] + if !exists { + return fmt.Errorf("agent %s not found", id) + } + + if agent.Status == AgentStatusRevoked { + return fmt.Errorf("agent %s is revoked", id) + } + + agent.Status = AgentStatusApproved + return r.save() +} + +// Reject removes a pending agent registration. +func (r *Registry) Reject(id string) error { + r.mu.Lock() + defer r.mu.Unlock() + + agent, exists := r.agents[id] + if !exists { + return fmt.Errorf("agent %s not found", id) + } + + if agent.Status != AgentStatusPending { + return fmt.Errorf("agent %s is not pending", id) + } + + delete(r.agents, id) + return r.save() +} + +// Revoke marks an agent as revoked. +func (r *Registry) Revoke(id string) error { + r.mu.Lock() + defer r.mu.Unlock() + + agent, exists := r.agents[id] + if !exists { + return fmt.Errorf("agent %s not found", id) + } + + agent.Status = AgentStatusRevoked + return r.save() +} + +// Remove deletes an agent registration. +func (r *Registry) Remove(id string) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.agents[id]; !exists { + return fmt.Errorf("agent %s not found", id) + } + + delete(r.agents, id) + return r.save() +} + +// IsApproved checks if an agent is approved to connect. +func (r *Registry) IsApproved(id string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + agent, exists := r.agents[id] + if !exists { + return false + } + + return agent.Status == AgentStatusApproved || agent.Status == AgentStatusConnected +} + +// UpdateStatus updates an agent's connection status. +func (r *Registry) UpdateStatus(id string, status AgentStatus) error { + r.mu.Lock() + defer r.mu.Unlock() + + agent, exists := r.agents[id] + if !exists { + return fmt.Errorf("agent %s not found", id) + } + + agent.Status = status + agent.LastSeen = time.Now() + return r.save() +} + +// UpdateLastSeen updates the last seen timestamp. +func (r *Registry) UpdateLastSeen(id string) { + r.mu.Lock() + defer r.mu.Unlock() + + if agent, exists := r.agents[id]; exists { + agent.LastSeen = time.Now() + } +} + +// load reads the registry from disk. +func (r *Registry) load() error { + data, err := os.ReadFile(r.filePath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var agents []*AgentRecord + if err := json.Unmarshal(data, &agents); err != nil { + return err + } + + for _, agent := range agents { + r.agents[agent.ID] = agent + } + + return nil +} + +// save writes the registry to disk. +func (r *Registry) save() error { + if r.filePath == "" { + return nil + } + + agents := make([]*AgentRecord, 0, len(r.agents)) + for _, agent := range r.agents { + agents = append(agents, agent) + } + + data, err := json.MarshalIndent(agents, "", " ") + if err != nil { + return err + } + + // Ensure directory exists + dir := filepath.Dir(r.filePath) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + return os.WriteFile(r.filePath, data, 0644) +}