package server import ( "context" "crypto/tls" "crypto/x509" "fmt" "log" "net" "os" "strings" "tyto/internal/config" "tyto/internal/pki" pb "tyto/internal/proto" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) // GRPCServer wraps the gRPC server and Hub. type GRPCServer struct { pb.UnimplementedAgentServiceServer hub *Hub server *grpc.Server listener net.Listener config *config.ServerConfig } // NewGRPCServer creates a new gRPC server. func NewGRPCServer(hub *Hub, cfg *config.ServerConfig) (*GRPCServer, error) { s := &GRPCServer{ hub: hub, config: cfg, } // Build server options opts, err := s.serverOptions() if err != nil { return nil, fmt.Errorf("failed to create server options: %w", err) } s.server = grpc.NewServer(opts...) pb.RegisterAgentServiceServer(s.server, s) return s, nil } func (s *GRPCServer) serverOptions() ([]grpc.ServerOption, error) { var opts []grpc.ServerOption tlsCfg := s.config.TLS if tlsCfg.CACert != "" && tlsCfg.ServerCert != "" { // Load mTLS configuration tlsConfig, err := s.loadTLSConfig() if err != nil { return nil, err } opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig))) } else { log.Println("Warning: gRPC server running without TLS (insecure mode)") } // Add interceptors for authentication and logging opts = append(opts, grpc.UnaryInterceptor(s.unaryInterceptor), grpc.StreamInterceptor(s.streamInterceptor), ) return opts, nil } func (s *GRPCServer) loadTLSConfig() (*tls.Config, error) { tlsCfg := s.config.TLS // Load CA certificate caCert, err := os.ReadFile(tlsCfg.CACert) if err != nil { return nil, fmt.Errorf("failed to read CA cert: %w", err) } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { return nil, fmt.Errorf("failed to parse CA cert") } // Load server certificate cert, err := tls.LoadX509KeyPair(tlsCfg.ServerCert, tlsCfg.ServerKey) if err != nil { return nil, fmt.Errorf("failed to load server cert: %w", err) } return &tls.Config{ Certificates: []tls.Certificate{cert}, ClientCAs: caCertPool, ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS12, }, nil } // Start begins listening for connections. func (s *GRPCServer) Start(port int) error { addr := fmt.Sprintf(":%d", port) listener, err := net.Listen("tcp", addr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", addr, err) } s.listener = listener log.Printf("gRPC server listening on %s", addr) return s.server.Serve(listener) } // Stop gracefully stops the server. func (s *GRPCServer) Stop() { if s.server != nil { s.server.GracefulStop() } } // Interceptors func (s *GRPCServer) unaryInterceptor( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { // Extract agent ID from TLS certificate agentID, err := s.extractAgentID(ctx) if err != nil { // Allow register without pre-auth for initial registration if !strings.Contains(info.FullMethod, "Register") { return nil, err } // For registration, use the agent ID from the request if regReq, ok := req.(*pb.RegisterRequest); ok { agentID = regReq.AgentId } } // Add agent ID to context ctx = ContextWithAgentID(ctx, agentID) // Log the request log.Printf("gRPC %s from agent %s", info.FullMethod, agentID) return handler(ctx, req) } func (s *GRPCServer) streamInterceptor( srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { // Extract agent ID from TLS certificate agentID, err := s.extractAgentID(ss.Context()) if err != nil { return err } // Wrap the stream with context containing agent ID wrapped := &wrappedServerStream{ ServerStream: ss, ctx: ContextWithAgentID(ss.Context(), agentID), } log.Printf("gRPC stream %s from agent %s", info.FullMethod, agentID) return handler(srv, wrapped) } func (s *GRPCServer) extractAgentID(ctx context.Context) (string, error) { // Try to get peer info p, ok := peer.FromContext(ctx) if !ok { return "", status.Error(codes.Unauthenticated, "no peer info") } // Check for TLS info tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo) if !ok { // No TLS, might be insecure mode - check for existing context value if id, ok := AgentIDFromContext(ctx); ok { return id, nil } return "", status.Error(codes.Unauthenticated, "no TLS info") } // Extract CN from client certificate if len(tlsInfo.State.PeerCertificates) == 0 { return "", status.Error(codes.Unauthenticated, "no client certificate") } cert := tlsInfo.State.PeerCertificates[0] agentID := pki.ExtractAgentID(cert) if agentID == "" { return "", status.Error(codes.Unauthenticated, "no agent ID in certificate") } return agentID, nil } // wrappedServerStream wraps a ServerStream to override Context(). type wrappedServerStream struct { grpc.ServerStream ctx context.Context } func (w *wrappedServerStream) Context() context.Context { return w.ctx } // Service methods - delegate to Hub func (s *GRPCServer) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) { return s.hub.Register(ctx, req) } func (s *GRPCServer) Stream(stream pb.AgentService_StreamServer) error { return s.hub.Stream(stream) } func (s *GRPCServer) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) { return s.hub.Heartbeat(ctx, req) }