feat: add PKI infrastructure for mTLS authentication
PKI Package (internal/pki/): - CA initialization with configurable validity and key size - Server certificate generation with DNS/IP SANs - Agent certificate generation (agent ID in CN) - Certificate revocation list (CRL) support - mTLS TLS configuration helpers - File-based certificate store with JSON persistence CLI Commands (cmd/tyto/): - `tyto pki init-ca` - Initialize new Certificate Authority - `tyto pki gen-server` - Generate server certificate - `tyto pki gen-agent` - Generate agent certificate - `tyto pki revoke` - Revoke certificate by serial - `tyto pki list` - List all certificates - `tyto pki info` - Show CA information Security Features: - RSA 4096-bit keys by default - TLS 1.2 minimum version - Client certificate verification for mTLS - CRL checking in TLS handshake - Agent ID extraction from verified certificates 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
94
backend/cmd/tyto/main.go
Normal file
94
backend/cmd/tyto/main.go
Normal file
@@ -0,0 +1,94 @@
|
||||
// Package main provides the Tyto CLI with subcommands for PKI management,
|
||||
// server operation, and agent operation.
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
const usage = `Tyto - System Monitoring
|
||||
|
||||
Usage:
|
||||
tyto <command> [options]
|
||||
|
||||
Commands:
|
||||
server Start the Tyto server (default if no command given)
|
||||
agent Start as a Tyto agent
|
||||
pki PKI/certificate management
|
||||
|
||||
Run 'tyto <command> --help' for more information on a command.
|
||||
`
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
// Default to server command
|
||||
runServer(os.Args[1:])
|
||||
return
|
||||
}
|
||||
|
||||
switch os.Args[1] {
|
||||
case "server":
|
||||
runServer(os.Args[2:])
|
||||
case "agent":
|
||||
runAgent(os.Args[2:])
|
||||
case "pki":
|
||||
runPKI(os.Args[2:])
|
||||
case "-h", "--help", "help":
|
||||
fmt.Print(usage)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s\n\n", os.Args[1])
|
||||
fmt.Print(usage)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runServer(args []string) {
|
||||
// For now, delegate to the existing server main
|
||||
// In a full implementation, this would parse server-specific flags
|
||||
fmt.Println("Starting Tyto server...")
|
||||
fmt.Println("Use 'tyto server --help' for options")
|
||||
|
||||
// Import and call the actual server logic
|
||||
// For now, just print usage
|
||||
if len(args) > 0 && (args[0] == "-h" || args[0] == "--help") {
|
||||
fmt.Print(`Usage: tyto server [options]
|
||||
|
||||
Options:
|
||||
--port PORT HTTP port (default: 8080)
|
||||
--mode MODE Operating mode: standalone, server (default: standalone)
|
||||
--config FILE Config file path
|
||||
|
||||
Environment:
|
||||
TYTO_MODE Operating mode
|
||||
PORT HTTP port
|
||||
TYTO_CONFIG Config file path
|
||||
`)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Server mode not implemented in tyto CLI yet.")
|
||||
fmt.Println("Use the 'server' binary directly or set TYTO_MODE environment variable.")
|
||||
}
|
||||
|
||||
func runAgent(args []string) {
|
||||
if len(args) > 0 && (args[0] == "-h" || args[0] == "--help") {
|
||||
fmt.Print(`Usage: tyto agent [options]
|
||||
|
||||
Options:
|
||||
--id ID Agent identifier (required)
|
||||
--server URL Server URL to report to (required)
|
||||
--interval DURATION Collection interval (default: 5s)
|
||||
--ca-cert FILE CA certificate file
|
||||
--cert FILE Agent certificate file
|
||||
--key FILE Agent key file
|
||||
|
||||
Environment:
|
||||
TYTO_AGENT_ID Agent identifier
|
||||
TYTO_SERVER_URL Server URL
|
||||
`)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Agent mode not implemented yet.")
|
||||
}
|
||||
375
backend/cmd/tyto/pki.go
Normal file
375
backend/cmd/tyto/pki.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tyto/internal/pki"
|
||||
)
|
||||
|
||||
const pkiUsage = `PKI/Certificate Management
|
||||
|
||||
Usage:
|
||||
tyto pki <subcommand> [options]
|
||||
|
||||
Subcommands:
|
||||
init-ca Initialize a new Certificate Authority
|
||||
gen-server Generate a server certificate
|
||||
gen-agent Generate an agent certificate
|
||||
revoke Revoke a certificate
|
||||
list List all certificates
|
||||
info Show CA information
|
||||
|
||||
Run 'tyto pki <subcommand> --help' for more information.
|
||||
`
|
||||
|
||||
func runPKI(args []string) {
|
||||
if len(args) == 0 {
|
||||
fmt.Print(pkiUsage)
|
||||
return
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "init-ca":
|
||||
pkiInitCA(args[1:])
|
||||
case "gen-server":
|
||||
pkiGenServer(args[1:])
|
||||
case "gen-agent":
|
||||
pkiGenAgent(args[1:])
|
||||
case "revoke":
|
||||
pkiRevoke(args[1:])
|
||||
case "list":
|
||||
pkiList(args[1:])
|
||||
case "info":
|
||||
pkiInfo(args[1:])
|
||||
case "-h", "--help", "help":
|
||||
fmt.Print(pkiUsage)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown pki subcommand: %s\n\n", args[0])
|
||||
fmt.Print(pkiUsage)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func pkiInitCA(args []string) {
|
||||
fs := flag.NewFlagSet("pki init-ca", flag.ExitOnError)
|
||||
cn := fs.String("cn", "Tyto CA", "Common Name for the CA")
|
||||
org := fs.String("org", "Tyto", "Organization name")
|
||||
country := fs.String("country", "US", "Country code")
|
||||
outDir := fs.String("out", "/etc/tyto/pki", "Output directory for CA files")
|
||||
validity := fs.Duration("validity", 10*365*24*time.Hour, "CA validity period")
|
||||
|
||||
fs.Parse(args)
|
||||
|
||||
fmt.Printf("Initializing CA in %s...\n", *outDir)
|
||||
|
||||
cfg := pki.CAConfig{
|
||||
CommonName: *cn,
|
||||
Organization: *org,
|
||||
Country: *country,
|
||||
Validity: *validity,
|
||||
}
|
||||
|
||||
ca, err := pki.InitCA(cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := ca.SaveToDir(*outDir); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving CA: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
info := ca.Info()
|
||||
fmt.Println("CA initialized successfully!")
|
||||
fmt.Printf(" Subject: %s\n", info.Subject)
|
||||
fmt.Printf(" Serial: %s\n", info.SerialNumber)
|
||||
fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339))
|
||||
fmt.Printf(" Certificate: %s\n", filepath.Join(*outDir, "ca.crt"))
|
||||
fmt.Printf(" Private key: %s\n", filepath.Join(*outDir, "ca.key"))
|
||||
}
|
||||
|
||||
func pkiGenServer(args []string) {
|
||||
fs := flag.NewFlagSet("pki gen-server", flag.ExitOnError)
|
||||
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
|
||||
cn := fs.String("cn", "", "Common Name (required)")
|
||||
dnsNames := fs.String("dns", "", "DNS names (comma-separated)")
|
||||
ips := fs.String("ip", "", "IP addresses (comma-separated)")
|
||||
outDir := fs.String("out", "", "Output directory (default: ca-dir/certs)")
|
||||
validity := fs.Duration("validity", 365*24*time.Hour, "Certificate validity")
|
||||
|
||||
fs.Parse(args)
|
||||
|
||||
if *cn == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --cn is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *outDir == "" {
|
||||
*outDir = filepath.Join(*caDir, "certs")
|
||||
}
|
||||
|
||||
// Load CA
|
||||
ca, err := pki.LoadCAFromDir(*caDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Parse DNS names
|
||||
var dns []string
|
||||
if *dnsNames != "" {
|
||||
dns = strings.Split(*dnsNames, ",")
|
||||
for i := range dns {
|
||||
dns[i] = strings.TrimSpace(dns[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse IP addresses
|
||||
var ipAddrs []net.IP
|
||||
if *ips != "" {
|
||||
for _, ipStr := range strings.Split(*ips, ",") {
|
||||
ip := net.ParseIP(strings.TrimSpace(ipStr))
|
||||
if ip != nil {
|
||||
ipAddrs = append(ipAddrs, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg := pki.ServerCertConfig{
|
||||
CommonName: *cn,
|
||||
DNSNames: dns,
|
||||
IPAddresses: ipAddrs,
|
||||
Validity: *validity,
|
||||
}
|
||||
|
||||
bundle, err := ca.GenerateServerCert(cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating certificate: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Save certificate
|
||||
certPath := filepath.Join(*outDir, *cn+".crt")
|
||||
keyPath := filepath.Join(*outDir, *cn+".key")
|
||||
|
||||
if err := bundle.SaveToFiles(certPath, keyPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving certificate: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
info := bundle.Info(pki.CertTypeServer)
|
||||
fmt.Println("Server certificate generated!")
|
||||
fmt.Printf(" Subject: %s\n", info.Subject)
|
||||
fmt.Printf(" Serial: %s\n", info.SerialNumber)
|
||||
fmt.Printf(" DNS names: %v\n", dns)
|
||||
fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339))
|
||||
fmt.Printf(" Certificate: %s\n", certPath)
|
||||
fmt.Printf(" Private key: %s\n", keyPath)
|
||||
}
|
||||
|
||||
func pkiGenAgent(args []string) {
|
||||
fs := flag.NewFlagSet("pki gen-agent", flag.ExitOnError)
|
||||
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
|
||||
agentID := fs.String("agent-id", "", "Agent ID (required, used as CN)")
|
||||
outDir := fs.String("out", "", "Output directory (default: ca-dir/agents)")
|
||||
validity := fs.Duration("validity", 365*24*time.Hour, "Certificate validity")
|
||||
|
||||
fs.Parse(args)
|
||||
|
||||
if *agentID == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --agent-id is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *outDir == "" {
|
||||
*outDir = filepath.Join(*caDir, "agents")
|
||||
}
|
||||
|
||||
// Load CA
|
||||
ca, err := pki.LoadCAFromDir(*caDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg := pki.AgentCertConfig{
|
||||
AgentID: *agentID,
|
||||
Validity: *validity,
|
||||
}
|
||||
|
||||
bundle, err := ca.GenerateAgentCert(cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating certificate: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Save certificate
|
||||
certPath := filepath.Join(*outDir, *agentID+".crt")
|
||||
keyPath := filepath.Join(*outDir, *agentID+".key")
|
||||
|
||||
if err := bundle.SaveToFiles(certPath, keyPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving certificate: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
info := bundle.Info(pki.CertTypeAgent)
|
||||
fmt.Println("Agent certificate generated!")
|
||||
fmt.Printf(" Agent ID: %s\n", *agentID)
|
||||
fmt.Printf(" Serial: %s\n", info.SerialNumber)
|
||||
fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339))
|
||||
fmt.Printf(" Certificate: %s\n", certPath)
|
||||
fmt.Printf(" Private key: %s\n", keyPath)
|
||||
}
|
||||
|
||||
func pkiRevoke(args []string) {
|
||||
fs := flag.NewFlagSet("pki revoke", flag.ExitOnError)
|
||||
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
|
||||
serial := fs.String("serial", "", "Certificate serial number to revoke (required)")
|
||||
|
||||
fs.Parse(args)
|
||||
|
||||
if *serial == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --serial is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Load CA
|
||||
ca, err := pki.LoadCAFromDir(*caDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := ca.RevokeCertificate(*serial); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error revoking certificate: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Certificate %s has been revoked.\n", *serial)
|
||||
fmt.Printf("CRL updated: %s\n", filepath.Join(*caDir, "ca.crl"))
|
||||
}
|
||||
|
||||
func pkiList(args []string) {
|
||||
fs := flag.NewFlagSet("pki list", flag.ExitOnError)
|
||||
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
|
||||
certType := fs.String("type", "", "Filter by type: ca, server, agent")
|
||||
|
||||
fs.Parse(args)
|
||||
|
||||
// Load CA
|
||||
ca, err := pki.LoadCAFromDir(*caDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
store := ca.Store()
|
||||
if store == nil {
|
||||
fmt.Println("No certificates found.")
|
||||
return
|
||||
}
|
||||
|
||||
var certs []pki.CertInfo
|
||||
switch *certType {
|
||||
case "server":
|
||||
certs = store.ListByType(pki.CertTypeServer)
|
||||
case "agent":
|
||||
certs = store.ListByType(pki.CertTypeAgent)
|
||||
case "":
|
||||
certs = store.ListCertificates()
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown type: %s\n", *certType)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(certs) == 0 {
|
||||
fmt.Println("No certificates found.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("%-16s %-12s %-20s %-8s %s\n", "SERIAL", "TYPE", "SUBJECT", "REVOKED", "EXPIRES")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
|
||||
for _, cert := range certs {
|
||||
revoked := ""
|
||||
if cert.Revoked {
|
||||
revoked = "YES"
|
||||
}
|
||||
serial := cert.SerialNumber
|
||||
if len(serial) > 14 {
|
||||
serial = serial[:14] + ".."
|
||||
}
|
||||
fmt.Printf("%-16s %-12s %-20s %-8s %s\n",
|
||||
serial,
|
||||
cert.Type,
|
||||
cert.Subject,
|
||||
revoked,
|
||||
cert.NotAfter.Format("2006-01-02"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func pkiInfo(args []string) {
|
||||
fs := flag.NewFlagSet("pki info", flag.ExitOnError)
|
||||
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
|
||||
|
||||
fs.Parse(args)
|
||||
|
||||
// Load CA
|
||||
ca, err := pki.LoadCAFromDir(*caDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
info := ca.Info()
|
||||
fmt.Println("Certificate Authority Information")
|
||||
fmt.Println(strings.Repeat("=", 40))
|
||||
fmt.Printf("Subject: %s\n", info.Subject)
|
||||
fmt.Printf("Serial Number: %s\n", info.SerialNumber)
|
||||
fmt.Printf("Valid From: %s\n", info.NotBefore.Format(time.RFC3339))
|
||||
fmt.Printf("Valid Until: %s\n", info.NotAfter.Format(time.RFC3339))
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(info.NotAfter) {
|
||||
fmt.Println("\n⚠️ WARNING: CA certificate has expired!")
|
||||
} else {
|
||||
remaining := time.Until(info.NotAfter)
|
||||
fmt.Printf("Remaining: %d days\n", int(remaining.Hours()/24))
|
||||
}
|
||||
|
||||
// CRL info
|
||||
crlInfo, err := ca.GetCRLInfo()
|
||||
if err != nil {
|
||||
fmt.Printf("\nCRL: Error loading: %v\n", err)
|
||||
} else if crlInfo == nil {
|
||||
fmt.Println("\nCRL: Not generated yet")
|
||||
} else {
|
||||
fmt.Println("\nCRL Information")
|
||||
fmt.Println(strings.Repeat("-", 40))
|
||||
fmt.Printf("Last Update: %s\n", crlInfo.ThisUpdate.Format(time.RFC3339))
|
||||
fmt.Printf("Next Update: %s\n", crlInfo.NextUpdate.Format(time.RFC3339))
|
||||
fmt.Printf("Revoked Certs: %d\n", crlInfo.RevokedCount)
|
||||
}
|
||||
|
||||
// Certificate counts
|
||||
store := ca.Store()
|
||||
if store != nil {
|
||||
servers := store.ListByType(pki.CertTypeServer)
|
||||
agents := store.ListByType(pki.CertTypeAgent)
|
||||
fmt.Println("\nCertificate Counts")
|
||||
fmt.Println(strings.Repeat("-", 40))
|
||||
fmt.Printf("Server Certs: %d\n", len(servers))
|
||||
fmt.Printf("Agent Certs: %d\n", len(agents))
|
||||
}
|
||||
}
|
||||
231
backend/internal/pki/ca.go
Normal file
231
backend/internal/pki/ca.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CA represents a Certificate Authority for signing certificates.
|
||||
type CA struct {
|
||||
cert *x509.Certificate
|
||||
key *rsa.PrivateKey
|
||||
certPEM []byte
|
||||
keyPEM []byte
|
||||
store *Store
|
||||
}
|
||||
|
||||
// CAConfig contains options for creating a new CA.
|
||||
type CAConfig struct {
|
||||
CommonName string
|
||||
Organization string
|
||||
Country string
|
||||
Validity time.Duration
|
||||
KeySize int
|
||||
}
|
||||
|
||||
// DefaultCAConfig returns a CAConfig with sensible defaults.
|
||||
func DefaultCAConfig() CAConfig {
|
||||
return CAConfig{
|
||||
CommonName: "Tyto CA",
|
||||
Organization: "Tyto",
|
||||
Country: "US",
|
||||
Validity: DefaultCAValidity,
|
||||
KeySize: DefaultKeySize,
|
||||
}
|
||||
}
|
||||
|
||||
// InitCA creates a new Certificate Authority with the given configuration.
|
||||
func InitCA(cfg CAConfig) (*CA, error) {
|
||||
if cfg.KeySize == 0 {
|
||||
cfg.KeySize = DefaultKeySize
|
||||
}
|
||||
if cfg.Validity == 0 {
|
||||
cfg.Validity = DefaultCAValidity
|
||||
}
|
||||
|
||||
// Generate private key
|
||||
key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CA key: %w", err)
|
||||
}
|
||||
|
||||
// Generate serial number
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: cfg.CommonName,
|
||||
Organization: []string{cfg.Organization},
|
||||
Country: []string{cfg.Country},
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(cfg.Validity),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 1,
|
||||
}
|
||||
|
||||
// Self-sign the certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CA certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
// Encode to PEM
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
})
|
||||
|
||||
return &CA{
|
||||
cert: cert,
|
||||
key: key,
|
||||
certPEM: certPEM,
|
||||
keyPEM: keyPEM,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadCA loads an existing CA from PEM-encoded certificate and key.
|
||||
func LoadCA(certPEM, keyPEM []byte) (*CA, error) {
|
||||
// Parse certificate
|
||||
certBlock, _ := pem.Decode(certPEM)
|
||||
if certBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA certificate PEM")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
if !cert.IsCA {
|
||||
return nil, fmt.Errorf("certificate is not a CA")
|
||||
}
|
||||
|
||||
// Parse private key
|
||||
keyBlock, _ := pem.Decode(keyPEM)
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA key PEM")
|
||||
}
|
||||
|
||||
var key *rsa.PrivateKey
|
||||
switch keyBlock.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
parsedKey, parseErr := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("failed to parse CA key: %w", parseErr)
|
||||
}
|
||||
var ok bool
|
||||
key, ok = parsedKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("CA key is not RSA")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CA key: %w", err)
|
||||
}
|
||||
|
||||
return &CA{
|
||||
cert: cert,
|
||||
key: key,
|
||||
certPEM: certPEM,
|
||||
keyPEM: keyPEM,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadCAFromDir loads a CA from a directory containing ca.crt and ca.key files.
|
||||
func LoadCAFromDir(dir string) (*CA, error) {
|
||||
certPath := filepath.Join(dir, "ca.crt")
|
||||
keyPath := filepath.Join(dir, "ca.key")
|
||||
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
|
||||
}
|
||||
|
||||
keyPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA key: %w", err)
|
||||
}
|
||||
|
||||
ca, err := LoadCA(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize store for this directory
|
||||
ca.store = NewStore(dir)
|
||||
|
||||
return ca, nil
|
||||
}
|
||||
|
||||
// SaveToDir saves the CA certificate and key to a directory.
|
||||
func (ca *CA) SaveToDir(dir string) error {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create CA directory: %w", err)
|
||||
}
|
||||
|
||||
certPath := filepath.Join(dir, "ca.crt")
|
||||
keyPath := filepath.Join(dir, "ca.key")
|
||||
|
||||
if err := os.WriteFile(certPath, ca.certPEM, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write CA certificate: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(keyPath, ca.keyPEM, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write CA key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize store
|
||||
ca.store = NewStore(dir)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Certificate returns the CA certificate.
|
||||
func (ca *CA) Certificate() *x509.Certificate {
|
||||
return ca.cert
|
||||
}
|
||||
|
||||
// CertificatePEM returns the PEM-encoded CA certificate.
|
||||
func (ca *CA) CertificatePEM() []byte {
|
||||
return ca.certPEM
|
||||
}
|
||||
|
||||
// Info returns metadata about the CA certificate.
|
||||
func (ca *CA) Info() CertInfo {
|
||||
return CertInfoFromX509(ca.cert, CertTypeCA)
|
||||
}
|
||||
|
||||
// Store returns the certificate store associated with this CA.
|
||||
func (ca *CA) Store() *Store {
|
||||
return ca.store
|
||||
}
|
||||
270
backend/internal/pki/cert.go
Normal file
270
backend/internal/pki/cert.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertificateBundle contains a certificate and its private key.
|
||||
type CertificateBundle struct {
|
||||
Certificate *x509.Certificate
|
||||
PrivateKey *rsa.PrivateKey
|
||||
CertificatePEM []byte
|
||||
PrivateKeyPEM []byte
|
||||
}
|
||||
|
||||
// ServerCertConfig contains options for generating a server certificate.
|
||||
type ServerCertConfig struct {
|
||||
CommonName string
|
||||
Organization string
|
||||
DNSNames []string
|
||||
IPAddresses []net.IP
|
||||
Validity time.Duration
|
||||
KeySize int
|
||||
}
|
||||
|
||||
// AgentCertConfig contains options for generating an agent certificate.
|
||||
type AgentCertConfig struct {
|
||||
AgentID string // Used as CommonName
|
||||
Organization string
|
||||
Validity time.Duration
|
||||
KeySize int
|
||||
}
|
||||
|
||||
// GenerateServerCert creates a new server certificate signed by the CA.
|
||||
func (ca *CA) GenerateServerCert(cfg ServerCertConfig) (*CertificateBundle, error) {
|
||||
if cfg.CommonName == "" {
|
||||
return nil, fmt.Errorf("common name is required")
|
||||
}
|
||||
if cfg.KeySize == 0 {
|
||||
cfg.KeySize = DefaultKeySize
|
||||
}
|
||||
if cfg.Validity == 0 {
|
||||
cfg.Validity = DefaultCertValidity
|
||||
}
|
||||
|
||||
// Generate private key
|
||||
key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate server key: %w", err)
|
||||
}
|
||||
|
||||
// Generate serial number
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: cfg.CommonName,
|
||||
Organization: []string{cfg.Organization},
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(cfg.Validity),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: cfg.DNSNames,
|
||||
IPAddresses: cfg.IPAddresses,
|
||||
}
|
||||
|
||||
// Sign with CA
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, ca.cert, &key.PublicKey, ca.key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create server certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server certificate: %w", err)
|
||||
}
|
||||
|
||||
// Encode to PEM
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
})
|
||||
|
||||
bundle := &CertificateBundle{
|
||||
Certificate: cert,
|
||||
PrivateKey: key,
|
||||
CertificatePEM: certPEM,
|
||||
PrivateKeyPEM: keyPEM,
|
||||
}
|
||||
|
||||
// Store certificate info
|
||||
if ca.store != nil {
|
||||
info := CertInfoFromX509(cert, CertTypeServer)
|
||||
ca.store.AddCertificate(info)
|
||||
}
|
||||
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
// GenerateAgentCert creates a new agent certificate signed by the CA.
|
||||
// The AgentID is used as the CommonName and is extracted during mTLS verification.
|
||||
func (ca *CA) GenerateAgentCert(cfg AgentCertConfig) (*CertificateBundle, error) {
|
||||
if cfg.AgentID == "" {
|
||||
return nil, fmt.Errorf("agent ID is required")
|
||||
}
|
||||
if cfg.KeySize == 0 {
|
||||
cfg.KeySize = DefaultKeySize
|
||||
}
|
||||
if cfg.Validity == 0 {
|
||||
cfg.Validity = DefaultCertValidity
|
||||
}
|
||||
|
||||
// Generate private key
|
||||
key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate agent key: %w", err)
|
||||
}
|
||||
|
||||
// Generate serial number
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: cfg.AgentID,
|
||||
Organization: []string{cfg.Organization},
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(cfg.Validity),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
// Sign with CA
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, ca.cert, &key.PublicKey, ca.key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse agent certificate: %w", err)
|
||||
}
|
||||
|
||||
// Encode to PEM
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
})
|
||||
|
||||
bundle := &CertificateBundle{
|
||||
Certificate: cert,
|
||||
PrivateKey: key,
|
||||
CertificatePEM: certPEM,
|
||||
PrivateKeyPEM: keyPEM,
|
||||
}
|
||||
|
||||
// Store certificate info
|
||||
if ca.store != nil {
|
||||
info := CertInfoFromX509(cert, CertTypeAgent)
|
||||
ca.store.AddCertificate(info)
|
||||
}
|
||||
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
// Info returns metadata about this certificate bundle.
|
||||
func (b *CertificateBundle) Info(certType CertType) CertInfo {
|
||||
return CertInfoFromX509(b.Certificate, certType)
|
||||
}
|
||||
|
||||
// SaveToFiles writes the certificate and key to separate files.
|
||||
func (b *CertificateBundle) SaveToFiles(certPath, keyPath string) error {
|
||||
if err := writeFile(certPath, b.CertificatePEM, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write certificate: %w", err)
|
||||
}
|
||||
|
||||
if err := writeFile(keyPath, b.PrivateKeyPEM, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadCertificateBundle loads a certificate and key from PEM files.
|
||||
func LoadCertificateBundle(certPath, keyPath string) (*CertificateBundle, error) {
|
||||
certPEM, err := readFile(certPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read certificate: %w", err)
|
||||
}
|
||||
|
||||
keyPEM, err := readFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key: %w", err)
|
||||
}
|
||||
|
||||
// Parse certificate
|
||||
certBlock, _ := pem.Decode(certPEM)
|
||||
if certBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate PEM")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
// Parse private key
|
||||
keyBlock, _ := pem.Decode(keyPEM)
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode private key PEM")
|
||||
}
|
||||
|
||||
var key *rsa.PrivateKey
|
||||
switch keyBlock.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
parsedKey, parseErr := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", parseErr)
|
||||
}
|
||||
var ok bool
|
||||
key, ok = parsedKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private key is not RSA")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
return &CertificateBundle{
|
||||
Certificate: cert,
|
||||
PrivateKey: key,
|
||||
CertificatePEM: certPEM,
|
||||
PrivateKeyPEM: keyPEM,
|
||||
}, nil
|
||||
}
|
||||
169
backend/internal/pki/crl.go
Normal file
169
backend/internal/pki/crl.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CRLValidity is how long a CRL is valid before it must be regenerated.
|
||||
const CRLValidity = 7 * 24 * time.Hour // 7 days
|
||||
|
||||
// RevokeCertificate revokes a certificate by its serial number.
|
||||
func (ca *CA) RevokeCertificate(serialHex string) error {
|
||||
if ca.store == nil {
|
||||
return fmt.Errorf("certificate store not initialized")
|
||||
}
|
||||
|
||||
// Mark as revoked in store
|
||||
if err := ca.store.MarkRevoked(serialHex); err != nil {
|
||||
return fmt.Errorf("failed to mark certificate as revoked: %w", err)
|
||||
}
|
||||
|
||||
// Regenerate CRL
|
||||
return ca.GenerateCRL()
|
||||
}
|
||||
|
||||
// GenerateCRL creates a new Certificate Revocation List.
|
||||
func (ca *CA) GenerateCRL() error {
|
||||
if ca.store == nil {
|
||||
return fmt.Errorf("certificate store not initialized")
|
||||
}
|
||||
|
||||
// Get all revoked certificates
|
||||
revokedSerials := ca.store.RevokedSerials()
|
||||
|
||||
// Build revoked certificate list
|
||||
revokedCerts := make([]pkix.RevokedCertificate, 0, len(revokedSerials))
|
||||
for _, serialHex := range revokedSerials {
|
||||
serial := new(big.Int)
|
||||
serial.SetString(serialHex, 16)
|
||||
|
||||
info, _ := ca.store.GetCertificate(serialHex)
|
||||
revocationTime := info.RevokedAt
|
||||
if revocationTime.IsZero() {
|
||||
revocationTime = time.Now()
|
||||
}
|
||||
|
||||
revokedCerts = append(revokedCerts, pkix.RevokedCertificate{
|
||||
SerialNumber: serial,
|
||||
RevocationTime: revocationTime,
|
||||
})
|
||||
}
|
||||
|
||||
// Create CRL
|
||||
now := time.Now()
|
||||
crlTemplate := &x509.RevocationList{
|
||||
RevokedCertificates: revokedCerts,
|
||||
Number: big.NewInt(now.Unix()),
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(CRLValidity),
|
||||
}
|
||||
|
||||
crlDER, err := x509.CreateRevocationList(rand.Reader, crlTemplate, ca.cert, ca.key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create CRL: %w", err)
|
||||
}
|
||||
|
||||
// Encode to PEM
|
||||
crlPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "X509 CRL",
|
||||
Bytes: crlDER,
|
||||
})
|
||||
|
||||
// Save to file
|
||||
crlPath := filepath.Join(ca.store.dir, "ca.crl")
|
||||
if err := os.WriteFile(crlPath, crlPEM, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write CRL: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadCRL loads the Certificate Revocation List from the CA directory.
|
||||
func (ca *CA) LoadCRL() (*x509.RevocationList, error) {
|
||||
if ca.store == nil {
|
||||
return nil, fmt.Errorf("certificate store not initialized")
|
||||
}
|
||||
|
||||
crlPath := filepath.Join(ca.store.dir, "ca.crl")
|
||||
crlPEM, err := os.ReadFile(crlPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No CRL yet
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read CRL: %w", err)
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(crlPEM)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode CRL PEM")
|
||||
}
|
||||
|
||||
crl, err := x509.ParseRevocationList(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CRL: %w", err)
|
||||
}
|
||||
|
||||
return crl, nil
|
||||
}
|
||||
|
||||
// IsCertificateRevoked checks if a certificate is in the CRL.
|
||||
func (ca *CA) IsCertificateRevoked(cert *x509.Certificate) (bool, error) {
|
||||
// First check the store (faster)
|
||||
if ca.store != nil {
|
||||
serialHex := cert.SerialNumber.Text(16)
|
||||
if ca.store.IsRevoked(serialHex) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Also check CRL for completeness
|
||||
crl, err := ca.LoadCRL()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if crl == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, revoked := range crl.RevokedCertificates {
|
||||
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// CRLInfo contains metadata about a CRL.
|
||||
type CRLInfo struct {
|
||||
Issuer string `json:"issuer"`
|
||||
ThisUpdate time.Time `json:"thisUpdate"`
|
||||
NextUpdate time.Time `json:"nextUpdate"`
|
||||
RevokedCount int `json:"revokedCount"`
|
||||
}
|
||||
|
||||
// GetCRLInfo returns metadata about the current CRL.
|
||||
func (ca *CA) GetCRLInfo() (*CRLInfo, error) {
|
||||
crl, err := ca.LoadCRL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if crl == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &CRLInfo{
|
||||
Issuer: crl.Issuer.CommonName,
|
||||
ThisUpdate: crl.ThisUpdate,
|
||||
NextUpdate: crl.NextUpdate,
|
||||
RevokedCount: len(crl.RevokedCertificates),
|
||||
}, nil
|
||||
}
|
||||
52
backend/internal/pki/pki.go
Normal file
52
backend/internal/pki/pki.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Package pki provides certificate authority and certificate management
|
||||
// for mTLS authentication between Tyto servers and agents.
|
||||
package pki
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertType identifies the purpose of a certificate.
|
||||
type CertType string
|
||||
|
||||
const (
|
||||
CertTypeCA CertType = "ca"
|
||||
CertTypeServer CertType = "server"
|
||||
CertTypeAgent CertType = "agent"
|
||||
)
|
||||
|
||||
// CertInfo contains metadata about a certificate.
|
||||
type CertInfo struct {
|
||||
SerialNumber string `json:"serialNumber"`
|
||||
Subject string `json:"subject"`
|
||||
Issuer string `json:"issuer"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
Type CertType `json:"type"`
|
||||
DNSNames []string `json:"dnsNames,omitempty"`
|
||||
Revoked bool `json:"revoked"`
|
||||
RevokedAt time.Time `json:"revokedAt,omitempty"`
|
||||
}
|
||||
|
||||
// CertInfoFromX509 extracts metadata from an X.509 certificate.
|
||||
func CertInfoFromX509(cert *x509.Certificate, certType CertType) CertInfo {
|
||||
return CertInfo{
|
||||
SerialNumber: cert.SerialNumber.Text(16),
|
||||
Subject: cert.Subject.CommonName,
|
||||
Issuer: cert.Issuer.CommonName,
|
||||
NotBefore: cert.NotBefore,
|
||||
NotAfter: cert.NotAfter,
|
||||
Type: certType,
|
||||
DNSNames: cert.DNSNames,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCAValidity is the default validity period for CA certificates.
|
||||
const DefaultCAValidity = 10 * 365 * 24 * time.Hour // 10 years
|
||||
|
||||
// DefaultCertValidity is the default validity period for server/agent certificates.
|
||||
const DefaultCertValidity = 365 * 24 * time.Hour // 1 year
|
||||
|
||||
// DefaultKeySize is the default RSA key size in bits.
|
||||
const DefaultKeySize = 4096
|
||||
341
backend/internal/pki/pki_test.go
Normal file
341
backend/internal/pki/pki_test.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInitCA(t *testing.T) {
|
||||
cfg := CAConfig{
|
||||
CommonName: "Test CA",
|
||||
Organization: "Test Org",
|
||||
Country: "US",
|
||||
Validity: 24 * time.Hour,
|
||||
KeySize: 2048, // Smaller for faster tests
|
||||
}
|
||||
|
||||
ca, err := InitCA(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
|
||||
if ca.cert == nil {
|
||||
t.Fatal("CA certificate is nil")
|
||||
}
|
||||
|
||||
if !ca.cert.IsCA {
|
||||
t.Error("Certificate is not marked as CA")
|
||||
}
|
||||
|
||||
if ca.cert.Subject.CommonName != "Test CA" {
|
||||
t.Errorf("Expected CN 'Test CA', got '%s'", ca.cert.Subject.CommonName)
|
||||
}
|
||||
|
||||
if len(ca.certPEM) == 0 {
|
||||
t.Error("Certificate PEM is empty")
|
||||
}
|
||||
|
||||
if len(ca.keyPEM) == 0 {
|
||||
t.Error("Key PEM is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCASaveAndLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create CA
|
||||
cfg := CAConfig{
|
||||
CommonName: "Test CA",
|
||||
KeySize: 2048,
|
||||
}
|
||||
ca, err := InitCA(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
|
||||
// Save to directory
|
||||
if err := ca.SaveToDir(tmpDir); err != nil {
|
||||
t.Fatalf("SaveToDir failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify files exist
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "ca.crt")); err != nil {
|
||||
t.Error("ca.crt not created")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "ca.key")); err != nil {
|
||||
t.Error("ca.key not created")
|
||||
}
|
||||
|
||||
// Load from directory
|
||||
loadedCA, err := LoadCAFromDir(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCAFromDir failed: %v", err)
|
||||
}
|
||||
|
||||
if loadedCA.cert.Subject.CommonName != ca.cert.Subject.CommonName {
|
||||
t.Errorf("Loaded CA CN mismatch: got '%s', want '%s'",
|
||||
loadedCA.cert.Subject.CommonName, ca.cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateServerCert(t *testing.T) {
|
||||
// Create CA
|
||||
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate server certificate
|
||||
cfg := ServerCertConfig{
|
||||
CommonName: "test-server",
|
||||
DNSNames: []string{"localhost", "test.local"},
|
||||
IPAddresses: []net.IP{
|
||||
net.ParseIP("127.0.0.1"),
|
||||
net.ParseIP("::1"),
|
||||
},
|
||||
Validity: 24 * time.Hour,
|
||||
KeySize: 2048,
|
||||
}
|
||||
|
||||
bundle, err := ca.GenerateServerCert(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateServerCert failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify certificate properties
|
||||
cert := bundle.Certificate
|
||||
|
||||
if cert.Subject.CommonName != "test-server" {
|
||||
t.Errorf("Expected CN 'test-server', got '%s'", cert.Subject.CommonName)
|
||||
}
|
||||
|
||||
if len(cert.DNSNames) != 2 {
|
||||
t.Errorf("Expected 2 DNS names, got %d", len(cert.DNSNames))
|
||||
}
|
||||
|
||||
if len(cert.IPAddresses) != 2 {
|
||||
t.Errorf("Expected 2 IP addresses, got %d", len(cert.IPAddresses))
|
||||
}
|
||||
|
||||
// Verify it's signed by CA
|
||||
roots := x509.NewCertPool()
|
||||
roots.AddCert(ca.cert)
|
||||
|
||||
_, err = cert.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Certificate verification failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAgentCert(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create CA with store
|
||||
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
if err := ca.SaveToDir(tmpDir); err != nil {
|
||||
t.Fatalf("SaveToDir failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate agent certificate
|
||||
cfg := AgentCertConfig{
|
||||
AgentID: "agent-001",
|
||||
Validity: 24 * time.Hour,
|
||||
KeySize: 2048,
|
||||
}
|
||||
|
||||
bundle, err := ca.GenerateAgentCert(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAgentCert failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify certificate properties
|
||||
cert := bundle.Certificate
|
||||
|
||||
if cert.Subject.CommonName != "agent-001" {
|
||||
t.Errorf("Expected CN 'agent-001', got '%s'", cert.Subject.CommonName)
|
||||
}
|
||||
|
||||
// Verify client auth usage
|
||||
hasClientAuth := false
|
||||
for _, usage := range cert.ExtKeyUsage {
|
||||
if usage == x509.ExtKeyUsageClientAuth {
|
||||
hasClientAuth = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasClientAuth {
|
||||
t.Error("Certificate missing ClientAuth extended key usage")
|
||||
}
|
||||
|
||||
// Verify it's signed by CA
|
||||
roots := x509.NewCertPool()
|
||||
roots.AddCert(ca.cert)
|
||||
|
||||
_, err = cert.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Certificate verification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was recorded in store
|
||||
certs := ca.Store().ListByType(CertTypeAgent)
|
||||
if len(certs) != 1 {
|
||||
t.Errorf("Expected 1 agent cert in store, got %d", len(certs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateRevocation(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create CA with store
|
||||
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
if err := ca.SaveToDir(tmpDir); err != nil {
|
||||
t.Fatalf("SaveToDir failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate agent certificate
|
||||
bundle, err := ca.GenerateAgentCert(AgentCertConfig{
|
||||
AgentID: "agent-to-revoke",
|
||||
KeySize: 2048,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAgentCert failed: %v", err)
|
||||
}
|
||||
|
||||
serialHex := bundle.Certificate.SerialNumber.Text(16)
|
||||
|
||||
// Verify not revoked initially
|
||||
revoked, err := ca.IsCertificateRevoked(bundle.Certificate)
|
||||
if err != nil {
|
||||
t.Fatalf("IsCertificateRevoked failed: %v", err)
|
||||
}
|
||||
if revoked {
|
||||
t.Error("Certificate should not be revoked initially")
|
||||
}
|
||||
|
||||
// Revoke the certificate
|
||||
if err := ca.RevokeCertificate(serialHex); err != nil {
|
||||
t.Fatalf("RevokeCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify now revoked
|
||||
revoked, err = ca.IsCertificateRevoked(bundle.Certificate)
|
||||
if err != nil {
|
||||
t.Fatalf("IsCertificateRevoked failed: %v", err)
|
||||
}
|
||||
if !revoked {
|
||||
t.Error("Certificate should be revoked after RevokeCertificate")
|
||||
}
|
||||
|
||||
// Verify CRL was created
|
||||
crlPath := filepath.Join(tmpDir, "ca.crl")
|
||||
if _, err := os.Stat(crlPath); err != nil {
|
||||
t.Error("CRL file not created")
|
||||
}
|
||||
|
||||
// Verify CRL info
|
||||
crlInfo, err := ca.GetCRLInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("GetCRLInfo failed: %v", err)
|
||||
}
|
||||
if crlInfo.RevokedCount != 1 {
|
||||
t.Errorf("Expected 1 revoked cert in CRL, got %d", crlInfo.RevokedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateBundleSaveLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create CA
|
||||
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate certificate
|
||||
bundle, err := ca.GenerateServerCert(ServerCertConfig{
|
||||
CommonName: "test-server",
|
||||
KeySize: 2048,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateServerCert failed: %v", err)
|
||||
}
|
||||
|
||||
// Save to files
|
||||
certPath := filepath.Join(tmpDir, "server.crt")
|
||||
keyPath := filepath.Join(tmpDir, "server.key")
|
||||
|
||||
if err := bundle.SaveToFiles(certPath, keyPath); err != nil {
|
||||
t.Fatalf("SaveToFiles failed: %v", err)
|
||||
}
|
||||
|
||||
// Load from files
|
||||
loaded, err := LoadCertificateBundle(certPath, keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCertificateBundle failed: %v", err)
|
||||
}
|
||||
|
||||
if loaded.Certificate.Subject.CommonName != bundle.Certificate.Subject.CommonName {
|
||||
t.Error("Loaded certificate CN mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerTLSConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create CA
|
||||
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
|
||||
if err != nil {
|
||||
t.Fatalf("InitCA failed: %v", err)
|
||||
}
|
||||
if err := ca.SaveToDir(tmpDir); err != nil {
|
||||
t.Fatalf("SaveToDir failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate server certificate
|
||||
serverBundle, err := ca.GenerateServerCert(ServerCertConfig{
|
||||
CommonName: "test-server",
|
||||
DNSNames: []string{"localhost"},
|
||||
KeySize: 2048,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateServerCert failed: %v", err)
|
||||
}
|
||||
|
||||
serverCertPath := filepath.Join(tmpDir, "server.crt")
|
||||
serverKeyPath := filepath.Join(tmpDir, "server.key")
|
||||
if err := serverBundle.SaveToFiles(serverCertPath, serverKeyPath); err != nil {
|
||||
t.Fatalf("SaveToFiles failed: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS config
|
||||
caCertPath := filepath.Join(tmpDir, "ca.crt")
|
||||
tlsConfig, err := ServerTLSConfig(caCertPath, serverCertPath, serverKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ServerTLSConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if tlsConfig.ClientAuth != tls.RequireAndVerifyClientCert {
|
||||
t.Error("Expected RequireAndVerifyClientCert")
|
||||
}
|
||||
|
||||
if tlsConfig.MinVersion != tls.VersionTLS12 {
|
||||
t.Error("Expected TLS 1.2 minimum")
|
||||
}
|
||||
}
|
||||
170
backend/internal/pki/store.go
Normal file
170
backend/internal/pki/store.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store manages certificate records on disk.
|
||||
type Store struct {
|
||||
dir string
|
||||
certs map[string]CertInfo // serial -> info
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// storeData is the JSON structure for persistence.
|
||||
type storeData struct {
|
||||
Certificates []CertInfo `json:"certificates"`
|
||||
}
|
||||
|
||||
// NewStore creates a new certificate store in the given directory.
|
||||
func NewStore(dir string) *Store {
|
||||
s := &Store{
|
||||
dir: dir,
|
||||
certs: make(map[string]CertInfo),
|
||||
}
|
||||
s.load()
|
||||
return s
|
||||
}
|
||||
|
||||
// load reads the store from disk.
|
||||
func (s *Store) load() error {
|
||||
path := filepath.Join(s.dir, "certs.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var sd storeData
|
||||
if err := json.Unmarshal(data, &sd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, cert := range sd.Certificates {
|
||||
s.certs[cert.SerialNumber] = cert
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// save writes the store to disk.
|
||||
func (s *Store) save() error {
|
||||
s.mu.RLock()
|
||||
certs := make([]CertInfo, 0, len(s.certs))
|
||||
for _, cert := range s.certs {
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
sd := storeData{Certificates: certs}
|
||||
data, err := json.MarshalIndent(sd, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := filepath.Join(s.dir, "certs.json")
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
// AddCertificate adds a certificate to the store.
|
||||
func (s *Store) AddCertificate(info CertInfo) error {
|
||||
s.mu.Lock()
|
||||
s.certs[info.SerialNumber] = info
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// GetCertificate returns a certificate by serial number.
|
||||
func (s *Store) GetCertificate(serial string) (CertInfo, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
info, ok := s.certs[serial]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// ListCertificates returns all certificates in the store.
|
||||
func (s *Store) ListCertificates() []CertInfo {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
certs := make([]CertInfo, 0, len(s.certs))
|
||||
for _, cert := range s.certs {
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
return certs
|
||||
}
|
||||
|
||||
// ListByType returns certificates of a specific type.
|
||||
func (s *Store) ListByType(certType CertType) []CertInfo {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var certs []CertInfo
|
||||
for _, cert := range s.certs {
|
||||
if cert.Type == certType {
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
}
|
||||
return certs
|
||||
}
|
||||
|
||||
// MarkRevoked marks a certificate as revoked.
|
||||
func (s *Store) MarkRevoked(serial string) error {
|
||||
s.mu.Lock()
|
||||
if info, ok := s.certs[serial]; ok {
|
||||
info.Revoked = true
|
||||
s.certs[serial] = info
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// IsRevoked checks if a certificate is revoked.
|
||||
func (s *Store) IsRevoked(serial string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if info, ok := s.certs[serial]; ok {
|
||||
return info.Revoked
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RevokedSerials returns all revoked certificate serial numbers.
|
||||
func (s *Store) RevokedSerials() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var serials []string
|
||||
for serial, cert := range s.certs {
|
||||
if cert.Revoked {
|
||||
serials = append(serials, serial)
|
||||
}
|
||||
}
|
||||
return serials
|
||||
}
|
||||
|
||||
// writeFile is a helper that creates parent directories.
|
||||
func writeFile(path string, data []byte, perm os.FileMode) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
// readFile is a helper for reading files.
|
||||
func readFile(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
118
backend/internal/pki/tls.go
Normal file
118
backend/internal/pki/tls.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ServerTLSConfig creates a TLS configuration for a server with mTLS.
|
||||
// This requires clients to present a valid certificate signed by the CA.
|
||||
func ServerTLSConfig(caCertPath, serverCertPath, serverKeyPath string) (*tls.Config, error) {
|
||||
// Load CA certificate
|
||||
caCert, err := os.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||
}
|
||||
|
||||
// Load server certificate
|
||||
serverCert, err := tls.LoadX509KeyPair(serverCertPath, serverKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load server certificate: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ClientTLSConfig creates a TLS configuration for a client with mTLS.
|
||||
// This presents the client certificate and verifies the server's certificate.
|
||||
func ClientTLSConfig(caCertPath, clientCertPath, clientKeyPath string) (*tls.Config, error) {
|
||||
// Load CA certificate
|
||||
caCert, err := os.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||
}
|
||||
|
||||
// Load client certificate
|
||||
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ServerTLSConfigWithCRL creates a server TLS config that checks CRL for revoked certs.
|
||||
func ServerTLSConfigWithCRL(ca *CA, serverCertPath, serverKeyPath string) (*tls.Config, error) {
|
||||
// Load server certificate
|
||||
serverCert, err := tls.LoadX509KeyPair(serverCertPath, serverKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load server certificate: %w", err)
|
||||
}
|
||||
|
||||
// Create CA pool
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(ca.cert)
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
// Check each certificate in the chain against CRL
|
||||
for _, chain := range verifiedChains {
|
||||
for _, cert := range chain {
|
||||
if cert.IsCA {
|
||||
continue // Skip CA certs
|
||||
}
|
||||
|
||||
revoked, err := ca.IsCertificateRevoked(cert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check revocation status: %w", err)
|
||||
}
|
||||
if revoked {
|
||||
return fmt.Errorf("certificate %s has been revoked", cert.SerialNumber.Text(16))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtractAgentID extracts the agent ID from a verified client certificate.
|
||||
// The agent ID is stored in the certificate's CommonName.
|
||||
func ExtractAgentID(state *tls.ConnectionState) (string, error) {
|
||||
if len(state.VerifiedChains) == 0 || len(state.VerifiedChains[0]) == 0 {
|
||||
return "", fmt.Errorf("no verified certificate chain")
|
||||
}
|
||||
|
||||
cert := state.VerifiedChains[0][0]
|
||||
agentID := cert.Subject.CommonName
|
||||
if agentID == "" {
|
||||
return "", fmt.Errorf("certificate has no CommonName")
|
||||
}
|
||||
|
||||
return agentID, nil
|
||||
}
|
||||
Reference in New Issue
Block a user