feat: implement server hub for multi-device agent management
Server Package (internal/server/): - Registry: Agent registration with approval workflow, persistence - Hub: Connection manager for connected agents, message routing - GRPCServer: mTLS-enabled gRPC server with interceptors - SSEBridge: Bridges agent metrics to browser SSE clients Registry Features: - JSON file-based persistence - Agent lifecycle: pending -> approved -> connected -> offline - Revocation support for certificate-based agent removal - Automatic last-seen tracking Hub Features: - Bidirectional gRPC stream handling - MetricsSubscriber interface for metric distribution - Stale connection detection and cleanup - Broadcast and per-agent command sending gRPC Server: - Unary and stream interceptors for auth - Agent ID extraction from mTLS certificates - Delegation to Hub for business logic Agent Management API: - GET/DELETE /api/v1/agents - List/remove agents - GET /api/v1/agents/pending - Pending approvals - POST /api/v1/agents/pending/:id/approve|reject - GET /api/v1/agents/:id/metrics - Latest agent metrics - GET /api/v1/agents/connected - Connected agents Server Mode Startup: - Full initialization of registry, hub, gRPC, SSE bridge - Graceful shutdown with signal handling - Agent mode now uses the agent package 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"tyto/internal/agent"
|
||||
"tyto/internal/api"
|
||||
"tyto/internal/config"
|
||||
"tyto/internal/server"
|
||||
"tyto/internal/sse"
|
||||
)
|
||||
|
||||
@@ -60,29 +66,81 @@ func runServer(cfg *config.Config) {
|
||||
log.Printf("gRPC port for agents: %d", cfg.Server.GRPCPort)
|
||||
log.Printf("Database: %s", cfg.Database.Type)
|
||||
|
||||
// TODO: Initialize database
|
||||
// TODO: Initialize authentication
|
||||
// TODO: Initialize gRPC server for agents
|
||||
// TODO: Initialize agent hub
|
||||
// Set up signal handling
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// For now, run in standalone-compatible mode
|
||||
// Full server mode will be implemented in subsequent sprints
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Initialize agent registry
|
||||
registryPath := "/var/lib/tyto/agents.json"
|
||||
if cfg.Database.SQLitePath != "" {
|
||||
registryPath = cfg.Database.SQLitePath + ".agents.json"
|
||||
}
|
||||
registry := server.NewRegistry(registryPath)
|
||||
log.Printf("Agent registry initialized: %s", registryPath)
|
||||
|
||||
// Initialize Hub
|
||||
hubConfig := &server.HubConfig{
|
||||
RequireApproval: cfg.Server.Registration.RequireApproval,
|
||||
AutoApprove: cfg.Server.Registration.AutoEnabled && !cfg.Server.Registration.RequireApproval,
|
||||
}
|
||||
hub := server.NewHub(registry, hubConfig)
|
||||
hub.Start()
|
||||
defer hub.Stop()
|
||||
log.Println("Agent hub started")
|
||||
|
||||
// Initialize SSE bridge for multi-device streaming
|
||||
bridge := server.NewSSEBridge(hub)
|
||||
bridge.Start()
|
||||
defer bridge.Stop()
|
||||
|
||||
// Initialize gRPC server for agent connections
|
||||
grpcServer, err := server.NewGRPCServer(hub, &cfg.Server)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create gRPC server: %v", err)
|
||||
}
|
||||
|
||||
// Start gRPC server in background
|
||||
go func() {
|
||||
log.Printf("Starting gRPC server on port %d", cfg.Server.GRPCPort)
|
||||
if err := grpcServer.Start(cfg.Server.GRPCPort); err != nil {
|
||||
log.Printf("gRPC server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
// Initialize SSE broker for local metrics (also runs in server mode)
|
||||
broker := sse.NewBroker(cfg)
|
||||
go broker.Run()
|
||||
|
||||
server := api.NewServer(cfg, broker)
|
||||
// Initialize HTTP API server with agent management
|
||||
apiServer := api.NewServer(cfg, broker)
|
||||
|
||||
var err error
|
||||
if cfg.TLSEnabled {
|
||||
log.Printf("Starting HTTPS server on port %s", cfg.Port)
|
||||
err = server.RunTLS(cfg.TLSCertFile, cfg.TLSKeyFile)
|
||||
} else {
|
||||
err = server.Run()
|
||||
}
|
||||
// Add agent API routes
|
||||
agentAPI := api.NewAgentAPI(registry, hub)
|
||||
agentAPI.RegisterRoutes(apiServer.Router().Group("/api/v1"))
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
// Start HTTP server in background
|
||||
go func() {
|
||||
var err error
|
||||
if cfg.TLSEnabled {
|
||||
log.Printf("Starting HTTPS server on port %s", cfg.Port)
|
||||
err = apiServer.RunTLS(cfg.TLSCertFile, cfg.TLSKeyFile)
|
||||
} else {
|
||||
log.Printf("Starting HTTP server on port %s", cfg.Port)
|
||||
err = apiServer.Run()
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for shutdown signal
|
||||
<-sigCh
|
||||
log.Println("Shutting down server...")
|
||||
cancel()
|
||||
}
|
||||
|
||||
// runAgent starts Tyto as a lightweight agent that reports to a central server.
|
||||
@@ -98,7 +156,28 @@ func runAgent(cfg *config.Config) {
|
||||
log.Printf("Reporting to: %s", cfg.Agent.ServerURL)
|
||||
log.Printf("Collection interval: %s", cfg.Agent.Interval)
|
||||
|
||||
// TODO: Implement gRPC client and metrics collection loop
|
||||
// This will be implemented in Sprint 3 (Agent Implementation)
|
||||
log.Fatal("Agent mode not yet implemented")
|
||||
// Set up signal handling
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create agent
|
||||
a := agent.New(cfg)
|
||||
|
||||
// Handle shutdown signal
|
||||
go func() {
|
||||
<-sigCh
|
||||
log.Println("Received shutdown signal, stopping agent...")
|
||||
a.Stop()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Run agent
|
||||
if err := a.Run(ctx); err != nil && err != context.Canceled {
|
||||
log.Fatalf("Agent error: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Agent stopped")
|
||||
}
|
||||
|
||||
224
backend/internal/api/agents.go
Normal file
224
backend/internal/api/agents.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"tyto/internal/server"
|
||||
)
|
||||
|
||||
// AgentAPI handles agent management endpoints.
|
||||
type AgentAPI struct {
|
||||
registry *server.Registry
|
||||
hub *server.Hub
|
||||
}
|
||||
|
||||
// NewAgentAPI creates a new agent API handler.
|
||||
func NewAgentAPI(registry *server.Registry, hub *server.Hub) *AgentAPI {
|
||||
return &AgentAPI{
|
||||
registry: registry,
|
||||
hub: hub,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes adds agent management routes to a router group.
|
||||
func (a *AgentAPI) RegisterRoutes(group *gin.RouterGroup) {
|
||||
agents := group.Group("/agents")
|
||||
{
|
||||
agents.GET("", a.listAgents)
|
||||
agents.GET("/:id", a.getAgent)
|
||||
agents.DELETE("/:id", a.removeAgent)
|
||||
agents.POST("/:id/revoke", a.revokeAgent)
|
||||
|
||||
// Pending registrations
|
||||
agents.GET("/pending", a.listPending)
|
||||
agents.POST("/pending/:id/approve", a.approveAgent)
|
||||
agents.POST("/pending/:id/reject", a.rejectAgent)
|
||||
|
||||
// Connected agents
|
||||
agents.GET("/connected", a.listConnected)
|
||||
agents.GET("/:id/metrics", a.getAgentMetrics)
|
||||
}
|
||||
}
|
||||
|
||||
// AgentResponse is the API representation of an agent.
|
||||
type AgentResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Architecture string `json:"architecture"`
|
||||
Version string `json:"version"`
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Connected bool `json:"connected"`
|
||||
LastSeen string `json:"lastSeen,omitempty"`
|
||||
RegisteredAt string `json:"registeredAt"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
func agentToResponse(agent *server.AgentRecord, connected bool) AgentResponse {
|
||||
lastSeen := ""
|
||||
if !agent.LastSeen.IsZero() {
|
||||
lastSeen = agent.LastSeen.Format("2006-01-02T15:04:05Z07:00")
|
||||
}
|
||||
|
||||
registeredAt := ""
|
||||
if !agent.RegisteredAt.IsZero() {
|
||||
registeredAt = agent.RegisteredAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
}
|
||||
|
||||
return AgentResponse{
|
||||
ID: agent.ID,
|
||||
Name: agent.Name,
|
||||
Hostname: agent.Hostname,
|
||||
OS: agent.OS,
|
||||
Architecture: agent.Architecture,
|
||||
Version: agent.Version,
|
||||
Capabilities: agent.Capabilities,
|
||||
Status: string(agent.Status),
|
||||
Connected: connected,
|
||||
LastSeen: lastSeen,
|
||||
RegisteredAt: registeredAt,
|
||||
Tags: agent.Tags,
|
||||
}
|
||||
}
|
||||
|
||||
// listAgents returns all registered agents.
|
||||
func (a *AgentAPI) listAgents(c *gin.Context) {
|
||||
agents := a.registry.List()
|
||||
connectedIDs := make(map[string]bool)
|
||||
if a.hub != nil {
|
||||
for _, id := range a.hub.GetConnectedAgents() {
|
||||
connectedIDs[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
response := make([]AgentResponse, len(agents))
|
||||
for i, agent := range agents {
|
||||
response[i] = agentToResponse(agent, connectedIDs[agent.ID])
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// getAgent returns a specific agent.
|
||||
func (a *AgentAPI) getAgent(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
agent, exists := a.registry.Get(id)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
|
||||
return
|
||||
}
|
||||
|
||||
connected := false
|
||||
if a.hub != nil {
|
||||
for _, connID := range a.hub.GetConnectedAgents() {
|
||||
if connID == id {
|
||||
connected = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, agentToResponse(agent, connected))
|
||||
}
|
||||
|
||||
// removeAgent removes an agent registration.
|
||||
func (a *AgentAPI) removeAgent(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := a.registry.Remove(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "removed"})
|
||||
}
|
||||
|
||||
// revokeAgent revokes an agent's registration.
|
||||
func (a *AgentAPI) revokeAgent(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := a.registry.Revoke(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "revoked"})
|
||||
}
|
||||
|
||||
// listPending returns agents awaiting approval.
|
||||
func (a *AgentAPI) listPending(c *gin.Context) {
|
||||
agents := a.registry.ListPending()
|
||||
|
||||
response := make([]AgentResponse, len(agents))
|
||||
for i, agent := range agents {
|
||||
response[i] = agentToResponse(agent, false)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// approveAgent approves a pending agent.
|
||||
func (a *AgentAPI) approveAgent(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := a.registry.Approve(id); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "approved"})
|
||||
}
|
||||
|
||||
// rejectAgent rejects a pending agent.
|
||||
func (a *AgentAPI) rejectAgent(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := a.registry.Reject(id); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "rejected"})
|
||||
}
|
||||
|
||||
// listConnected returns currently connected agents.
|
||||
func (a *AgentAPI) listConnected(c *gin.Context) {
|
||||
if a.hub == nil {
|
||||
c.JSON(http.StatusOK, []AgentResponse{})
|
||||
return
|
||||
}
|
||||
|
||||
connectedIDs := a.hub.GetConnectedAgents()
|
||||
response := make([]AgentResponse, 0, len(connectedIDs))
|
||||
|
||||
for _, id := range connectedIDs {
|
||||
if agent, exists := a.registry.Get(id); exists {
|
||||
response = append(response, agentToResponse(agent, true))
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// getAgentMetrics returns the latest metrics for an agent.
|
||||
func (a *AgentAPI) getAgentMetrics(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if a.hub == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "hub not available"})
|
||||
return
|
||||
}
|
||||
|
||||
metrics, exists := a.hub.GetAgentMetrics(id)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no metrics available"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, metrics)
|
||||
}
|
||||
@@ -335,6 +335,11 @@ func (s *Server) ListenAddr() string {
|
||||
return fmt.Sprintf(":%s", s.cfg.Port)
|
||||
}
|
||||
|
||||
// Router returns the underlying Gin engine for adding routes.
|
||||
func (s *Server) Router() *gin.Engine {
|
||||
return s.router
|
||||
}
|
||||
|
||||
// Alert handlers
|
||||
func (s *Server) getAlertsHandler(c *gin.Context) {
|
||||
response := models.AlertsResponse{
|
||||
|
||||
@@ -101,18 +101,22 @@ func ServerTLSConfigWithCRL(ca *CA, serverCertPath, serverKeyPath string) (*tls.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtractAgentID extracts the agent ID from a verified client certificate.
|
||||
// ExtractAgentIDFromState extracts the agent ID from a verified TLS connection state.
|
||||
// The agent ID is stored in the certificate's CommonName.
|
||||
func ExtractAgentID(state *tls.ConnectionState) (string, error) {
|
||||
func ExtractAgentIDFromState(state *tls.ConnectionState) (string, error) {
|
||||
if len(state.VerifiedChains) == 0 || len(state.VerifiedChains[0]) == 0 {
|
||||
return "", fmt.Errorf("no verified certificate chain")
|
||||
}
|
||||
|
||||
cert := state.VerifiedChains[0][0]
|
||||
agentID := cert.Subject.CommonName
|
||||
if agentID == "" {
|
||||
return "", fmt.Errorf("certificate has no CommonName")
|
||||
}
|
||||
|
||||
return agentID, nil
|
||||
return ExtractAgentID(cert), nil
|
||||
}
|
||||
|
||||
// ExtractAgentID extracts the agent ID from a certificate.
|
||||
// The agent ID is stored in the certificate's CommonName.
|
||||
func ExtractAgentID(cert *x509.Certificate) string {
|
||||
if cert == nil {
|
||||
return ""
|
||||
}
|
||||
return cert.Subject.CommonName
|
||||
}
|
||||
|
||||
275
backend/internal/server/bridge.go
Normal file
275
backend/internal/server/bridge.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyto/internal/models"
|
||||
pb "tyto/internal/proto"
|
||||
)
|
||||
|
||||
// SSEBridge bridges the Hub to SSE clients for multi-device monitoring.
|
||||
// It implements MetricsSubscriber to receive agent updates.
|
||||
type SSEBridge struct {
|
||||
hub *Hub
|
||||
clients map[chan []byte]bool
|
||||
mu sync.RWMutex
|
||||
broadcast chan []byte
|
||||
|
||||
// Cache of latest metrics per device
|
||||
metricsCache map[string]*DeviceMetrics
|
||||
cacheMu sync.RWMutex
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// DeviceMetrics wraps metrics with device info.
|
||||
type DeviceMetrics struct {
|
||||
DeviceID string `json:"deviceId"`
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Status string `json:"status"`
|
||||
LastUpdated time.Time `json:"lastUpdated"`
|
||||
Metrics *models.AllMetrics `json:"metrics"`
|
||||
}
|
||||
|
||||
// MultiDeviceSnapshot is sent to SSE clients.
|
||||
type MultiDeviceSnapshot struct {
|
||||
Type string `json:"type"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Devices map[string]*DeviceMetrics `json:"devices"`
|
||||
}
|
||||
|
||||
// DeviceUpdate is sent when a single device's metrics change.
|
||||
type DeviceUpdate struct {
|
||||
Type string `json:"type"`
|
||||
DeviceID string `json:"deviceId"`
|
||||
Status string `json:"status"`
|
||||
Metrics *models.AllMetrics `json:"metrics"`
|
||||
}
|
||||
|
||||
// DeviceStatusChange is sent when a device connects/disconnects.
|
||||
type DeviceStatusChange struct {
|
||||
Type string `json:"type"`
|
||||
DeviceID string `json:"deviceId"`
|
||||
Status string `json:"status"`
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
}
|
||||
|
||||
// NewSSEBridge creates a new SSE bridge.
|
||||
func NewSSEBridge(hub *Hub) *SSEBridge {
|
||||
bridge := &SSEBridge{
|
||||
hub: hub,
|
||||
clients: make(map[chan []byte]bool),
|
||||
broadcast: make(chan []byte, 256),
|
||||
metricsCache: make(map[string]*DeviceMetrics),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Subscribe to hub events
|
||||
hub.Subscribe(bridge)
|
||||
|
||||
return bridge
|
||||
}
|
||||
|
||||
// Start begins the broadcast loop.
|
||||
func (b *SSEBridge) Start() {
|
||||
b.wg.Add(1)
|
||||
go b.broadcastLoop()
|
||||
}
|
||||
|
||||
// Stop stops the bridge.
|
||||
func (b *SSEBridge) Stop() {
|
||||
close(b.stopCh)
|
||||
b.wg.Wait()
|
||||
}
|
||||
|
||||
func (b *SSEBridge) broadcastLoop() {
|
||||
defer b.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
return
|
||||
case data := <-b.broadcast:
|
||||
b.mu.RLock()
|
||||
for client := range b.clients {
|
||||
select {
|
||||
case client <- data:
|
||||
default:
|
||||
// Client buffer full, skip
|
||||
}
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds an SSE client.
|
||||
func (b *SSEBridge) Register(client chan []byte) {
|
||||
b.mu.Lock()
|
||||
b.clients[client] = true
|
||||
b.mu.Unlock()
|
||||
|
||||
// Send current snapshot to new client
|
||||
go b.sendSnapshot(client)
|
||||
}
|
||||
|
||||
// Unregister removes an SSE client.
|
||||
func (b *SSEBridge) Unregister(client chan []byte) {
|
||||
b.mu.Lock()
|
||||
delete(b.clients, client)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *SSEBridge) sendSnapshot(client chan []byte) {
|
||||
b.cacheMu.RLock()
|
||||
snapshot := MultiDeviceSnapshot{
|
||||
Type: "snapshot",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Devices: b.metricsCache,
|
||||
}
|
||||
b.cacheMu.RUnlock()
|
||||
|
||||
data, err := json.Marshal(snapshot)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal snapshot: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case client <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsSubscriber implementation
|
||||
|
||||
// OnAgentMetrics is called when an agent sends metrics.
|
||||
func (b *SSEBridge) OnAgentMetrics(agentID string, metrics *models.AllMetrics) {
|
||||
// Update cache
|
||||
b.cacheMu.Lock()
|
||||
if existing, ok := b.metricsCache[agentID]; ok {
|
||||
existing.Metrics = metrics
|
||||
existing.LastUpdated = time.Now()
|
||||
existing.Status = "online"
|
||||
} else {
|
||||
b.metricsCache[agentID] = &DeviceMetrics{
|
||||
DeviceID: agentID,
|
||||
Status: "online",
|
||||
LastUpdated: time.Now(),
|
||||
Metrics: metrics,
|
||||
}
|
||||
}
|
||||
b.cacheMu.Unlock()
|
||||
|
||||
// Broadcast update
|
||||
update := DeviceUpdate{
|
||||
Type: "update",
|
||||
DeviceID: agentID,
|
||||
Status: "online",
|
||||
Metrics: metrics,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(update)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal update: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case b.broadcast <- data:
|
||||
default:
|
||||
log.Println("Broadcast buffer full, dropping update")
|
||||
}
|
||||
}
|
||||
|
||||
// OnAgentConnected is called when an agent connects.
|
||||
func (b *SSEBridge) OnAgentConnected(agentID string, info *pb.AgentInfo) {
|
||||
hostname := ""
|
||||
if info != nil {
|
||||
hostname = info.Hostname
|
||||
}
|
||||
|
||||
// Update cache
|
||||
b.cacheMu.Lock()
|
||||
if existing, ok := b.metricsCache[agentID]; ok {
|
||||
existing.Status = "online"
|
||||
existing.Hostname = hostname
|
||||
} else {
|
||||
b.metricsCache[agentID] = &DeviceMetrics{
|
||||
DeviceID: agentID,
|
||||
Hostname: hostname,
|
||||
Status: "online",
|
||||
}
|
||||
}
|
||||
b.cacheMu.Unlock()
|
||||
|
||||
// Broadcast status change
|
||||
change := DeviceStatusChange{
|
||||
Type: "connected",
|
||||
DeviceID: agentID,
|
||||
Status: "online",
|
||||
Hostname: hostname,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(change)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case b.broadcast <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// OnAgentDisconnected is called when an agent disconnects.
|
||||
func (b *SSEBridge) OnAgentDisconnected(agentID string) {
|
||||
// Update cache
|
||||
b.cacheMu.Lock()
|
||||
if existing, ok := b.metricsCache[agentID]; ok {
|
||||
existing.Status = "offline"
|
||||
}
|
||||
b.cacheMu.Unlock()
|
||||
|
||||
// Broadcast status change
|
||||
change := DeviceStatusChange{
|
||||
Type: "disconnected",
|
||||
DeviceID: agentID,
|
||||
Status: "offline",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(change)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case b.broadcast <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GetSnapshot returns the current metrics snapshot.
|
||||
func (b *SSEBridge) GetSnapshot() map[string]*DeviceMetrics {
|
||||
b.cacheMu.RLock()
|
||||
defer b.cacheMu.RUnlock()
|
||||
|
||||
snapshot := make(map[string]*DeviceMetrics, len(b.metricsCache))
|
||||
for k, v := range b.metricsCache {
|
||||
snapshot[k] = v
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// GetDeviceMetrics returns metrics for a specific device.
|
||||
func (b *SSEBridge) GetDeviceMetrics(deviceID string) (*DeviceMetrics, bool) {
|
||||
b.cacheMu.RLock()
|
||||
defer b.cacheMu.RUnlock()
|
||||
|
||||
metrics, ok := b.metricsCache[deviceID]
|
||||
return metrics, ok
|
||||
}
|
||||
232
backend/internal/server/grpc.go
Normal file
232
backend/internal/server/grpc.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"tyto/internal/config"
|
||||
"tyto/internal/pki"
|
||||
pb "tyto/internal/proto"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// GRPCServer wraps the gRPC server and Hub.
|
||||
type GRPCServer struct {
|
||||
pb.UnimplementedAgentServiceServer
|
||||
|
||||
hub *Hub
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
config *config.ServerConfig
|
||||
}
|
||||
|
||||
// NewGRPCServer creates a new gRPC server.
|
||||
func NewGRPCServer(hub *Hub, cfg *config.ServerConfig) (*GRPCServer, error) {
|
||||
s := &GRPCServer{
|
||||
hub: hub,
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// Build server options
|
||||
opts, err := s.serverOptions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create server options: %w", err)
|
||||
}
|
||||
|
||||
s.server = grpc.NewServer(opts...)
|
||||
pb.RegisterAgentServiceServer(s.server, s)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) serverOptions() ([]grpc.ServerOption, error) {
|
||||
var opts []grpc.ServerOption
|
||||
|
||||
tlsCfg := s.config.TLS
|
||||
if tlsCfg.CACert != "" && tlsCfg.ServerCert != "" {
|
||||
// Load mTLS configuration
|
||||
tlsConfig, err := s.loadTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
} else {
|
||||
log.Println("Warning: gRPC server running without TLS (insecure mode)")
|
||||
}
|
||||
|
||||
// Add interceptors for authentication and logging
|
||||
opts = append(opts,
|
||||
grpc.UnaryInterceptor(s.unaryInterceptor),
|
||||
grpc.StreamInterceptor(s.streamInterceptor),
|
||||
)
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) loadTLSConfig() (*tls.Config, error) {
|
||||
tlsCfg := s.config.TLS
|
||||
|
||||
// Load CA certificate
|
||||
caCert, err := os.ReadFile(tlsCfg.CACert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA cert: %w", err)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA cert")
|
||||
}
|
||||
|
||||
// Load server certificate
|
||||
cert, err := tls.LoadX509KeyPair(tlsCfg.ServerCert, tlsCfg.ServerKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load server cert: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins listening for connections.
|
||||
func (s *GRPCServer) Start(port int) error {
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
log.Printf("gRPC server listening on %s", addr)
|
||||
|
||||
return s.server.Serve(listener)
|
||||
}
|
||||
|
||||
// Stop gracefully stops the server.
|
||||
func (s *GRPCServer) Stop() {
|
||||
if s.server != nil {
|
||||
s.server.GracefulStop()
|
||||
}
|
||||
}
|
||||
|
||||
// Interceptors
|
||||
|
||||
func (s *GRPCServer) unaryInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
// Extract agent ID from TLS certificate
|
||||
agentID, err := s.extractAgentID(ctx)
|
||||
if err != nil {
|
||||
// Allow register without pre-auth for initial registration
|
||||
if !strings.Contains(info.FullMethod, "Register") {
|
||||
return nil, err
|
||||
}
|
||||
// For registration, use the agent ID from the request
|
||||
if regReq, ok := req.(*pb.RegisterRequest); ok {
|
||||
agentID = regReq.AgentId
|
||||
}
|
||||
}
|
||||
|
||||
// Add agent ID to context
|
||||
ctx = ContextWithAgentID(ctx, agentID)
|
||||
|
||||
// Log the request
|
||||
log.Printf("gRPC %s from agent %s", info.FullMethod, agentID)
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) streamInterceptor(
|
||||
srv interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
// Extract agent ID from TLS certificate
|
||||
agentID, err := s.extractAgentID(ss.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wrap the stream with context containing agent ID
|
||||
wrapped := &wrappedServerStream{
|
||||
ServerStream: ss,
|
||||
ctx: ContextWithAgentID(ss.Context(), agentID),
|
||||
}
|
||||
|
||||
log.Printf("gRPC stream %s from agent %s", info.FullMethod, agentID)
|
||||
|
||||
return handler(srv, wrapped)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) extractAgentID(ctx context.Context) (string, error) {
|
||||
// Try to get peer info
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return "", status.Error(codes.Unauthenticated, "no peer info")
|
||||
}
|
||||
|
||||
// Check for TLS info
|
||||
tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo)
|
||||
if !ok {
|
||||
// No TLS, might be insecure mode - check for existing context value
|
||||
if id, ok := AgentIDFromContext(ctx); ok {
|
||||
return id, nil
|
||||
}
|
||||
return "", status.Error(codes.Unauthenticated, "no TLS info")
|
||||
}
|
||||
|
||||
// Extract CN from client certificate
|
||||
if len(tlsInfo.State.PeerCertificates) == 0 {
|
||||
return "", status.Error(codes.Unauthenticated, "no client certificate")
|
||||
}
|
||||
|
||||
cert := tlsInfo.State.PeerCertificates[0]
|
||||
agentID := pki.ExtractAgentID(cert)
|
||||
if agentID == "" {
|
||||
return "", status.Error(codes.Unauthenticated, "no agent ID in certificate")
|
||||
}
|
||||
|
||||
return agentID, nil
|
||||
}
|
||||
|
||||
// wrappedServerStream wraps a ServerStream to override Context().
|
||||
type wrappedServerStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *wrappedServerStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
|
||||
// Service methods - delegate to Hub
|
||||
|
||||
func (s *GRPCServer) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
||||
return s.hub.Register(ctx, req)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) Stream(stream pb.AgentService_StreamServer) error {
|
||||
return s.hub.Stream(stream)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) {
|
||||
return s.hub.Heartbeat(ctx, req)
|
||||
}
|
||||
513
backend/internal/server/hub.go
Normal file
513
backend/internal/server/hub.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyto/internal/models"
|
||||
pb "tyto/internal/proto"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// AgentConnection represents a connected agent.
|
||||
type AgentConnection struct {
|
||||
ID string
|
||||
Stream pb.AgentService_StreamServer
|
||||
Info *pb.AgentInfo
|
||||
LastMetrics *models.AllMetrics
|
||||
LastSeen time.Time
|
||||
Connected bool
|
||||
SendCh chan *pb.ServerMessage
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// MetricsSubscriber receives aggregated metrics from all agents.
|
||||
type MetricsSubscriber interface {
|
||||
OnAgentMetrics(agentID string, metrics *models.AllMetrics)
|
||||
OnAgentConnected(agentID string, info *pb.AgentInfo)
|
||||
OnAgentDisconnected(agentID string)
|
||||
}
|
||||
|
||||
// Hub manages agent connections and message routing.
|
||||
type Hub struct {
|
||||
registry *Registry
|
||||
config *HubConfig
|
||||
agents map[string]*AgentConnection
|
||||
mu sync.RWMutex
|
||||
subscribers []MetricsSubscriber
|
||||
|
||||
// Channels for internal coordination
|
||||
registerCh chan *AgentConnection
|
||||
unregisterCh chan string
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// HubConfig contains Hub configuration.
|
||||
type HubConfig struct {
|
||||
RequireApproval bool
|
||||
AutoApprove bool
|
||||
}
|
||||
|
||||
// NewHub creates a new Hub instance.
|
||||
func NewHub(registry *Registry, config *HubConfig) *Hub {
|
||||
if config == nil {
|
||||
config = &HubConfig{
|
||||
RequireApproval: true,
|
||||
AutoApprove: false,
|
||||
}
|
||||
}
|
||||
|
||||
return &Hub{
|
||||
registry: registry,
|
||||
config: config,
|
||||
agents: make(map[string]*AgentConnection),
|
||||
registerCh: make(chan *AgentConnection, 16),
|
||||
unregisterCh: make(chan string, 16),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a metrics subscriber.
|
||||
func (h *Hub) Subscribe(sub MetricsSubscriber) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.subscribers = append(h.subscribers, sub)
|
||||
}
|
||||
|
||||
// Start begins the hub's event loop.
|
||||
func (h *Hub) Start() {
|
||||
h.wg.Add(1)
|
||||
go h.run()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the hub.
|
||||
func (h *Hub) Stop() {
|
||||
close(h.stopCh)
|
||||
h.wg.Wait()
|
||||
}
|
||||
|
||||
func (h *Hub) run() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.stopCh:
|
||||
h.disconnectAll()
|
||||
return
|
||||
|
||||
case conn := <-h.registerCh:
|
||||
h.handleRegister(conn)
|
||||
|
||||
case agentID := <-h.unregisterCh:
|
||||
h.handleUnregister(agentID)
|
||||
|
||||
case <-ticker.C:
|
||||
h.checkStaleConnections()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) handleRegister(conn *AgentConnection) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Close existing connection if any
|
||||
if existing, ok := h.agents[conn.ID]; ok {
|
||||
existing.Connected = false
|
||||
if existing.cancel != nil {
|
||||
existing.cancel()
|
||||
}
|
||||
close(existing.SendCh)
|
||||
}
|
||||
|
||||
h.agents[conn.ID] = conn
|
||||
log.Printf("Agent registered: %s", conn.ID)
|
||||
|
||||
// Notify subscribers
|
||||
for _, sub := range h.subscribers {
|
||||
sub.OnAgentConnected(conn.ID, conn.Info)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) handleUnregister(agentID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
conn, ok := h.agents[agentID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conn.Connected = false
|
||||
delete(h.agents, agentID)
|
||||
log.Printf("Agent unregistered: %s", agentID)
|
||||
|
||||
// Update registry status
|
||||
h.registry.UpdateStatus(agentID, AgentStatusOffline)
|
||||
|
||||
// Notify subscribers
|
||||
for _, sub := range h.subscribers {
|
||||
sub.OnAgentDisconnected(agentID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) checkStaleConnections() {
|
||||
h.mu.RLock()
|
||||
staleIDs := make([]string, 0)
|
||||
for id, conn := range h.agents {
|
||||
if time.Since(conn.LastSeen) > 60*time.Second {
|
||||
staleIDs = append(staleIDs, id)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
for _, id := range staleIDs {
|
||||
log.Printf("Removing stale agent: %s", id)
|
||||
h.unregisterCh <- id
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) disconnectAll() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, conn := range h.agents {
|
||||
conn.Connected = false
|
||||
if conn.cancel != nil {
|
||||
conn.cancel()
|
||||
}
|
||||
close(conn.SendCh)
|
||||
}
|
||||
h.agents = make(map[string]*AgentConnection)
|
||||
}
|
||||
|
||||
// Register handles agent registration requests.
|
||||
func (h *Hub) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
||||
agentID := req.AgentId
|
||||
info := req.Info
|
||||
|
||||
// Check if already registered
|
||||
existing, exists := h.registry.Get(agentID)
|
||||
if exists {
|
||||
switch existing.Status {
|
||||
case AgentStatusRevoked:
|
||||
return &pb.RegisterResponse{
|
||||
Status: pb.RegisterStatus_REGISTER_STATUS_REJECTED,
|
||||
Message: "agent certificate has been revoked",
|
||||
}, nil
|
||||
|
||||
case AgentStatusApproved, AgentStatusConnected:
|
||||
// Update info and return success
|
||||
record := &AgentRecord{
|
||||
ID: agentID,
|
||||
Hostname: info.Hostname,
|
||||
OS: info.Os,
|
||||
Architecture: info.Architecture,
|
||||
Version: info.Version,
|
||||
Capabilities: info.Capabilities,
|
||||
}
|
||||
h.registry.Register(record)
|
||||
|
||||
return &pb.RegisterResponse{
|
||||
Status: pb.RegisterStatus_REGISTER_STATUS_ALREADY_REGISTERED,
|
||||
Message: "already registered",
|
||||
Config: h.getAgentConfig(),
|
||||
}, nil
|
||||
|
||||
case AgentStatusPending:
|
||||
return &pb.RegisterResponse{
|
||||
Status: pb.RegisterStatus_REGISTER_STATUS_PENDING_APPROVAL,
|
||||
Message: "awaiting approval",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// New registration
|
||||
record := &AgentRecord{
|
||||
ID: agentID,
|
||||
Hostname: info.Hostname,
|
||||
OS: info.Os,
|
||||
Architecture: info.Architecture,
|
||||
Version: info.Version,
|
||||
Capabilities: info.Capabilities,
|
||||
Status: AgentStatusPending,
|
||||
}
|
||||
|
||||
// Auto-approve if configured
|
||||
if h.config.AutoApprove || !h.config.RequireApproval {
|
||||
record.Status = AgentStatusApproved
|
||||
}
|
||||
|
||||
if err := h.registry.Register(record); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "registration failed: %v", err)
|
||||
}
|
||||
|
||||
if record.Status == AgentStatusApproved {
|
||||
return &pb.RegisterResponse{
|
||||
Status: pb.RegisterStatus_REGISTER_STATUS_ACCEPTED,
|
||||
Message: "registration accepted",
|
||||
Config: h.getAgentConfig(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &pb.RegisterResponse{
|
||||
Status: pb.RegisterStatus_REGISTER_STATUS_PENDING_APPROVAL,
|
||||
Message: "awaiting approval",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stream handles the bidirectional streaming RPC.
|
||||
func (h *Hub) Stream(stream pb.AgentService_StreamServer) error {
|
||||
// Extract agent ID from context (set by auth interceptor)
|
||||
agentID, ok := AgentIDFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "agent ID not found in context")
|
||||
}
|
||||
|
||||
// Verify agent is approved
|
||||
if !h.registry.IsApproved(agentID) {
|
||||
return status.Error(codes.PermissionDenied, "agent not approved")
|
||||
}
|
||||
|
||||
// Create connection
|
||||
ctx, cancel := context.WithCancel(stream.Context())
|
||||
conn := &AgentConnection{
|
||||
ID: agentID,
|
||||
Stream: stream,
|
||||
Connected: true,
|
||||
LastSeen: time.Now(),
|
||||
SendCh: make(chan *pb.ServerMessage, 16),
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Register connection
|
||||
h.registerCh <- conn
|
||||
|
||||
// Update registry status
|
||||
h.registry.UpdateStatus(agentID, AgentStatusConnected)
|
||||
|
||||
// Start sender goroutine
|
||||
h.wg.Add(1)
|
||||
go h.sendLoop(conn)
|
||||
|
||||
// Receive loop
|
||||
err := h.receiveLoop(ctx, conn)
|
||||
|
||||
// Cleanup
|
||||
h.unregisterCh <- agentID
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Hub) sendLoop(conn *AgentConnection) {
|
||||
defer h.wg.Done()
|
||||
|
||||
for msg := range conn.SendCh {
|
||||
if err := conn.Stream.Send(msg); err != nil {
|
||||
log.Printf("Send error for agent %s: %v", conn.ID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) receiveLoop(ctx context.Context, conn *AgentConnection) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
msg, err := conn.Stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.LastSeen = time.Now()
|
||||
h.registry.UpdateLastSeen(conn.ID)
|
||||
|
||||
switch payload := msg.Payload.(type) {
|
||||
case *pb.AgentMessage_Metrics:
|
||||
h.handleMetrics(conn, payload.Metrics)
|
||||
|
||||
case *pb.AgentMessage_Heartbeat:
|
||||
h.handleHeartbeat(conn, payload.Heartbeat)
|
||||
|
||||
case *pb.AgentMessage_Info:
|
||||
conn.Info = payload.Info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) handleMetrics(conn *AgentConnection, report *pb.MetricsReport) {
|
||||
// Deserialize metrics
|
||||
var metrics models.AllMetrics
|
||||
if err := json.Unmarshal(report.MetricsJson, &metrics); err != nil {
|
||||
log.Printf("Failed to unmarshal metrics from %s: %v", conn.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
metrics.Timestamp = time.UnixMilli(report.TimestampMs)
|
||||
conn.LastMetrics = &metrics
|
||||
|
||||
// Notify subscribers
|
||||
h.mu.RLock()
|
||||
for _, sub := range h.subscribers {
|
||||
sub.OnAgentMetrics(conn.ID, &metrics)
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Send acknowledgment
|
||||
conn.SendCh <- &pb.ServerMessage{
|
||||
Payload: &pb.ServerMessage_Ack{
|
||||
Ack: &pb.Ack{Success: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) handleHeartbeat(conn *AgentConnection, hb *pb.HeartbeatRequest) {
|
||||
// Just update last seen (already done in receive loop)
|
||||
log.Printf("Heartbeat from %s (uptime: %ds)", conn.ID, hb.UptimeSeconds)
|
||||
}
|
||||
|
||||
// Heartbeat handles simple heartbeat RPCs.
|
||||
func (h *Hub) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) {
|
||||
h.registry.UpdateLastSeen(req.AgentId)
|
||||
|
||||
return &pb.HeartbeatResponse{
|
||||
ServerTimeMs: time.Now().UnixMilli(),
|
||||
ConfigChanged: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendCommand sends a command to a specific agent.
|
||||
func (h *Hub) SendCommand(agentID string, cmd *pb.Command) error {
|
||||
h.mu.RLock()
|
||||
conn, ok := h.agents[agentID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !ok || !conn.Connected {
|
||||
return status.Error(codes.NotFound, "agent not connected")
|
||||
}
|
||||
|
||||
select {
|
||||
case conn.SendCh <- &pb.ServerMessage{
|
||||
Payload: &pb.ServerMessage_Command{Command: cmd},
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
return status.Error(codes.ResourceExhausted, "agent send buffer full")
|
||||
}
|
||||
}
|
||||
|
||||
// SendConfigUpdate sends a config update to a specific agent.
|
||||
func (h *Hub) SendConfigUpdate(agentID string, config *pb.ConfigUpdate) error {
|
||||
h.mu.RLock()
|
||||
conn, ok := h.agents[agentID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !ok || !conn.Connected {
|
||||
return status.Error(codes.NotFound, "agent not connected")
|
||||
}
|
||||
|
||||
select {
|
||||
case conn.SendCh <- &pb.ServerMessage{
|
||||
Payload: &pb.ServerMessage_Config{Config: config},
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
return status.Error(codes.ResourceExhausted, "agent send buffer full")
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastCommand sends a command to all connected agents.
|
||||
func (h *Hub) BroadcastCommand(cmd *pb.Command) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
msg := &pb.ServerMessage{
|
||||
Payload: &pb.ServerMessage_Command{Command: cmd},
|
||||
}
|
||||
|
||||
for _, conn := range h.agents {
|
||||
if conn.Connected {
|
||||
select {
|
||||
case conn.SendCh <- msg:
|
||||
default:
|
||||
log.Printf("Send buffer full for agent %s", conn.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectedAgents returns a list of currently connected agent IDs.
|
||||
func (h *Hub) GetConnectedAgents() []string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
ids := make([]string, 0, len(h.agents))
|
||||
for id, conn := range h.agents {
|
||||
if conn.Connected {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetAgentMetrics returns the last metrics for an agent.
|
||||
func (h *Hub) GetAgentMetrics(agentID string) (*models.AllMetrics, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, ok := h.agents[agentID]
|
||||
if !ok || conn.LastMetrics == nil {
|
||||
return nil, false
|
||||
}
|
||||
return conn.LastMetrics, true
|
||||
}
|
||||
|
||||
// GetAllMetrics returns the last metrics for all connected agents.
|
||||
func (h *Hub) GetAllMetrics() map[string]*models.AllMetrics {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*models.AllMetrics, len(h.agents))
|
||||
for id, conn := range h.agents {
|
||||
if conn.LastMetrics != nil {
|
||||
result[id] = conn.LastMetrics
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *Hub) getAgentConfig() *pb.AgentConfig {
|
||||
return &pb.AgentConfig{
|
||||
CollectionIntervalSeconds: 5,
|
||||
EnabledCollectors: []string{"cpu", "memory", "disk", "network", "process", "temperature", "gpu"},
|
||||
}
|
||||
}
|
||||
|
||||
// Context key for agent ID
|
||||
type contextKey string
|
||||
|
||||
const agentIDKey contextKey = "agentID"
|
||||
|
||||
// AgentIDFromContext extracts the agent ID from context.
|
||||
func AgentIDFromContext(ctx context.Context) (string, bool) {
|
||||
id, ok := ctx.Value(agentIDKey).(string)
|
||||
return id, ok
|
||||
}
|
||||
|
||||
// ContextWithAgentID returns a context with the agent ID set.
|
||||
func ContextWithAgentID(ctx context.Context, agentID string) context.Context {
|
||||
return context.WithValue(ctx, agentIDKey, agentID)
|
||||
}
|
||||
284
backend/internal/server/registry.go
Normal file
284
backend/internal/server/registry.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Package server implements the central Tyto server for multi-device monitoring.
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AgentStatus represents the current state of an agent.
|
||||
type AgentStatus string
|
||||
|
||||
const (
|
||||
AgentStatusPending AgentStatus = "pending"
|
||||
AgentStatusApproved AgentStatus = "approved"
|
||||
AgentStatusConnected AgentStatus = "connected"
|
||||
AgentStatusOffline AgentStatus = "offline"
|
||||
AgentStatusRevoked AgentStatus = "revoked"
|
||||
)
|
||||
|
||||
// AgentRecord stores metadata about a registered agent.
|
||||
type AgentRecord struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Architecture string `json:"architecture"`
|
||||
Version string `json:"version"`
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
Status AgentStatus `json:"status"`
|
||||
CertSerial string `json:"certSerial,omitempty"`
|
||||
CertExpiry time.Time `json:"certExpiry,omitempty"`
|
||||
LastSeen time.Time `json:"lastSeen,omitempty"`
|
||||
RegisteredAt time.Time `json:"registeredAt"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// Registry manages agent registrations.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
agents map[string]*AgentRecord
|
||||
filePath string
|
||||
}
|
||||
|
||||
// NewRegistry creates a new agent registry.
|
||||
func NewRegistry(filePath string) *Registry {
|
||||
r := &Registry{
|
||||
agents: make(map[string]*AgentRecord),
|
||||
filePath: filePath,
|
||||
}
|
||||
|
||||
// Load existing registrations
|
||||
if filePath != "" {
|
||||
r.load()
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Register adds or updates an agent registration.
|
||||
func (r *Registry) Register(agent *AgentRecord) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
existing, exists := r.agents[agent.ID]
|
||||
if exists {
|
||||
// Update existing record
|
||||
existing.Hostname = agent.Hostname
|
||||
existing.OS = agent.OS
|
||||
existing.Architecture = agent.Architecture
|
||||
existing.Version = agent.Version
|
||||
existing.Capabilities = agent.Capabilities
|
||||
existing.LastSeen = time.Now()
|
||||
|
||||
// Don't change status if already approved/connected
|
||||
if existing.Status == AgentStatusRevoked {
|
||||
return fmt.Errorf("agent %s is revoked", agent.ID)
|
||||
}
|
||||
} else {
|
||||
// New registration
|
||||
agent.RegisteredAt = time.Now()
|
||||
agent.LastSeen = time.Now()
|
||||
if agent.Status == "" {
|
||||
agent.Status = AgentStatusPending
|
||||
}
|
||||
r.agents[agent.ID] = agent
|
||||
}
|
||||
|
||||
return r.save()
|
||||
}
|
||||
|
||||
// Get returns an agent record by ID.
|
||||
func (r *Registry) Get(id string) (*AgentRecord, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
agent, exists := r.agents[id]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Return a copy
|
||||
copy := *agent
|
||||
return ©, true
|
||||
}
|
||||
|
||||
// List returns all registered agents.
|
||||
func (r *Registry) List() []*AgentRecord {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make([]*AgentRecord, 0, len(r.agents))
|
||||
for _, agent := range r.agents {
|
||||
copy := *agent
|
||||
result = append(result, ©)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ListPending returns agents awaiting approval.
|
||||
func (r *Registry) ListPending() []*AgentRecord {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []*AgentRecord
|
||||
for _, agent := range r.agents {
|
||||
if agent.Status == AgentStatusPending {
|
||||
copy := *agent
|
||||
result = append(result, ©)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Approve marks an agent as approved.
|
||||
func (r *Registry) Approve(id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
agent, exists := r.agents[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("agent %s not found", id)
|
||||
}
|
||||
|
||||
if agent.Status == AgentStatusRevoked {
|
||||
return fmt.Errorf("agent %s is revoked", id)
|
||||
}
|
||||
|
||||
agent.Status = AgentStatusApproved
|
||||
return r.save()
|
||||
}
|
||||
|
||||
// Reject removes a pending agent registration.
|
||||
func (r *Registry) Reject(id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
agent, exists := r.agents[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("agent %s not found", id)
|
||||
}
|
||||
|
||||
if agent.Status != AgentStatusPending {
|
||||
return fmt.Errorf("agent %s is not pending", id)
|
||||
}
|
||||
|
||||
delete(r.agents, id)
|
||||
return r.save()
|
||||
}
|
||||
|
||||
// Revoke marks an agent as revoked.
|
||||
func (r *Registry) Revoke(id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
agent, exists := r.agents[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("agent %s not found", id)
|
||||
}
|
||||
|
||||
agent.Status = AgentStatusRevoked
|
||||
return r.save()
|
||||
}
|
||||
|
||||
// Remove deletes an agent registration.
|
||||
func (r *Registry) Remove(id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.agents[id]; !exists {
|
||||
return fmt.Errorf("agent %s not found", id)
|
||||
}
|
||||
|
||||
delete(r.agents, id)
|
||||
return r.save()
|
||||
}
|
||||
|
||||
// IsApproved checks if an agent is approved to connect.
|
||||
func (r *Registry) IsApproved(id string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
agent, exists := r.agents[id]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return agent.Status == AgentStatusApproved || agent.Status == AgentStatusConnected
|
||||
}
|
||||
|
||||
// UpdateStatus updates an agent's connection status.
|
||||
func (r *Registry) UpdateStatus(id string, status AgentStatus) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
agent, exists := r.agents[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("agent %s not found", id)
|
||||
}
|
||||
|
||||
agent.Status = status
|
||||
agent.LastSeen = time.Now()
|
||||
return r.save()
|
||||
}
|
||||
|
||||
// UpdateLastSeen updates the last seen timestamp.
|
||||
func (r *Registry) UpdateLastSeen(id string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if agent, exists := r.agents[id]; exists {
|
||||
agent.LastSeen = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// load reads the registry from disk.
|
||||
func (r *Registry) load() error {
|
||||
data, err := os.ReadFile(r.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var agents []*AgentRecord
|
||||
if err := json.Unmarshal(data, &agents); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, agent := range agents {
|
||||
r.agents[agent.ID] = agent
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// save writes the registry to disk.
|
||||
func (r *Registry) save() error {
|
||||
if r.filePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
agents := make([]*AgentRecord, 0, len(r.agents))
|
||||
for _, agent := range r.agents {
|
||||
agents = append(agents, agent)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(agents, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(r.filePath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(r.filePath, data, 0644)
|
||||
}
|
||||
Reference in New Issue
Block a user