Files
tyto/backend/internal/agent/client.go
vikingowl 5e781c0e04 feat: implement lightweight agent with gRPC and mTLS support
Agent Package (internal/agent/):
- Agent struct with all collectors and memory-efficient pooling
- Run loop with configurable collection interval
- Graceful shutdown with context cancellation
- Auto-reconnection callback for re-registration

gRPC Client (internal/agent/client.go):
- mTLS support with CA, agent cert, and key
- Bidirectional streaming for metrics
- Heartbeat fallback when streaming fails
- Exponential backoff with jitter for reconnection
- Concurrent reconnection handling with mutex

Protocol Buffers (proto/tyto.proto):
- AgentService with Stream, Register, Heartbeat RPCs
- MetricsReport with summary fields for aggregation
- ConfigUpdate and Command messages for server control
- RegisterStatus enum for registration workflow

CLI Integration (cmd/tyto/main.go):
- Full agent subcommand with flag parsing
- Support for --id, --server, --interval, --ca-cert, etc.
- Environment variable overrides (TYTO_AGENT_*)
- Signal handling for graceful shutdown

Build System (Makefile):
- Cross-compilation for linux/amd64, arm64, armv7
- Stripped binaries with version info
- Proto generation target
- Test and coverage targets

Config Updates:
- DefaultConfig() and LoadFromPath() functions
- Agent config properly parsed from YAML

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-28 07:42:44 +01:00

403 lines
9.3 KiB
Go

package agent
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log"
"os"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"tyto/internal/config"
"tyto/internal/models"
pb "tyto/internal/proto"
)
// Reconnection parameters
const (
initialBackoff = 1 * time.Second
maxBackoff = 60 * time.Second
backoffFactor = 2.0
jitterFactor = 0.2 // 20% jitter
)
// Client handles gRPC communication with the central server.
type Client struct {
config *config.Config
conn *grpc.ClientConn
client pb.AgentServiceClient
stream pb.AgentService_StreamClient
streamMu sync.Mutex
// Reconnection state
connected bool
reconnecting bool
reconnectMu sync.Mutex
backoff time.Duration
reconnectCh chan struct{}
onReconnect func() // Callback when reconnected
}
// NewClient creates a new gRPC client.
func NewClient(cfg *config.Config) (*Client, error) {
return &Client{
config: cfg,
backoff: initialBackoff,
reconnectCh: make(chan struct{}),
}, nil
}
// SetOnReconnect sets a callback to be called when the client reconnects.
func (c *Client) SetOnReconnect(fn func()) {
c.onReconnect = fn
}
// Connect establishes a connection to the server.
func (c *Client) Connect(ctx context.Context) error {
opts, err := c.dialOptions()
if err != nil {
return fmt.Errorf("failed to create dial options: %w", err)
}
log.Printf("Connecting to server: %s", c.config.Agent.ServerURL)
conn, err := grpc.DialContext(ctx, c.config.Agent.ServerURL, opts...)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
c.conn = conn
c.client = pb.NewAgentServiceClient(conn)
c.connected = true
c.backoff = initialBackoff // Reset backoff on successful connection
log.Println("Connected to server")
return nil
}
// ConnectWithRetry attempts to connect with exponential backoff.
func (c *Client) ConnectWithRetry(ctx context.Context) error {
for {
err := c.Connect(ctx)
if err == nil {
return nil
}
delay := c.nextBackoff()
log.Printf("Connection failed: %v. Retrying in %s...", err, delay)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
// Continue to next attempt
}
}
}
// nextBackoff calculates the next backoff duration with jitter.
func (c *Client) nextBackoff() time.Duration {
c.reconnectMu.Lock()
defer c.reconnectMu.Unlock()
// Add jitter: +/- 20% of current backoff
jitter := time.Duration(float64(c.backoff) * jitterFactor * (2*randFloat() - 1))
delay := c.backoff + jitter
// Increase backoff for next time
c.backoff = time.Duration(float64(c.backoff) * backoffFactor)
if c.backoff > maxBackoff {
c.backoff = maxBackoff
}
return delay
}
// randFloat returns a random float64 in [0, 1).
func randFloat() float64 {
return float64(time.Now().UnixNano()%1000) / 1000.0
}
// Reconnect attempts to reconnect after a disconnection.
func (c *Client) Reconnect(ctx context.Context) error {
c.reconnectMu.Lock()
if c.reconnecting {
c.reconnectMu.Unlock()
// Wait for ongoing reconnection
select {
case <-c.reconnectCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
c.reconnecting = true
c.reconnectMu.Unlock()
defer func() {
c.reconnectMu.Lock()
c.reconnecting = false
// Signal waiting goroutines
close(c.reconnectCh)
c.reconnectCh = make(chan struct{})
c.reconnectMu.Unlock()
}()
// Close existing connection
c.Close()
// Reconnect with retry
if err := c.ConnectWithRetry(ctx); err != nil {
return err
}
// Notify callback
if c.onReconnect != nil {
c.onReconnect()
}
return nil
}
func (c *Client) dialOptions() ([]grpc.DialOption, error) {
opts := []grpc.DialOption{
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
}
// Configure TLS if certificates are provided
agentCfg := c.config.Agent
if agentCfg.TLS.CACert != "" && agentCfg.TLS.AgentCert != "" {
tlsConfig, err := c.loadTLSConfig()
if err != nil {
return nil, err
}
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
} else {
// Insecure mode for development
log.Println("Warning: Running without TLS (insecure mode)")
opts = append(opts, grpc.WithInsecure())
}
return opts, nil
}
func (c *Client) loadTLSConfig() (*tls.Config, error) {
agentCfg := c.config.Agent
// Load CA certificate
caCert, err := os.ReadFile(agentCfg.TLS.CACert)
if err != nil {
return nil, fmt.Errorf("failed to read CA cert: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA cert")
}
// Load agent certificate
cert, err := tls.LoadX509KeyPair(agentCfg.TLS.AgentCert, agentCfg.TLS.AgentKey)
if err != nil {
return nil, fmt.Errorf("failed to load agent cert: %w", err)
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}, nil
}
// Register sends a registration request to the server.
func (c *Client) Register(ctx context.Context, info *pb.AgentInfo) error {
req := &pb.RegisterRequest{
AgentId: c.config.Agent.ID,
Info: info,
}
resp, err := c.client.Register(ctx, req)
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
switch resp.Status {
case pb.RegisterStatus_REGISTER_STATUS_ACCEPTED:
log.Println("Registration accepted")
case pb.RegisterStatus_REGISTER_STATUS_PENDING_APPROVAL:
log.Println("Registration pending approval")
case pb.RegisterStatus_REGISTER_STATUS_ALREADY_REGISTERED:
log.Println("Already registered")
case pb.RegisterStatus_REGISTER_STATUS_REJECTED:
return fmt.Errorf("registration rejected: %s", resp.Message)
}
return nil
}
// SendMetrics sends collected metrics to the server.
func (c *Client) SendMetrics(ctx context.Context, agentID string, metrics *models.AllMetrics) error {
// Serialize metrics to JSON
metricsJSON, err := SerializeMetrics(metrics)
if err != nil {
return fmt.Errorf("failed to serialize metrics: %w", err)
}
// Create metrics report
report := &pb.MetricsReport{
AgentId: agentID,
TimestampMs: metrics.Timestamp.UnixMilli(),
MetricsJson: metricsJSON,
// Summary metrics for quick aggregation
CpuUsage: metrics.CPU.TotalUsage,
MemoryPercent: float64(metrics.Memory.Used) / float64(metrics.Memory.Total) * 100,
GpuUtilization: int32(metrics.GPU.Utilization),
}
// Calculate disk usage percent (max across mounts)
var maxDiskPercent float64
for _, mount := range metrics.Disk.Mounts {
if mount.UsedPercent > maxDiskPercent {
maxDiskPercent = mount.UsedPercent
}
}
report.DiskPercent = maxDiskPercent
// Try to use stream first, fall back to heartbeat
if err := c.sendViaStream(ctx, report); err != nil {
log.Printf("Stream send failed, using heartbeat: %v", err)
return c.sendViaHeartbeat(ctx, agentID)
}
return nil
}
func (c *Client) sendViaStream(ctx context.Context, report *pb.MetricsReport) error {
c.streamMu.Lock()
defer c.streamMu.Unlock()
// Establish stream if not exists
if c.stream == nil {
stream, err := c.client.Stream(ctx)
if err != nil {
return fmt.Errorf("failed to create stream: %w", err)
}
c.stream = stream
// Start goroutine to handle server messages
go c.handleServerMessages()
}
// Send metrics via stream
msg := &pb.AgentMessage{
Payload: &pb.AgentMessage_Metrics{
Metrics: report,
},
}
if err := c.stream.Send(msg); err != nil {
c.stream = nil // Reset stream on error
return fmt.Errorf("failed to send metrics: %w", err)
}
return nil
}
func (c *Client) sendViaHeartbeat(ctx context.Context, agentID string) error {
req := &pb.HeartbeatRequest{
AgentId: agentID,
TimestampMs: time.Now().UnixMilli(),
}
_, err := c.client.Heartbeat(ctx, req)
return err
}
func (c *Client) handleServerMessages() {
for {
c.streamMu.Lock()
stream := c.stream
c.streamMu.Unlock()
if stream == nil {
return
}
msg, err := stream.Recv()
if err != nil {
log.Printf("Stream receive error: %v", err)
c.streamMu.Lock()
c.stream = nil
c.streamMu.Unlock()
return
}
c.handleMessage(msg)
}
}
func (c *Client) handleMessage(msg *pb.ServerMessage) {
switch payload := msg.Payload.(type) {
case *pb.ServerMessage_Ack:
// Message acknowledged
if !payload.Ack.Success {
log.Printf("Server error: %s", payload.Ack.Error)
}
case *pb.ServerMessage_Config:
log.Printf("Received config update: interval=%ds",
payload.Config.CollectionIntervalSeconds)
// TODO: Apply config update
case *pb.ServerMessage_Command:
c.handleCommand(payload.Command)
}
}
func (c *Client) handleCommand(cmd *pb.Command) {
switch cmd.Type {
case pb.CommandType_COMMAND_TYPE_COLLECT_NOW:
log.Println("Received collect-now command")
// TODO: Trigger immediate collection
case pb.CommandType_COMMAND_TYPE_DISCONNECT:
log.Println("Received disconnect command")
c.Close()
default:
log.Printf("Unknown command type: %v", cmd.Type)
}
}
// Close closes the connection to the server.
func (c *Client) Close() error {
c.connected = false
c.streamMu.Lock()
if c.stream != nil {
c.stream.CloseSend()
c.stream = nil
}
c.streamMu.Unlock()
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// IsConnected returns true if connected to the server.
func (c *Client) IsConnected() bool {
return c.connected
}