package agent import ( "context" "crypto/tls" "crypto/x509" "fmt" "log" "os" "sync" "time" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "tyto/internal/config" "tyto/internal/models" pb "tyto/internal/proto" ) // Reconnection parameters const ( initialBackoff = 1 * time.Second maxBackoff = 60 * time.Second backoffFactor = 2.0 jitterFactor = 0.2 // 20% jitter ) // Client handles gRPC communication with the central server. type Client struct { config *config.Config conn *grpc.ClientConn client pb.AgentServiceClient stream pb.AgentService_StreamClient streamMu sync.Mutex // Reconnection state connected bool reconnecting bool reconnectMu sync.Mutex backoff time.Duration reconnectCh chan struct{} onReconnect func() // Callback when reconnected } // NewClient creates a new gRPC client. func NewClient(cfg *config.Config) (*Client, error) { return &Client{ config: cfg, backoff: initialBackoff, reconnectCh: make(chan struct{}), }, nil } // SetOnReconnect sets a callback to be called when the client reconnects. func (c *Client) SetOnReconnect(fn func()) { c.onReconnect = fn } // Connect establishes a connection to the server. func (c *Client) Connect(ctx context.Context) error { opts, err := c.dialOptions() if err != nil { return fmt.Errorf("failed to create dial options: %w", err) } log.Printf("Connecting to server: %s", c.config.Agent.ServerURL) conn, err := grpc.DialContext(ctx, c.config.Agent.ServerURL, opts...) if err != nil { return fmt.Errorf("failed to connect: %w", err) } c.conn = conn c.client = pb.NewAgentServiceClient(conn) c.connected = true c.backoff = initialBackoff // Reset backoff on successful connection log.Println("Connected to server") return nil } // ConnectWithRetry attempts to connect with exponential backoff. func (c *Client) ConnectWithRetry(ctx context.Context) error { for { err := c.Connect(ctx) if err == nil { return nil } delay := c.nextBackoff() log.Printf("Connection failed: %v. Retrying in %s...", err, delay) select { case <-ctx.Done(): return ctx.Err() case <-time.After(delay): // Continue to next attempt } } } // nextBackoff calculates the next backoff duration with jitter. func (c *Client) nextBackoff() time.Duration { c.reconnectMu.Lock() defer c.reconnectMu.Unlock() // Add jitter: +/- 20% of current backoff jitter := time.Duration(float64(c.backoff) * jitterFactor * (2*randFloat() - 1)) delay := c.backoff + jitter // Increase backoff for next time c.backoff = time.Duration(float64(c.backoff) * backoffFactor) if c.backoff > maxBackoff { c.backoff = maxBackoff } return delay } // randFloat returns a random float64 in [0, 1). func randFloat() float64 { return float64(time.Now().UnixNano()%1000) / 1000.0 } // Reconnect attempts to reconnect after a disconnection. func (c *Client) Reconnect(ctx context.Context) error { c.reconnectMu.Lock() if c.reconnecting { c.reconnectMu.Unlock() // Wait for ongoing reconnection select { case <-c.reconnectCh: return nil case <-ctx.Done(): return ctx.Err() } } c.reconnecting = true c.reconnectMu.Unlock() defer func() { c.reconnectMu.Lock() c.reconnecting = false // Signal waiting goroutines close(c.reconnectCh) c.reconnectCh = make(chan struct{}) c.reconnectMu.Unlock() }() // Close existing connection c.Close() // Reconnect with retry if err := c.ConnectWithRetry(ctx); err != nil { return err } // Notify callback if c.onReconnect != nil { c.onReconnect() } return nil } func (c *Client) dialOptions() ([]grpc.DialOption, error) { opts := []grpc.DialOption{ grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, PermitWithoutStream: true, }), } // Configure TLS if certificates are provided agentCfg := c.config.Agent if agentCfg.TLS.CACert != "" && agentCfg.TLS.AgentCert != "" { tlsConfig, err := c.loadTLSConfig() if err != nil { return nil, err } opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) } else { // Insecure mode for development log.Println("Warning: Running without TLS (insecure mode)") opts = append(opts, grpc.WithInsecure()) } return opts, nil } func (c *Client) loadTLSConfig() (*tls.Config, error) { agentCfg := c.config.Agent // Load CA certificate caCert, err := os.ReadFile(agentCfg.TLS.CACert) if err != nil { return nil, fmt.Errorf("failed to read CA cert: %w", err) } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { return nil, fmt.Errorf("failed to parse CA cert") } // Load agent certificate cert, err := tls.LoadX509KeyPair(agentCfg.TLS.AgentCert, agentCfg.TLS.AgentKey) if err != nil { return nil, fmt.Errorf("failed to load agent cert: %w", err) } return &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: caCertPool, MinVersion: tls.VersionTLS12, }, nil } // Register sends a registration request to the server. func (c *Client) Register(ctx context.Context, info *pb.AgentInfo) error { req := &pb.RegisterRequest{ AgentId: c.config.Agent.ID, Info: info, } resp, err := c.client.Register(ctx, req) if err != nil { return fmt.Errorf("registration failed: %w", err) } switch resp.Status { case pb.RegisterStatus_REGISTER_STATUS_ACCEPTED: log.Println("Registration accepted") case pb.RegisterStatus_REGISTER_STATUS_PENDING_APPROVAL: log.Println("Registration pending approval") case pb.RegisterStatus_REGISTER_STATUS_ALREADY_REGISTERED: log.Println("Already registered") case pb.RegisterStatus_REGISTER_STATUS_REJECTED: return fmt.Errorf("registration rejected: %s", resp.Message) } return nil } // SendMetrics sends collected metrics to the server. func (c *Client) SendMetrics(ctx context.Context, agentID string, metrics *models.AllMetrics) error { // Serialize metrics to JSON metricsJSON, err := SerializeMetrics(metrics) if err != nil { return fmt.Errorf("failed to serialize metrics: %w", err) } // Create metrics report report := &pb.MetricsReport{ AgentId: agentID, TimestampMs: metrics.Timestamp.UnixMilli(), MetricsJson: metricsJSON, // Summary metrics for quick aggregation CpuUsage: metrics.CPU.TotalUsage, MemoryPercent: float64(metrics.Memory.Used) / float64(metrics.Memory.Total) * 100, GpuUtilization: int32(metrics.GPU.Utilization), } // Calculate disk usage percent (max across mounts) var maxDiskPercent float64 for _, mount := range metrics.Disk.Mounts { if mount.UsedPercent > maxDiskPercent { maxDiskPercent = mount.UsedPercent } } report.DiskPercent = maxDiskPercent // Try to use stream first, fall back to heartbeat if err := c.sendViaStream(ctx, report); err != nil { log.Printf("Stream send failed, using heartbeat: %v", err) return c.sendViaHeartbeat(ctx, agentID) } return nil } func (c *Client) sendViaStream(ctx context.Context, report *pb.MetricsReport) error { c.streamMu.Lock() defer c.streamMu.Unlock() // Establish stream if not exists if c.stream == nil { stream, err := c.client.Stream(ctx) if err != nil { return fmt.Errorf("failed to create stream: %w", err) } c.stream = stream // Start goroutine to handle server messages go c.handleServerMessages() } // Send metrics via stream msg := &pb.AgentMessage{ Payload: &pb.AgentMessage_Metrics{ Metrics: report, }, } if err := c.stream.Send(msg); err != nil { c.stream = nil // Reset stream on error return fmt.Errorf("failed to send metrics: %w", err) } return nil } func (c *Client) sendViaHeartbeat(ctx context.Context, agentID string) error { req := &pb.HeartbeatRequest{ AgentId: agentID, TimestampMs: time.Now().UnixMilli(), } _, err := c.client.Heartbeat(ctx, req) return err } func (c *Client) handleServerMessages() { for { c.streamMu.Lock() stream := c.stream c.streamMu.Unlock() if stream == nil { return } msg, err := stream.Recv() if err != nil { log.Printf("Stream receive error: %v", err) c.streamMu.Lock() c.stream = nil c.streamMu.Unlock() return } c.handleMessage(msg) } } func (c *Client) handleMessage(msg *pb.ServerMessage) { switch payload := msg.Payload.(type) { case *pb.ServerMessage_Ack: // Message acknowledged if !payload.Ack.Success { log.Printf("Server error: %s", payload.Ack.Error) } case *pb.ServerMessage_Config: log.Printf("Received config update: interval=%ds", payload.Config.CollectionIntervalSeconds) // TODO: Apply config update case *pb.ServerMessage_Command: c.handleCommand(payload.Command) } } func (c *Client) handleCommand(cmd *pb.Command) { switch cmd.Type { case pb.CommandType_COMMAND_TYPE_COLLECT_NOW: log.Println("Received collect-now command") // TODO: Trigger immediate collection case pb.CommandType_COMMAND_TYPE_DISCONNECT: log.Println("Received disconnect command") c.Close() default: log.Printf("Unknown command type: %v", cmd.Type) } } // Close closes the connection to the server. func (c *Client) Close() error { c.connected = false c.streamMu.Lock() if c.stream != nil { c.stream.CloseSend() c.stream = nil } c.streamMu.Unlock() if c.conn != nil { return c.conn.Close() } return nil } // IsConnected returns true if connected to the server. func (c *Client) IsConnected() bool { return c.connected }