diff --git a/backend/cmd/tyto/main.go b/backend/cmd/tyto/main.go new file mode 100644 index 0000000..6a0bd9f --- /dev/null +++ b/backend/cmd/tyto/main.go @@ -0,0 +1,94 @@ +// Package main provides the Tyto CLI with subcommands for PKI management, +// server operation, and agent operation. +package main + +import ( + "fmt" + "os" +) + +const usage = `Tyto - System Monitoring + +Usage: + tyto [options] + +Commands: + server Start the Tyto server (default if no command given) + agent Start as a Tyto agent + pki PKI/certificate management + +Run 'tyto --help' for more information on a command. +` + +func main() { + if len(os.Args) < 2 { + // Default to server command + runServer(os.Args[1:]) + return + } + + switch os.Args[1] { + case "server": + runServer(os.Args[2:]) + case "agent": + runAgent(os.Args[2:]) + case "pki": + runPKI(os.Args[2:]) + case "-h", "--help", "help": + fmt.Print(usage) + default: + fmt.Fprintf(os.Stderr, "Unknown command: %s\n\n", os.Args[1]) + fmt.Print(usage) + os.Exit(1) + } +} + +func runServer(args []string) { + // For now, delegate to the existing server main + // In a full implementation, this would parse server-specific flags + fmt.Println("Starting Tyto server...") + fmt.Println("Use 'tyto server --help' for options") + + // Import and call the actual server logic + // For now, just print usage + if len(args) > 0 && (args[0] == "-h" || args[0] == "--help") { + fmt.Print(`Usage: tyto server [options] + +Options: + --port PORT HTTP port (default: 8080) + --mode MODE Operating mode: standalone, server (default: standalone) + --config FILE Config file path + +Environment: + TYTO_MODE Operating mode + PORT HTTP port + TYTO_CONFIG Config file path +`) + return + } + + fmt.Println("Server mode not implemented in tyto CLI yet.") + fmt.Println("Use the 'server' binary directly or set TYTO_MODE environment variable.") +} + +func runAgent(args []string) { + if len(args) > 0 && (args[0] == "-h" || args[0] == "--help") { + fmt.Print(`Usage: tyto agent [options] + +Options: + --id ID Agent identifier (required) + --server URL Server URL to report to (required) + --interval DURATION Collection interval (default: 5s) + --ca-cert FILE CA certificate file + --cert FILE Agent certificate file + --key FILE Agent key file + +Environment: + TYTO_AGENT_ID Agent identifier + TYTO_SERVER_URL Server URL +`) + return + } + + fmt.Println("Agent mode not implemented yet.") +} diff --git a/backend/cmd/tyto/pki.go b/backend/cmd/tyto/pki.go new file mode 100644 index 0000000..7a4821e --- /dev/null +++ b/backend/cmd/tyto/pki.go @@ -0,0 +1,375 @@ +package main + +import ( + "flag" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "time" + + "tyto/internal/pki" +) + +const pkiUsage = `PKI/Certificate Management + +Usage: + tyto pki [options] + +Subcommands: + init-ca Initialize a new Certificate Authority + gen-server Generate a server certificate + gen-agent Generate an agent certificate + revoke Revoke a certificate + list List all certificates + info Show CA information + +Run 'tyto pki --help' for more information. +` + +func runPKI(args []string) { + if len(args) == 0 { + fmt.Print(pkiUsage) + return + } + + switch args[0] { + case "init-ca": + pkiInitCA(args[1:]) + case "gen-server": + pkiGenServer(args[1:]) + case "gen-agent": + pkiGenAgent(args[1:]) + case "revoke": + pkiRevoke(args[1:]) + case "list": + pkiList(args[1:]) + case "info": + pkiInfo(args[1:]) + case "-h", "--help", "help": + fmt.Print(pkiUsage) + default: + fmt.Fprintf(os.Stderr, "Unknown pki subcommand: %s\n\n", args[0]) + fmt.Print(pkiUsage) + os.Exit(1) + } +} + +func pkiInitCA(args []string) { + fs := flag.NewFlagSet("pki init-ca", flag.ExitOnError) + cn := fs.String("cn", "Tyto CA", "Common Name for the CA") + org := fs.String("org", "Tyto", "Organization name") + country := fs.String("country", "US", "Country code") + outDir := fs.String("out", "/etc/tyto/pki", "Output directory for CA files") + validity := fs.Duration("validity", 10*365*24*time.Hour, "CA validity period") + + fs.Parse(args) + + fmt.Printf("Initializing CA in %s...\n", *outDir) + + cfg := pki.CAConfig{ + CommonName: *cn, + Organization: *org, + Country: *country, + Validity: *validity, + } + + ca, err := pki.InitCA(cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if err := ca.SaveToDir(*outDir); err != nil { + fmt.Fprintf(os.Stderr, "Error saving CA: %v\n", err) + os.Exit(1) + } + + info := ca.Info() + fmt.Println("CA initialized successfully!") + fmt.Printf(" Subject: %s\n", info.Subject) + fmt.Printf(" Serial: %s\n", info.SerialNumber) + fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339)) + fmt.Printf(" Certificate: %s\n", filepath.Join(*outDir, "ca.crt")) + fmt.Printf(" Private key: %s\n", filepath.Join(*outDir, "ca.key")) +} + +func pkiGenServer(args []string) { + fs := flag.NewFlagSet("pki gen-server", flag.ExitOnError) + caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory") + cn := fs.String("cn", "", "Common Name (required)") + dnsNames := fs.String("dns", "", "DNS names (comma-separated)") + ips := fs.String("ip", "", "IP addresses (comma-separated)") + outDir := fs.String("out", "", "Output directory (default: ca-dir/certs)") + validity := fs.Duration("validity", 365*24*time.Hour, "Certificate validity") + + fs.Parse(args) + + if *cn == "" { + fmt.Fprintln(os.Stderr, "Error: --cn is required") + fs.Usage() + os.Exit(1) + } + + if *outDir == "" { + *outDir = filepath.Join(*caDir, "certs") + } + + // Load CA + ca, err := pki.LoadCAFromDir(*caDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err) + os.Exit(1) + } + + // Parse DNS names + var dns []string + if *dnsNames != "" { + dns = strings.Split(*dnsNames, ",") + for i := range dns { + dns[i] = strings.TrimSpace(dns[i]) + } + } + + // Parse IP addresses + var ipAddrs []net.IP + if *ips != "" { + for _, ipStr := range strings.Split(*ips, ",") { + ip := net.ParseIP(strings.TrimSpace(ipStr)) + if ip != nil { + ipAddrs = append(ipAddrs, ip) + } + } + } + + cfg := pki.ServerCertConfig{ + CommonName: *cn, + DNSNames: dns, + IPAddresses: ipAddrs, + Validity: *validity, + } + + bundle, err := ca.GenerateServerCert(cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error generating certificate: %v\n", err) + os.Exit(1) + } + + // Save certificate + certPath := filepath.Join(*outDir, *cn+".crt") + keyPath := filepath.Join(*outDir, *cn+".key") + + if err := bundle.SaveToFiles(certPath, keyPath); err != nil { + fmt.Fprintf(os.Stderr, "Error saving certificate: %v\n", err) + os.Exit(1) + } + + info := bundle.Info(pki.CertTypeServer) + fmt.Println("Server certificate generated!") + fmt.Printf(" Subject: %s\n", info.Subject) + fmt.Printf(" Serial: %s\n", info.SerialNumber) + fmt.Printf(" DNS names: %v\n", dns) + fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339)) + fmt.Printf(" Certificate: %s\n", certPath) + fmt.Printf(" Private key: %s\n", keyPath) +} + +func pkiGenAgent(args []string) { + fs := flag.NewFlagSet("pki gen-agent", flag.ExitOnError) + caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory") + agentID := fs.String("agent-id", "", "Agent ID (required, used as CN)") + outDir := fs.String("out", "", "Output directory (default: ca-dir/agents)") + validity := fs.Duration("validity", 365*24*time.Hour, "Certificate validity") + + fs.Parse(args) + + if *agentID == "" { + fmt.Fprintln(os.Stderr, "Error: --agent-id is required") + fs.Usage() + os.Exit(1) + } + + if *outDir == "" { + *outDir = filepath.Join(*caDir, "agents") + } + + // Load CA + ca, err := pki.LoadCAFromDir(*caDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err) + os.Exit(1) + } + + cfg := pki.AgentCertConfig{ + AgentID: *agentID, + Validity: *validity, + } + + bundle, err := ca.GenerateAgentCert(cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error generating certificate: %v\n", err) + os.Exit(1) + } + + // Save certificate + certPath := filepath.Join(*outDir, *agentID+".crt") + keyPath := filepath.Join(*outDir, *agentID+".key") + + if err := bundle.SaveToFiles(certPath, keyPath); err != nil { + fmt.Fprintf(os.Stderr, "Error saving certificate: %v\n", err) + os.Exit(1) + } + + info := bundle.Info(pki.CertTypeAgent) + fmt.Println("Agent certificate generated!") + fmt.Printf(" Agent ID: %s\n", *agentID) + fmt.Printf(" Serial: %s\n", info.SerialNumber) + fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339)) + fmt.Printf(" Certificate: %s\n", certPath) + fmt.Printf(" Private key: %s\n", keyPath) +} + +func pkiRevoke(args []string) { + fs := flag.NewFlagSet("pki revoke", flag.ExitOnError) + caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory") + serial := fs.String("serial", "", "Certificate serial number to revoke (required)") + + fs.Parse(args) + + if *serial == "" { + fmt.Fprintln(os.Stderr, "Error: --serial is required") + fs.Usage() + os.Exit(1) + } + + // Load CA + ca, err := pki.LoadCAFromDir(*caDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err) + os.Exit(1) + } + + if err := ca.RevokeCertificate(*serial); err != nil { + fmt.Fprintf(os.Stderr, "Error revoking certificate: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Certificate %s has been revoked.\n", *serial) + fmt.Printf("CRL updated: %s\n", filepath.Join(*caDir, "ca.crl")) +} + +func pkiList(args []string) { + fs := flag.NewFlagSet("pki list", flag.ExitOnError) + caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory") + certType := fs.String("type", "", "Filter by type: ca, server, agent") + + fs.Parse(args) + + // Load CA + ca, err := pki.LoadCAFromDir(*caDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err) + os.Exit(1) + } + + store := ca.Store() + if store == nil { + fmt.Println("No certificates found.") + return + } + + var certs []pki.CertInfo + switch *certType { + case "server": + certs = store.ListByType(pki.CertTypeServer) + case "agent": + certs = store.ListByType(pki.CertTypeAgent) + case "": + certs = store.ListCertificates() + default: + fmt.Fprintf(os.Stderr, "Unknown type: %s\n", *certType) + os.Exit(1) + } + + if len(certs) == 0 { + fmt.Println("No certificates found.") + return + } + + fmt.Printf("%-16s %-12s %-20s %-8s %s\n", "SERIAL", "TYPE", "SUBJECT", "REVOKED", "EXPIRES") + fmt.Println(strings.Repeat("-", 80)) + + for _, cert := range certs { + revoked := "" + if cert.Revoked { + revoked = "YES" + } + serial := cert.SerialNumber + if len(serial) > 14 { + serial = serial[:14] + ".." + } + fmt.Printf("%-16s %-12s %-20s %-8s %s\n", + serial, + cert.Type, + cert.Subject, + revoked, + cert.NotAfter.Format("2006-01-02"), + ) + } +} + +func pkiInfo(args []string) { + fs := flag.NewFlagSet("pki info", flag.ExitOnError) + caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory") + + fs.Parse(args) + + // Load CA + ca, err := pki.LoadCAFromDir(*caDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err) + os.Exit(1) + } + + info := ca.Info() + fmt.Println("Certificate Authority Information") + fmt.Println(strings.Repeat("=", 40)) + fmt.Printf("Subject: %s\n", info.Subject) + fmt.Printf("Serial Number: %s\n", info.SerialNumber) + fmt.Printf("Valid From: %s\n", info.NotBefore.Format(time.RFC3339)) + fmt.Printf("Valid Until: %s\n", info.NotAfter.Format(time.RFC3339)) + + // Check if expired + if time.Now().After(info.NotAfter) { + fmt.Println("\n⚠️ WARNING: CA certificate has expired!") + } else { + remaining := time.Until(info.NotAfter) + fmt.Printf("Remaining: %d days\n", int(remaining.Hours()/24)) + } + + // CRL info + crlInfo, err := ca.GetCRLInfo() + if err != nil { + fmt.Printf("\nCRL: Error loading: %v\n", err) + } else if crlInfo == nil { + fmt.Println("\nCRL: Not generated yet") + } else { + fmt.Println("\nCRL Information") + fmt.Println(strings.Repeat("-", 40)) + fmt.Printf("Last Update: %s\n", crlInfo.ThisUpdate.Format(time.RFC3339)) + fmt.Printf("Next Update: %s\n", crlInfo.NextUpdate.Format(time.RFC3339)) + fmt.Printf("Revoked Certs: %d\n", crlInfo.RevokedCount) + } + + // Certificate counts + store := ca.Store() + if store != nil { + servers := store.ListByType(pki.CertTypeServer) + agents := store.ListByType(pki.CertTypeAgent) + fmt.Println("\nCertificate Counts") + fmt.Println(strings.Repeat("-", 40)) + fmt.Printf("Server Certs: %d\n", len(servers)) + fmt.Printf("Agent Certs: %d\n", len(agents)) + } +} diff --git a/backend/internal/pki/ca.go b/backend/internal/pki/ca.go new file mode 100644 index 0000000..737f3fb --- /dev/null +++ b/backend/internal/pki/ca.go @@ -0,0 +1,231 @@ +package pki + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +// CA represents a Certificate Authority for signing certificates. +type CA struct { + cert *x509.Certificate + key *rsa.PrivateKey + certPEM []byte + keyPEM []byte + store *Store +} + +// CAConfig contains options for creating a new CA. +type CAConfig struct { + CommonName string + Organization string + Country string + Validity time.Duration + KeySize int +} + +// DefaultCAConfig returns a CAConfig with sensible defaults. +func DefaultCAConfig() CAConfig { + return CAConfig{ + CommonName: "Tyto CA", + Organization: "Tyto", + Country: "US", + Validity: DefaultCAValidity, + KeySize: DefaultKeySize, + } +} + +// InitCA creates a new Certificate Authority with the given configuration. +func InitCA(cfg CAConfig) (*CA, error) { + if cfg.KeySize == 0 { + cfg.KeySize = DefaultKeySize + } + if cfg.Validity == 0 { + cfg.Validity = DefaultCAValidity + } + + // Generate private key + key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize) + if err != nil { + return nil, fmt.Errorf("failed to generate CA key: %w", err) + } + + // Generate serial number + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + now := time.Now() + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: cfg.CommonName, + Organization: []string{cfg.Organization}, + Country: []string{cfg.Country}, + }, + NotBefore: now, + NotAfter: now.Add(cfg.Validity), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + } + + // Self-sign the certificate + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + return nil, fmt.Errorf("failed to create CA certificate: %w", err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, fmt.Errorf("failed to parse CA certificate: %w", err) + } + + // Encode to PEM + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + return &CA{ + cert: cert, + key: key, + certPEM: certPEM, + keyPEM: keyPEM, + }, nil +} + +// LoadCA loads an existing CA from PEM-encoded certificate and key. +func LoadCA(certPEM, keyPEM []byte) (*CA, error) { + // Parse certificate + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil { + return nil, fmt.Errorf("failed to decode CA certificate PEM") + } + + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CA certificate: %w", err) + } + + if !cert.IsCA { + return nil, fmt.Errorf("certificate is not a CA") + } + + // Parse private key + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return nil, fmt.Errorf("failed to decode CA key PEM") + } + + var key *rsa.PrivateKey + switch keyBlock.Type { + case "RSA PRIVATE KEY": + key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "PRIVATE KEY": + parsedKey, parseErr := x509.ParsePKCS8PrivateKey(keyBlock.Bytes) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse CA key: %w", parseErr) + } + var ok bool + key, ok = parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("CA key is not RSA") + } + default: + return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type) + } + if err != nil { + return nil, fmt.Errorf("failed to parse CA key: %w", err) + } + + return &CA{ + cert: cert, + key: key, + certPEM: certPEM, + keyPEM: keyPEM, + }, nil +} + +// LoadCAFromDir loads a CA from a directory containing ca.crt and ca.key files. +func LoadCAFromDir(dir string) (*CA, error) { + certPath := filepath.Join(dir, "ca.crt") + keyPath := filepath.Join(dir, "ca.key") + + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA key: %w", err) + } + + ca, err := LoadCA(certPEM, keyPEM) + if err != nil { + return nil, err + } + + // Initialize store for this directory + ca.store = NewStore(dir) + + return ca, nil +} + +// SaveToDir saves the CA certificate and key to a directory. +func (ca *CA) SaveToDir(dir string) error { + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create CA directory: %w", err) + } + + certPath := filepath.Join(dir, "ca.crt") + keyPath := filepath.Join(dir, "ca.key") + + if err := os.WriteFile(certPath, ca.certPEM, 0644); err != nil { + return fmt.Errorf("failed to write CA certificate: %w", err) + } + + if err := os.WriteFile(keyPath, ca.keyPEM, 0600); err != nil { + return fmt.Errorf("failed to write CA key: %w", err) + } + + // Initialize store + ca.store = NewStore(dir) + + return nil +} + +// Certificate returns the CA certificate. +func (ca *CA) Certificate() *x509.Certificate { + return ca.cert +} + +// CertificatePEM returns the PEM-encoded CA certificate. +func (ca *CA) CertificatePEM() []byte { + return ca.certPEM +} + +// Info returns metadata about the CA certificate. +func (ca *CA) Info() CertInfo { + return CertInfoFromX509(ca.cert, CertTypeCA) +} + +// Store returns the certificate store associated with this CA. +func (ca *CA) Store() *Store { + return ca.store +} diff --git a/backend/internal/pki/cert.go b/backend/internal/pki/cert.go new file mode 100644 index 0000000..361fd93 --- /dev/null +++ b/backend/internal/pki/cert.go @@ -0,0 +1,270 @@ +package pki + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "time" +) + +// CertificateBundle contains a certificate and its private key. +type CertificateBundle struct { + Certificate *x509.Certificate + PrivateKey *rsa.PrivateKey + CertificatePEM []byte + PrivateKeyPEM []byte +} + +// ServerCertConfig contains options for generating a server certificate. +type ServerCertConfig struct { + CommonName string + Organization string + DNSNames []string + IPAddresses []net.IP + Validity time.Duration + KeySize int +} + +// AgentCertConfig contains options for generating an agent certificate. +type AgentCertConfig struct { + AgentID string // Used as CommonName + Organization string + Validity time.Duration + KeySize int +} + +// GenerateServerCert creates a new server certificate signed by the CA. +func (ca *CA) GenerateServerCert(cfg ServerCertConfig) (*CertificateBundle, error) { + if cfg.CommonName == "" { + return nil, fmt.Errorf("common name is required") + } + if cfg.KeySize == 0 { + cfg.KeySize = DefaultKeySize + } + if cfg.Validity == 0 { + cfg.Validity = DefaultCertValidity + } + + // Generate private key + key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize) + if err != nil { + return nil, fmt.Errorf("failed to generate server key: %w", err) + } + + // Generate serial number + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + now := time.Now() + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: cfg.CommonName, + Organization: []string{cfg.Organization}, + }, + NotBefore: now, + NotAfter: now.Add(cfg.Validity), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: cfg.DNSNames, + IPAddresses: cfg.IPAddresses, + } + + // Sign with CA + certDER, err := x509.CreateCertificate(rand.Reader, template, ca.cert, &key.PublicKey, ca.key) + if err != nil { + return nil, fmt.Errorf("failed to create server certificate: %w", err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, fmt.Errorf("failed to parse server certificate: %w", err) + } + + // Encode to PEM + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + bundle := &CertificateBundle{ + Certificate: cert, + PrivateKey: key, + CertificatePEM: certPEM, + PrivateKeyPEM: keyPEM, + } + + // Store certificate info + if ca.store != nil { + info := CertInfoFromX509(cert, CertTypeServer) + ca.store.AddCertificate(info) + } + + return bundle, nil +} + +// GenerateAgentCert creates a new agent certificate signed by the CA. +// The AgentID is used as the CommonName and is extracted during mTLS verification. +func (ca *CA) GenerateAgentCert(cfg AgentCertConfig) (*CertificateBundle, error) { + if cfg.AgentID == "" { + return nil, fmt.Errorf("agent ID is required") + } + if cfg.KeySize == 0 { + cfg.KeySize = DefaultKeySize + } + if cfg.Validity == 0 { + cfg.Validity = DefaultCertValidity + } + + // Generate private key + key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize) + if err != nil { + return nil, fmt.Errorf("failed to generate agent key: %w", err) + } + + // Generate serial number + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + now := time.Now() + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: cfg.AgentID, + Organization: []string{cfg.Organization}, + }, + NotBefore: now, + NotAfter: now.Add(cfg.Validity), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + // Sign with CA + certDER, err := x509.CreateCertificate(rand.Reader, template, ca.cert, &key.PublicKey, ca.key) + if err != nil { + return nil, fmt.Errorf("failed to create agent certificate: %w", err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, fmt.Errorf("failed to parse agent certificate: %w", err) + } + + // Encode to PEM + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + bundle := &CertificateBundle{ + Certificate: cert, + PrivateKey: key, + CertificatePEM: certPEM, + PrivateKeyPEM: keyPEM, + } + + // Store certificate info + if ca.store != nil { + info := CertInfoFromX509(cert, CertTypeAgent) + ca.store.AddCertificate(info) + } + + return bundle, nil +} + +// Info returns metadata about this certificate bundle. +func (b *CertificateBundle) Info(certType CertType) CertInfo { + return CertInfoFromX509(b.Certificate, certType) +} + +// SaveToFiles writes the certificate and key to separate files. +func (b *CertificateBundle) SaveToFiles(certPath, keyPath string) error { + if err := writeFile(certPath, b.CertificatePEM, 0644); err != nil { + return fmt.Errorf("failed to write certificate: %w", err) + } + + if err := writeFile(keyPath, b.PrivateKeyPEM, 0600); err != nil { + return fmt.Errorf("failed to write private key: %w", err) + } + + return nil +} + +// LoadCertificateBundle loads a certificate and key from PEM files. +func LoadCertificateBundle(certPath, keyPath string) (*CertificateBundle, error) { + certPEM, err := readFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read certificate: %w", err) + } + + keyPEM, err := readFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read private key: %w", err) + } + + // Parse certificate + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil { + return nil, fmt.Errorf("failed to decode certificate PEM") + } + + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + // Parse private key + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return nil, fmt.Errorf("failed to decode private key PEM") + } + + var key *rsa.PrivateKey + switch keyBlock.Type { + case "RSA PRIVATE KEY": + key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "PRIVATE KEY": + parsedKey, parseErr := x509.ParsePKCS8PrivateKey(keyBlock.Bytes) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse private key: %w", parseErr) + } + var ok bool + key, ok = parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is not RSA") + } + default: + return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type) + } + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return &CertificateBundle{ + Certificate: cert, + PrivateKey: key, + CertificatePEM: certPEM, + PrivateKeyPEM: keyPEM, + }, nil +} diff --git a/backend/internal/pki/crl.go b/backend/internal/pki/crl.go new file mode 100644 index 0000000..451d576 --- /dev/null +++ b/backend/internal/pki/crl.go @@ -0,0 +1,169 @@ +package pki + +import ( + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +// CRLValidity is how long a CRL is valid before it must be regenerated. +const CRLValidity = 7 * 24 * time.Hour // 7 days + +// RevokeCertificate revokes a certificate by its serial number. +func (ca *CA) RevokeCertificate(serialHex string) error { + if ca.store == nil { + return fmt.Errorf("certificate store not initialized") + } + + // Mark as revoked in store + if err := ca.store.MarkRevoked(serialHex); err != nil { + return fmt.Errorf("failed to mark certificate as revoked: %w", err) + } + + // Regenerate CRL + return ca.GenerateCRL() +} + +// GenerateCRL creates a new Certificate Revocation List. +func (ca *CA) GenerateCRL() error { + if ca.store == nil { + return fmt.Errorf("certificate store not initialized") + } + + // Get all revoked certificates + revokedSerials := ca.store.RevokedSerials() + + // Build revoked certificate list + revokedCerts := make([]pkix.RevokedCertificate, 0, len(revokedSerials)) + for _, serialHex := range revokedSerials { + serial := new(big.Int) + serial.SetString(serialHex, 16) + + info, _ := ca.store.GetCertificate(serialHex) + revocationTime := info.RevokedAt + if revocationTime.IsZero() { + revocationTime = time.Now() + } + + revokedCerts = append(revokedCerts, pkix.RevokedCertificate{ + SerialNumber: serial, + RevocationTime: revocationTime, + }) + } + + // Create CRL + now := time.Now() + crlTemplate := &x509.RevocationList{ + RevokedCertificates: revokedCerts, + Number: big.NewInt(now.Unix()), + ThisUpdate: now, + NextUpdate: now.Add(CRLValidity), + } + + crlDER, err := x509.CreateRevocationList(rand.Reader, crlTemplate, ca.cert, ca.key) + if err != nil { + return fmt.Errorf("failed to create CRL: %w", err) + } + + // Encode to PEM + crlPEM := pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: crlDER, + }) + + // Save to file + crlPath := filepath.Join(ca.store.dir, "ca.crl") + if err := os.WriteFile(crlPath, crlPEM, 0644); err != nil { + return fmt.Errorf("failed to write CRL: %w", err) + } + + return nil +} + +// LoadCRL loads the Certificate Revocation List from the CA directory. +func (ca *CA) LoadCRL() (*x509.RevocationList, error) { + if ca.store == nil { + return nil, fmt.Errorf("certificate store not initialized") + } + + crlPath := filepath.Join(ca.store.dir, "ca.crl") + crlPEM, err := os.ReadFile(crlPath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil // No CRL yet + } + return nil, fmt.Errorf("failed to read CRL: %w", err) + } + + block, _ := pem.Decode(crlPEM) + if block == nil { + return nil, fmt.Errorf("failed to decode CRL PEM") + } + + crl, err := x509.ParseRevocationList(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CRL: %w", err) + } + + return crl, nil +} + +// IsCertificateRevoked checks if a certificate is in the CRL. +func (ca *CA) IsCertificateRevoked(cert *x509.Certificate) (bool, error) { + // First check the store (faster) + if ca.store != nil { + serialHex := cert.SerialNumber.Text(16) + if ca.store.IsRevoked(serialHex) { + return true, nil + } + } + + // Also check CRL for completeness + crl, err := ca.LoadCRL() + if err != nil { + return false, err + } + if crl == nil { + return false, nil + } + + for _, revoked := range crl.RevokedCertificates { + if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { + return true, nil + } + } + + return false, nil +} + +// CRLInfo contains metadata about a CRL. +type CRLInfo struct { + Issuer string `json:"issuer"` + ThisUpdate time.Time `json:"thisUpdate"` + NextUpdate time.Time `json:"nextUpdate"` + RevokedCount int `json:"revokedCount"` +} + +// GetCRLInfo returns metadata about the current CRL. +func (ca *CA) GetCRLInfo() (*CRLInfo, error) { + crl, err := ca.LoadCRL() + if err != nil { + return nil, err + } + if crl == nil { + return nil, nil + } + + return &CRLInfo{ + Issuer: crl.Issuer.CommonName, + ThisUpdate: crl.ThisUpdate, + NextUpdate: crl.NextUpdate, + RevokedCount: len(crl.RevokedCertificates), + }, nil +} diff --git a/backend/internal/pki/pki.go b/backend/internal/pki/pki.go new file mode 100644 index 0000000..0a703ed --- /dev/null +++ b/backend/internal/pki/pki.go @@ -0,0 +1,52 @@ +// Package pki provides certificate authority and certificate management +// for mTLS authentication between Tyto servers and agents. +package pki + +import ( + "crypto/x509" + "time" +) + +// CertType identifies the purpose of a certificate. +type CertType string + +const ( + CertTypeCA CertType = "ca" + CertTypeServer CertType = "server" + CertTypeAgent CertType = "agent" +) + +// CertInfo contains metadata about a certificate. +type CertInfo struct { + SerialNumber string `json:"serialNumber"` + Subject string `json:"subject"` + Issuer string `json:"issuer"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` + Type CertType `json:"type"` + DNSNames []string `json:"dnsNames,omitempty"` + Revoked bool `json:"revoked"` + RevokedAt time.Time `json:"revokedAt,omitempty"` +} + +// CertInfoFromX509 extracts metadata from an X.509 certificate. +func CertInfoFromX509(cert *x509.Certificate, certType CertType) CertInfo { + return CertInfo{ + SerialNumber: cert.SerialNumber.Text(16), + Subject: cert.Subject.CommonName, + Issuer: cert.Issuer.CommonName, + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + Type: certType, + DNSNames: cert.DNSNames, + } +} + +// DefaultCAValidity is the default validity period for CA certificates. +const DefaultCAValidity = 10 * 365 * 24 * time.Hour // 10 years + +// DefaultCertValidity is the default validity period for server/agent certificates. +const DefaultCertValidity = 365 * 24 * time.Hour // 1 year + +// DefaultKeySize is the default RSA key size in bits. +const DefaultKeySize = 4096 diff --git a/backend/internal/pki/pki_test.go b/backend/internal/pki/pki_test.go new file mode 100644 index 0000000..0656926 --- /dev/null +++ b/backend/internal/pki/pki_test.go @@ -0,0 +1,341 @@ +package pki + +import ( + "crypto/tls" + "crypto/x509" + "net" + "os" + "path/filepath" + "testing" + "time" +) + +func TestInitCA(t *testing.T) { + cfg := CAConfig{ + CommonName: "Test CA", + Organization: "Test Org", + Country: "US", + Validity: 24 * time.Hour, + KeySize: 2048, // Smaller for faster tests + } + + ca, err := InitCA(cfg) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + + if ca.cert == nil { + t.Fatal("CA certificate is nil") + } + + if !ca.cert.IsCA { + t.Error("Certificate is not marked as CA") + } + + if ca.cert.Subject.CommonName != "Test CA" { + t.Errorf("Expected CN 'Test CA', got '%s'", ca.cert.Subject.CommonName) + } + + if len(ca.certPEM) == 0 { + t.Error("Certificate PEM is empty") + } + + if len(ca.keyPEM) == 0 { + t.Error("Key PEM is empty") + } +} + +func TestCASaveAndLoad(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA + cfg := CAConfig{ + CommonName: "Test CA", + KeySize: 2048, + } + ca, err := InitCA(cfg) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + + // Save to directory + if err := ca.SaveToDir(tmpDir); err != nil { + t.Fatalf("SaveToDir failed: %v", err) + } + + // Verify files exist + if _, err := os.Stat(filepath.Join(tmpDir, "ca.crt")); err != nil { + t.Error("ca.crt not created") + } + if _, err := os.Stat(filepath.Join(tmpDir, "ca.key")); err != nil { + t.Error("ca.key not created") + } + + // Load from directory + loadedCA, err := LoadCAFromDir(tmpDir) + if err != nil { + t.Fatalf("LoadCAFromDir failed: %v", err) + } + + if loadedCA.cert.Subject.CommonName != ca.cert.Subject.CommonName { + t.Errorf("Loaded CA CN mismatch: got '%s', want '%s'", + loadedCA.cert.Subject.CommonName, ca.cert.Subject.CommonName) + } +} + +func TestGenerateServerCert(t *testing.T) { + // Create CA + ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048}) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + + // Generate server certificate + cfg := ServerCertConfig{ + CommonName: "test-server", + DNSNames: []string{"localhost", "test.local"}, + IPAddresses: []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("::1"), + }, + Validity: 24 * time.Hour, + KeySize: 2048, + } + + bundle, err := ca.GenerateServerCert(cfg) + if err != nil { + t.Fatalf("GenerateServerCert failed: %v", err) + } + + // Verify certificate properties + cert := bundle.Certificate + + if cert.Subject.CommonName != "test-server" { + t.Errorf("Expected CN 'test-server', got '%s'", cert.Subject.CommonName) + } + + if len(cert.DNSNames) != 2 { + t.Errorf("Expected 2 DNS names, got %d", len(cert.DNSNames)) + } + + if len(cert.IPAddresses) != 2 { + t.Errorf("Expected 2 IP addresses, got %d", len(cert.IPAddresses)) + } + + // Verify it's signed by CA + roots := x509.NewCertPool() + roots.AddCert(ca.cert) + + _, err = cert.Verify(x509.VerifyOptions{ + Roots: roots, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }) + if err != nil { + t.Errorf("Certificate verification failed: %v", err) + } +} + +func TestGenerateAgentCert(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA with store + ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048}) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + if err := ca.SaveToDir(tmpDir); err != nil { + t.Fatalf("SaveToDir failed: %v", err) + } + + // Generate agent certificate + cfg := AgentCertConfig{ + AgentID: "agent-001", + Validity: 24 * time.Hour, + KeySize: 2048, + } + + bundle, err := ca.GenerateAgentCert(cfg) + if err != nil { + t.Fatalf("GenerateAgentCert failed: %v", err) + } + + // Verify certificate properties + cert := bundle.Certificate + + if cert.Subject.CommonName != "agent-001" { + t.Errorf("Expected CN 'agent-001', got '%s'", cert.Subject.CommonName) + } + + // Verify client auth usage + hasClientAuth := false + for _, usage := range cert.ExtKeyUsage { + if usage == x509.ExtKeyUsageClientAuth { + hasClientAuth = true + break + } + } + if !hasClientAuth { + t.Error("Certificate missing ClientAuth extended key usage") + } + + // Verify it's signed by CA + roots := x509.NewCertPool() + roots.AddCert(ca.cert) + + _, err = cert.Verify(x509.VerifyOptions{ + Roots: roots, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + if err != nil { + t.Errorf("Certificate verification failed: %v", err) + } + + // Verify it was recorded in store + certs := ca.Store().ListByType(CertTypeAgent) + if len(certs) != 1 { + t.Errorf("Expected 1 agent cert in store, got %d", len(certs)) + } +} + +func TestCertificateRevocation(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA with store + ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048}) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + if err := ca.SaveToDir(tmpDir); err != nil { + t.Fatalf("SaveToDir failed: %v", err) + } + + // Generate agent certificate + bundle, err := ca.GenerateAgentCert(AgentCertConfig{ + AgentID: "agent-to-revoke", + KeySize: 2048, + }) + if err != nil { + t.Fatalf("GenerateAgentCert failed: %v", err) + } + + serialHex := bundle.Certificate.SerialNumber.Text(16) + + // Verify not revoked initially + revoked, err := ca.IsCertificateRevoked(bundle.Certificate) + if err != nil { + t.Fatalf("IsCertificateRevoked failed: %v", err) + } + if revoked { + t.Error("Certificate should not be revoked initially") + } + + // Revoke the certificate + if err := ca.RevokeCertificate(serialHex); err != nil { + t.Fatalf("RevokeCertificate failed: %v", err) + } + + // Verify now revoked + revoked, err = ca.IsCertificateRevoked(bundle.Certificate) + if err != nil { + t.Fatalf("IsCertificateRevoked failed: %v", err) + } + if !revoked { + t.Error("Certificate should be revoked after RevokeCertificate") + } + + // Verify CRL was created + crlPath := filepath.Join(tmpDir, "ca.crl") + if _, err := os.Stat(crlPath); err != nil { + t.Error("CRL file not created") + } + + // Verify CRL info + crlInfo, err := ca.GetCRLInfo() + if err != nil { + t.Fatalf("GetCRLInfo failed: %v", err) + } + if crlInfo.RevokedCount != 1 { + t.Errorf("Expected 1 revoked cert in CRL, got %d", crlInfo.RevokedCount) + } +} + +func TestCertificateBundleSaveLoad(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA + ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048}) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + + // Generate certificate + bundle, err := ca.GenerateServerCert(ServerCertConfig{ + CommonName: "test-server", + KeySize: 2048, + }) + if err != nil { + t.Fatalf("GenerateServerCert failed: %v", err) + } + + // Save to files + certPath := filepath.Join(tmpDir, "server.crt") + keyPath := filepath.Join(tmpDir, "server.key") + + if err := bundle.SaveToFiles(certPath, keyPath); err != nil { + t.Fatalf("SaveToFiles failed: %v", err) + } + + // Load from files + loaded, err := LoadCertificateBundle(certPath, keyPath) + if err != nil { + t.Fatalf("LoadCertificateBundle failed: %v", err) + } + + if loaded.Certificate.Subject.CommonName != bundle.Certificate.Subject.CommonName { + t.Error("Loaded certificate CN mismatch") + } +} + +func TestServerTLSConfig(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA + ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048}) + if err != nil { + t.Fatalf("InitCA failed: %v", err) + } + if err := ca.SaveToDir(tmpDir); err != nil { + t.Fatalf("SaveToDir failed: %v", err) + } + + // Generate server certificate + serverBundle, err := ca.GenerateServerCert(ServerCertConfig{ + CommonName: "test-server", + DNSNames: []string{"localhost"}, + KeySize: 2048, + }) + if err != nil { + t.Fatalf("GenerateServerCert failed: %v", err) + } + + serverCertPath := filepath.Join(tmpDir, "server.crt") + serverKeyPath := filepath.Join(tmpDir, "server.key") + if err := serverBundle.SaveToFiles(serverCertPath, serverKeyPath); err != nil { + t.Fatalf("SaveToFiles failed: %v", err) + } + + // Create TLS config + caCertPath := filepath.Join(tmpDir, "ca.crt") + tlsConfig, err := ServerTLSConfig(caCertPath, serverCertPath, serverKeyPath) + if err != nil { + t.Fatalf("ServerTLSConfig failed: %v", err) + } + + if tlsConfig.ClientAuth != tls.RequireAndVerifyClientCert { + t.Error("Expected RequireAndVerifyClientCert") + } + + if tlsConfig.MinVersion != tls.VersionTLS12 { + t.Error("Expected TLS 1.2 minimum") + } +} diff --git a/backend/internal/pki/store.go b/backend/internal/pki/store.go new file mode 100644 index 0000000..ddbcc15 --- /dev/null +++ b/backend/internal/pki/store.go @@ -0,0 +1,170 @@ +package pki + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" +) + +// Store manages certificate records on disk. +type Store struct { + dir string + certs map[string]CertInfo // serial -> info + mu sync.RWMutex +} + +// storeData is the JSON structure for persistence. +type storeData struct { + Certificates []CertInfo `json:"certificates"` +} + +// NewStore creates a new certificate store in the given directory. +func NewStore(dir string) *Store { + s := &Store{ + dir: dir, + certs: make(map[string]CertInfo), + } + s.load() + return s +} + +// load reads the store from disk. +func (s *Store) load() error { + path := filepath.Join(s.dir, "certs.json") + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var sd storeData + if err := json.Unmarshal(data, &sd); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + for _, cert := range sd.Certificates { + s.certs[cert.SerialNumber] = cert + } + + return nil +} + +// save writes the store to disk. +func (s *Store) save() error { + s.mu.RLock() + certs := make([]CertInfo, 0, len(s.certs)) + for _, cert := range s.certs { + certs = append(certs, cert) + } + s.mu.RUnlock() + + sd := storeData{Certificates: certs} + data, err := json.MarshalIndent(sd, "", " ") + if err != nil { + return err + } + + path := filepath.Join(s.dir, "certs.json") + return os.WriteFile(path, data, 0600) +} + +// AddCertificate adds a certificate to the store. +func (s *Store) AddCertificate(info CertInfo) error { + s.mu.Lock() + s.certs[info.SerialNumber] = info + s.mu.Unlock() + + return s.save() +} + +// GetCertificate returns a certificate by serial number. +func (s *Store) GetCertificate(serial string) (CertInfo, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + info, ok := s.certs[serial] + return info, ok +} + +// ListCertificates returns all certificates in the store. +func (s *Store) ListCertificates() []CertInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + certs := make([]CertInfo, 0, len(s.certs)) + for _, cert := range s.certs { + certs = append(certs, cert) + } + return certs +} + +// ListByType returns certificates of a specific type. +func (s *Store) ListByType(certType CertType) []CertInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + var certs []CertInfo + for _, cert := range s.certs { + if cert.Type == certType { + certs = append(certs, cert) + } + } + return certs +} + +// MarkRevoked marks a certificate as revoked. +func (s *Store) MarkRevoked(serial string) error { + s.mu.Lock() + if info, ok := s.certs[serial]; ok { + info.Revoked = true + s.certs[serial] = info + } + s.mu.Unlock() + + return s.save() +} + +// IsRevoked checks if a certificate is revoked. +func (s *Store) IsRevoked(serial string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if info, ok := s.certs[serial]; ok { + return info.Revoked + } + return false +} + +// RevokedSerials returns all revoked certificate serial numbers. +func (s *Store) RevokedSerials() []string { + s.mu.RLock() + defer s.mu.RUnlock() + + var serials []string + for serial, cert := range s.certs { + if cert.Revoked { + serials = append(serials, serial) + } + } + return serials +} + +// writeFile is a helper that creates parent directories. +func writeFile(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + return os.WriteFile(path, data, perm) +} + +// readFile is a helper for reading files. +func readFile(path string) ([]byte, error) { + return os.ReadFile(path) +} diff --git a/backend/internal/pki/tls.go b/backend/internal/pki/tls.go new file mode 100644 index 0000000..319a49e --- /dev/null +++ b/backend/internal/pki/tls.go @@ -0,0 +1,118 @@ +package pki + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" +) + +// ServerTLSConfig creates a TLS configuration for a server with mTLS. +// This requires clients to present a valid certificate signed by the CA. +func ServerTLSConfig(caCertPath, serverCertPath, serverKeyPath string) (*tls.Config, error) { + // Load CA certificate + caCert, err := os.ReadFile(caCertPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + + // Load server certificate + serverCert, err := tls.LoadX509KeyPair(serverCertPath, serverKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load server certificate: %w", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, + }, nil +} + +// ClientTLSConfig creates a TLS configuration for a client with mTLS. +// This presents the client certificate and verifies the server's certificate. +func ClientTLSConfig(caCertPath, clientCertPath, clientKeyPath string) (*tls.Config, error) { + // Load CA certificate + caCert, err := os.ReadFile(caCertPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + + // Load client certificate + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %w", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + }, nil +} + +// ServerTLSConfigWithCRL creates a server TLS config that checks CRL for revoked certs. +func ServerTLSConfigWithCRL(ca *CA, serverCertPath, serverKeyPath string) (*tls.Config, error) { + // Load server certificate + serverCert, err := tls.LoadX509KeyPair(serverCertPath, serverKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load server certificate: %w", err) + } + + // Create CA pool + caCertPool := x509.NewCertPool() + caCertPool.AddCert(ca.cert) + + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + // Check each certificate in the chain against CRL + for _, chain := range verifiedChains { + for _, cert := range chain { + if cert.IsCA { + continue // Skip CA certs + } + + revoked, err := ca.IsCertificateRevoked(cert) + if err != nil { + return fmt.Errorf("failed to check revocation status: %w", err) + } + if revoked { + return fmt.Errorf("certificate %s has been revoked", cert.SerialNumber.Text(16)) + } + } + } + return nil + }, + }, nil +} + +// ExtractAgentID extracts the agent ID from a verified client certificate. +// The agent ID is stored in the certificate's CommonName. +func ExtractAgentID(state *tls.ConnectionState) (string, error) { + if len(state.VerifiedChains) == 0 || len(state.VerifiedChains[0]) == 0 { + return "", fmt.Errorf("no verified certificate chain") + } + + cert := state.VerifiedChains[0][0] + agentID := cert.Subject.CommonName + if agentID == "" { + return "", fmt.Errorf("certificate has no CommonName") + } + + return agentID, nil +}