feat: add PKI infrastructure for mTLS authentication

PKI Package (internal/pki/):
- CA initialization with configurable validity and key size
- Server certificate generation with DNS/IP SANs
- Agent certificate generation (agent ID in CN)
- Certificate revocation list (CRL) support
- mTLS TLS configuration helpers
- File-based certificate store with JSON persistence

CLI Commands (cmd/tyto/):
- `tyto pki init-ca` - Initialize new Certificate Authority
- `tyto pki gen-server` - Generate server certificate
- `tyto pki gen-agent` - Generate agent certificate
- `tyto pki revoke` - Revoke certificate by serial
- `tyto pki list` - List all certificates
- `tyto pki info` - Show CA information

Security Features:
- RSA 4096-bit keys by default
- TLS 1.2 minimum version
- Client certificate verification for mTLS
- CRL checking in TLS handshake
- Agent ID extraction from verified certificates

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-28 07:29:42 +01:00
parent a0a947094d
commit c8fbade575
9 changed files with 1820 additions and 0 deletions

94
backend/cmd/tyto/main.go Normal file
View File

@@ -0,0 +1,94 @@
// Package main provides the Tyto CLI with subcommands for PKI management,
// server operation, and agent operation.
package main
import (
"fmt"
"os"
)
const usage = `Tyto - System Monitoring
Usage:
tyto <command> [options]
Commands:
server Start the Tyto server (default if no command given)
agent Start as a Tyto agent
pki PKI/certificate management
Run 'tyto <command> --help' for more information on a command.
`
func main() {
if len(os.Args) < 2 {
// Default to server command
runServer(os.Args[1:])
return
}
switch os.Args[1] {
case "server":
runServer(os.Args[2:])
case "agent":
runAgent(os.Args[2:])
case "pki":
runPKI(os.Args[2:])
case "-h", "--help", "help":
fmt.Print(usage)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n\n", os.Args[1])
fmt.Print(usage)
os.Exit(1)
}
}
func runServer(args []string) {
// For now, delegate to the existing server main
// In a full implementation, this would parse server-specific flags
fmt.Println("Starting Tyto server...")
fmt.Println("Use 'tyto server --help' for options")
// Import and call the actual server logic
// For now, just print usage
if len(args) > 0 && (args[0] == "-h" || args[0] == "--help") {
fmt.Print(`Usage: tyto server [options]
Options:
--port PORT HTTP port (default: 8080)
--mode MODE Operating mode: standalone, server (default: standalone)
--config FILE Config file path
Environment:
TYTO_MODE Operating mode
PORT HTTP port
TYTO_CONFIG Config file path
`)
return
}
fmt.Println("Server mode not implemented in tyto CLI yet.")
fmt.Println("Use the 'server' binary directly or set TYTO_MODE environment variable.")
}
func runAgent(args []string) {
if len(args) > 0 && (args[0] == "-h" || args[0] == "--help") {
fmt.Print(`Usage: tyto agent [options]
Options:
--id ID Agent identifier (required)
--server URL Server URL to report to (required)
--interval DURATION Collection interval (default: 5s)
--ca-cert FILE CA certificate file
--cert FILE Agent certificate file
--key FILE Agent key file
Environment:
TYTO_AGENT_ID Agent identifier
TYTO_SERVER_URL Server URL
`)
return
}
fmt.Println("Agent mode not implemented yet.")
}

375
backend/cmd/tyto/pki.go Normal file
View File

@@ -0,0 +1,375 @@
package main
import (
"flag"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"time"
"tyto/internal/pki"
)
const pkiUsage = `PKI/Certificate Management
Usage:
tyto pki <subcommand> [options]
Subcommands:
init-ca Initialize a new Certificate Authority
gen-server Generate a server certificate
gen-agent Generate an agent certificate
revoke Revoke a certificate
list List all certificates
info Show CA information
Run 'tyto pki <subcommand> --help' for more information.
`
func runPKI(args []string) {
if len(args) == 0 {
fmt.Print(pkiUsage)
return
}
switch args[0] {
case "init-ca":
pkiInitCA(args[1:])
case "gen-server":
pkiGenServer(args[1:])
case "gen-agent":
pkiGenAgent(args[1:])
case "revoke":
pkiRevoke(args[1:])
case "list":
pkiList(args[1:])
case "info":
pkiInfo(args[1:])
case "-h", "--help", "help":
fmt.Print(pkiUsage)
default:
fmt.Fprintf(os.Stderr, "Unknown pki subcommand: %s\n\n", args[0])
fmt.Print(pkiUsage)
os.Exit(1)
}
}
func pkiInitCA(args []string) {
fs := flag.NewFlagSet("pki init-ca", flag.ExitOnError)
cn := fs.String("cn", "Tyto CA", "Common Name for the CA")
org := fs.String("org", "Tyto", "Organization name")
country := fs.String("country", "US", "Country code")
outDir := fs.String("out", "/etc/tyto/pki", "Output directory for CA files")
validity := fs.Duration("validity", 10*365*24*time.Hour, "CA validity period")
fs.Parse(args)
fmt.Printf("Initializing CA in %s...\n", *outDir)
cfg := pki.CAConfig{
CommonName: *cn,
Organization: *org,
Country: *country,
Validity: *validity,
}
ca, err := pki.InitCA(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
if err := ca.SaveToDir(*outDir); err != nil {
fmt.Fprintf(os.Stderr, "Error saving CA: %v\n", err)
os.Exit(1)
}
info := ca.Info()
fmt.Println("CA initialized successfully!")
fmt.Printf(" Subject: %s\n", info.Subject)
fmt.Printf(" Serial: %s\n", info.SerialNumber)
fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339))
fmt.Printf(" Certificate: %s\n", filepath.Join(*outDir, "ca.crt"))
fmt.Printf(" Private key: %s\n", filepath.Join(*outDir, "ca.key"))
}
func pkiGenServer(args []string) {
fs := flag.NewFlagSet("pki gen-server", flag.ExitOnError)
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
cn := fs.String("cn", "", "Common Name (required)")
dnsNames := fs.String("dns", "", "DNS names (comma-separated)")
ips := fs.String("ip", "", "IP addresses (comma-separated)")
outDir := fs.String("out", "", "Output directory (default: ca-dir/certs)")
validity := fs.Duration("validity", 365*24*time.Hour, "Certificate validity")
fs.Parse(args)
if *cn == "" {
fmt.Fprintln(os.Stderr, "Error: --cn is required")
fs.Usage()
os.Exit(1)
}
if *outDir == "" {
*outDir = filepath.Join(*caDir, "certs")
}
// Load CA
ca, err := pki.LoadCAFromDir(*caDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
os.Exit(1)
}
// Parse DNS names
var dns []string
if *dnsNames != "" {
dns = strings.Split(*dnsNames, ",")
for i := range dns {
dns[i] = strings.TrimSpace(dns[i])
}
}
// Parse IP addresses
var ipAddrs []net.IP
if *ips != "" {
for _, ipStr := range strings.Split(*ips, ",") {
ip := net.ParseIP(strings.TrimSpace(ipStr))
if ip != nil {
ipAddrs = append(ipAddrs, ip)
}
}
}
cfg := pki.ServerCertConfig{
CommonName: *cn,
DNSNames: dns,
IPAddresses: ipAddrs,
Validity: *validity,
}
bundle, err := ca.GenerateServerCert(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error generating certificate: %v\n", err)
os.Exit(1)
}
// Save certificate
certPath := filepath.Join(*outDir, *cn+".crt")
keyPath := filepath.Join(*outDir, *cn+".key")
if err := bundle.SaveToFiles(certPath, keyPath); err != nil {
fmt.Fprintf(os.Stderr, "Error saving certificate: %v\n", err)
os.Exit(1)
}
info := bundle.Info(pki.CertTypeServer)
fmt.Println("Server certificate generated!")
fmt.Printf(" Subject: %s\n", info.Subject)
fmt.Printf(" Serial: %s\n", info.SerialNumber)
fmt.Printf(" DNS names: %v\n", dns)
fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339))
fmt.Printf(" Certificate: %s\n", certPath)
fmt.Printf(" Private key: %s\n", keyPath)
}
func pkiGenAgent(args []string) {
fs := flag.NewFlagSet("pki gen-agent", flag.ExitOnError)
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
agentID := fs.String("agent-id", "", "Agent ID (required, used as CN)")
outDir := fs.String("out", "", "Output directory (default: ca-dir/agents)")
validity := fs.Duration("validity", 365*24*time.Hour, "Certificate validity")
fs.Parse(args)
if *agentID == "" {
fmt.Fprintln(os.Stderr, "Error: --agent-id is required")
fs.Usage()
os.Exit(1)
}
if *outDir == "" {
*outDir = filepath.Join(*caDir, "agents")
}
// Load CA
ca, err := pki.LoadCAFromDir(*caDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
os.Exit(1)
}
cfg := pki.AgentCertConfig{
AgentID: *agentID,
Validity: *validity,
}
bundle, err := ca.GenerateAgentCert(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error generating certificate: %v\n", err)
os.Exit(1)
}
// Save certificate
certPath := filepath.Join(*outDir, *agentID+".crt")
keyPath := filepath.Join(*outDir, *agentID+".key")
if err := bundle.SaveToFiles(certPath, keyPath); err != nil {
fmt.Fprintf(os.Stderr, "Error saving certificate: %v\n", err)
os.Exit(1)
}
info := bundle.Info(pki.CertTypeAgent)
fmt.Println("Agent certificate generated!")
fmt.Printf(" Agent ID: %s\n", *agentID)
fmt.Printf(" Serial: %s\n", info.SerialNumber)
fmt.Printf(" Valid until: %s\n", info.NotAfter.Format(time.RFC3339))
fmt.Printf(" Certificate: %s\n", certPath)
fmt.Printf(" Private key: %s\n", keyPath)
}
func pkiRevoke(args []string) {
fs := flag.NewFlagSet("pki revoke", flag.ExitOnError)
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
serial := fs.String("serial", "", "Certificate serial number to revoke (required)")
fs.Parse(args)
if *serial == "" {
fmt.Fprintln(os.Stderr, "Error: --serial is required")
fs.Usage()
os.Exit(1)
}
// Load CA
ca, err := pki.LoadCAFromDir(*caDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
os.Exit(1)
}
if err := ca.RevokeCertificate(*serial); err != nil {
fmt.Fprintf(os.Stderr, "Error revoking certificate: %v\n", err)
os.Exit(1)
}
fmt.Printf("Certificate %s has been revoked.\n", *serial)
fmt.Printf("CRL updated: %s\n", filepath.Join(*caDir, "ca.crl"))
}
func pkiList(args []string) {
fs := flag.NewFlagSet("pki list", flag.ExitOnError)
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
certType := fs.String("type", "", "Filter by type: ca, server, agent")
fs.Parse(args)
// Load CA
ca, err := pki.LoadCAFromDir(*caDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
os.Exit(1)
}
store := ca.Store()
if store == nil {
fmt.Println("No certificates found.")
return
}
var certs []pki.CertInfo
switch *certType {
case "server":
certs = store.ListByType(pki.CertTypeServer)
case "agent":
certs = store.ListByType(pki.CertTypeAgent)
case "":
certs = store.ListCertificates()
default:
fmt.Fprintf(os.Stderr, "Unknown type: %s\n", *certType)
os.Exit(1)
}
if len(certs) == 0 {
fmt.Println("No certificates found.")
return
}
fmt.Printf("%-16s %-12s %-20s %-8s %s\n", "SERIAL", "TYPE", "SUBJECT", "REVOKED", "EXPIRES")
fmt.Println(strings.Repeat("-", 80))
for _, cert := range certs {
revoked := ""
if cert.Revoked {
revoked = "YES"
}
serial := cert.SerialNumber
if len(serial) > 14 {
serial = serial[:14] + ".."
}
fmt.Printf("%-16s %-12s %-20s %-8s %s\n",
serial,
cert.Type,
cert.Subject,
revoked,
cert.NotAfter.Format("2006-01-02"),
)
}
}
func pkiInfo(args []string) {
fs := flag.NewFlagSet("pki info", flag.ExitOnError)
caDir := fs.String("ca-dir", "/etc/tyto/pki", "CA directory")
fs.Parse(args)
// Load CA
ca, err := pki.LoadCAFromDir(*caDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading CA: %v\n", err)
os.Exit(1)
}
info := ca.Info()
fmt.Println("Certificate Authority Information")
fmt.Println(strings.Repeat("=", 40))
fmt.Printf("Subject: %s\n", info.Subject)
fmt.Printf("Serial Number: %s\n", info.SerialNumber)
fmt.Printf("Valid From: %s\n", info.NotBefore.Format(time.RFC3339))
fmt.Printf("Valid Until: %s\n", info.NotAfter.Format(time.RFC3339))
// Check if expired
if time.Now().After(info.NotAfter) {
fmt.Println("\n⚠ WARNING: CA certificate has expired!")
} else {
remaining := time.Until(info.NotAfter)
fmt.Printf("Remaining: %d days\n", int(remaining.Hours()/24))
}
// CRL info
crlInfo, err := ca.GetCRLInfo()
if err != nil {
fmt.Printf("\nCRL: Error loading: %v\n", err)
} else if crlInfo == nil {
fmt.Println("\nCRL: Not generated yet")
} else {
fmt.Println("\nCRL Information")
fmt.Println(strings.Repeat("-", 40))
fmt.Printf("Last Update: %s\n", crlInfo.ThisUpdate.Format(time.RFC3339))
fmt.Printf("Next Update: %s\n", crlInfo.NextUpdate.Format(time.RFC3339))
fmt.Printf("Revoked Certs: %d\n", crlInfo.RevokedCount)
}
// Certificate counts
store := ca.Store()
if store != nil {
servers := store.ListByType(pki.CertTypeServer)
agents := store.ListByType(pki.CertTypeAgent)
fmt.Println("\nCertificate Counts")
fmt.Println(strings.Repeat("-", 40))
fmt.Printf("Server Certs: %d\n", len(servers))
fmt.Printf("Agent Certs: %d\n", len(agents))
}
}

231
backend/internal/pki/ca.go Normal file
View File

@@ -0,0 +1,231 @@
package pki
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)
// CA represents a Certificate Authority for signing certificates.
type CA struct {
cert *x509.Certificate
key *rsa.PrivateKey
certPEM []byte
keyPEM []byte
store *Store
}
// CAConfig contains options for creating a new CA.
type CAConfig struct {
CommonName string
Organization string
Country string
Validity time.Duration
KeySize int
}
// DefaultCAConfig returns a CAConfig with sensible defaults.
func DefaultCAConfig() CAConfig {
return CAConfig{
CommonName: "Tyto CA",
Organization: "Tyto",
Country: "US",
Validity: DefaultCAValidity,
KeySize: DefaultKeySize,
}
}
// InitCA creates a new Certificate Authority with the given configuration.
func InitCA(cfg CAConfig) (*CA, error) {
if cfg.KeySize == 0 {
cfg.KeySize = DefaultKeySize
}
if cfg.Validity == 0 {
cfg.Validity = DefaultCAValidity
}
// Generate private key
key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate CA key: %w", err)
}
// Generate serial number
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: cfg.CommonName,
Organization: []string{cfg.Organization},
Country: []string{cfg.Country},
},
NotBefore: now,
NotAfter: now.Add(cfg.Validity),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 1,
}
// Self-sign the certificate
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return nil, fmt.Errorf("failed to create CA certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
}
// Encode to PEM
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
return &CA{
cert: cert,
key: key,
certPEM: certPEM,
keyPEM: keyPEM,
}, nil
}
// LoadCA loads an existing CA from PEM-encoded certificate and key.
func LoadCA(certPEM, keyPEM []byte) (*CA, error) {
// Parse certificate
certBlock, _ := pem.Decode(certPEM)
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
}
if !cert.IsCA {
return nil, fmt.Errorf("certificate is not a CA")
}
// Parse private key
keyBlock, _ := pem.Decode(keyPEM)
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
var key *rsa.PrivateKey
switch keyBlock.Type {
case "RSA PRIVATE KEY":
key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "PRIVATE KEY":
parsedKey, parseErr := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if parseErr != nil {
return nil, fmt.Errorf("failed to parse CA key: %w", parseErr)
}
var ok bool
key, ok = parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("CA key is not RSA")
}
default:
return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to parse CA key: %w", err)
}
return &CA{
cert: cert,
key: key,
certPEM: certPEM,
keyPEM: keyPEM,
}, nil
}
// LoadCAFromDir loads a CA from a directory containing ca.crt and ca.key files.
func LoadCAFromDir(dir string) (*CA, error) {
certPath := filepath.Join(dir, "ca.crt")
keyPath := filepath.Join(dir, "ca.key")
certPEM, err := os.ReadFile(certPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA key: %w", err)
}
ca, err := LoadCA(certPEM, keyPEM)
if err != nil {
return nil, err
}
// Initialize store for this directory
ca.store = NewStore(dir)
return ca, nil
}
// SaveToDir saves the CA certificate and key to a directory.
func (ca *CA) SaveToDir(dir string) error {
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create CA directory: %w", err)
}
certPath := filepath.Join(dir, "ca.crt")
keyPath := filepath.Join(dir, "ca.key")
if err := os.WriteFile(certPath, ca.certPEM, 0644); err != nil {
return fmt.Errorf("failed to write CA certificate: %w", err)
}
if err := os.WriteFile(keyPath, ca.keyPEM, 0600); err != nil {
return fmt.Errorf("failed to write CA key: %w", err)
}
// Initialize store
ca.store = NewStore(dir)
return nil
}
// Certificate returns the CA certificate.
func (ca *CA) Certificate() *x509.Certificate {
return ca.cert
}
// CertificatePEM returns the PEM-encoded CA certificate.
func (ca *CA) CertificatePEM() []byte {
return ca.certPEM
}
// Info returns metadata about the CA certificate.
func (ca *CA) Info() CertInfo {
return CertInfoFromX509(ca.cert, CertTypeCA)
}
// Store returns the certificate store associated with this CA.
func (ca *CA) Store() *Store {
return ca.store
}

View File

@@ -0,0 +1,270 @@
package pki
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"time"
)
// CertificateBundle contains a certificate and its private key.
type CertificateBundle struct {
Certificate *x509.Certificate
PrivateKey *rsa.PrivateKey
CertificatePEM []byte
PrivateKeyPEM []byte
}
// ServerCertConfig contains options for generating a server certificate.
type ServerCertConfig struct {
CommonName string
Organization string
DNSNames []string
IPAddresses []net.IP
Validity time.Duration
KeySize int
}
// AgentCertConfig contains options for generating an agent certificate.
type AgentCertConfig struct {
AgentID string // Used as CommonName
Organization string
Validity time.Duration
KeySize int
}
// GenerateServerCert creates a new server certificate signed by the CA.
func (ca *CA) GenerateServerCert(cfg ServerCertConfig) (*CertificateBundle, error) {
if cfg.CommonName == "" {
return nil, fmt.Errorf("common name is required")
}
if cfg.KeySize == 0 {
cfg.KeySize = DefaultKeySize
}
if cfg.Validity == 0 {
cfg.Validity = DefaultCertValidity
}
// Generate private key
key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate server key: %w", err)
}
// Generate serial number
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: cfg.CommonName,
Organization: []string{cfg.Organization},
},
NotBefore: now,
NotAfter: now.Add(cfg.Validity),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: cfg.DNSNames,
IPAddresses: cfg.IPAddresses,
}
// Sign with CA
certDER, err := x509.CreateCertificate(rand.Reader, template, ca.cert, &key.PublicKey, ca.key)
if err != nil {
return nil, fmt.Errorf("failed to create server certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("failed to parse server certificate: %w", err)
}
// Encode to PEM
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
bundle := &CertificateBundle{
Certificate: cert,
PrivateKey: key,
CertificatePEM: certPEM,
PrivateKeyPEM: keyPEM,
}
// Store certificate info
if ca.store != nil {
info := CertInfoFromX509(cert, CertTypeServer)
ca.store.AddCertificate(info)
}
return bundle, nil
}
// GenerateAgentCert creates a new agent certificate signed by the CA.
// The AgentID is used as the CommonName and is extracted during mTLS verification.
func (ca *CA) GenerateAgentCert(cfg AgentCertConfig) (*CertificateBundle, error) {
if cfg.AgentID == "" {
return nil, fmt.Errorf("agent ID is required")
}
if cfg.KeySize == 0 {
cfg.KeySize = DefaultKeySize
}
if cfg.Validity == 0 {
cfg.Validity = DefaultCertValidity
}
// Generate private key
key, err := rsa.GenerateKey(rand.Reader, cfg.KeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate agent key: %w", err)
}
// Generate serial number
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: cfg.AgentID,
Organization: []string{cfg.Organization},
},
NotBefore: now,
NotAfter: now.Add(cfg.Validity),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
// Sign with CA
certDER, err := x509.CreateCertificate(rand.Reader, template, ca.cert, &key.PublicKey, ca.key)
if err != nil {
return nil, fmt.Errorf("failed to create agent certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("failed to parse agent certificate: %w", err)
}
// Encode to PEM
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
bundle := &CertificateBundle{
Certificate: cert,
PrivateKey: key,
CertificatePEM: certPEM,
PrivateKeyPEM: keyPEM,
}
// Store certificate info
if ca.store != nil {
info := CertInfoFromX509(cert, CertTypeAgent)
ca.store.AddCertificate(info)
}
return bundle, nil
}
// Info returns metadata about this certificate bundle.
func (b *CertificateBundle) Info(certType CertType) CertInfo {
return CertInfoFromX509(b.Certificate, certType)
}
// SaveToFiles writes the certificate and key to separate files.
func (b *CertificateBundle) SaveToFiles(certPath, keyPath string) error {
if err := writeFile(certPath, b.CertificatePEM, 0644); err != nil {
return fmt.Errorf("failed to write certificate: %w", err)
}
if err := writeFile(keyPath, b.PrivateKeyPEM, 0600); err != nil {
return fmt.Errorf("failed to write private key: %w", err)
}
return nil
}
// LoadCertificateBundle loads a certificate and key from PEM files.
func LoadCertificateBundle(certPath, keyPath string) (*CertificateBundle, error) {
certPEM, err := readFile(certPath)
if err != nil {
return nil, fmt.Errorf("failed to read certificate: %w", err)
}
keyPEM, err := readFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read private key: %w", err)
}
// Parse certificate
certBlock, _ := pem.Decode(certPEM)
if certBlock == nil {
return nil, fmt.Errorf("failed to decode certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
// Parse private key
keyBlock, _ := pem.Decode(keyPEM)
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode private key PEM")
}
var key *rsa.PrivateKey
switch keyBlock.Type {
case "RSA PRIVATE KEY":
key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "PRIVATE KEY":
parsedKey, parseErr := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if parseErr != nil {
return nil, fmt.Errorf("failed to parse private key: %w", parseErr)
}
var ok bool
key, ok = parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("private key is not RSA")
}
default:
return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return &CertificateBundle{
Certificate: cert,
PrivateKey: key,
CertificatePEM: certPEM,
PrivateKeyPEM: keyPEM,
}, nil
}

169
backend/internal/pki/crl.go Normal file
View File

@@ -0,0 +1,169 @@
package pki
import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)
// CRLValidity is how long a CRL is valid before it must be regenerated.
const CRLValidity = 7 * 24 * time.Hour // 7 days
// RevokeCertificate revokes a certificate by its serial number.
func (ca *CA) RevokeCertificate(serialHex string) error {
if ca.store == nil {
return fmt.Errorf("certificate store not initialized")
}
// Mark as revoked in store
if err := ca.store.MarkRevoked(serialHex); err != nil {
return fmt.Errorf("failed to mark certificate as revoked: %w", err)
}
// Regenerate CRL
return ca.GenerateCRL()
}
// GenerateCRL creates a new Certificate Revocation List.
func (ca *CA) GenerateCRL() error {
if ca.store == nil {
return fmt.Errorf("certificate store not initialized")
}
// Get all revoked certificates
revokedSerials := ca.store.RevokedSerials()
// Build revoked certificate list
revokedCerts := make([]pkix.RevokedCertificate, 0, len(revokedSerials))
for _, serialHex := range revokedSerials {
serial := new(big.Int)
serial.SetString(serialHex, 16)
info, _ := ca.store.GetCertificate(serialHex)
revocationTime := info.RevokedAt
if revocationTime.IsZero() {
revocationTime = time.Now()
}
revokedCerts = append(revokedCerts, pkix.RevokedCertificate{
SerialNumber: serial,
RevocationTime: revocationTime,
})
}
// Create CRL
now := time.Now()
crlTemplate := &x509.RevocationList{
RevokedCertificates: revokedCerts,
Number: big.NewInt(now.Unix()),
ThisUpdate: now,
NextUpdate: now.Add(CRLValidity),
}
crlDER, err := x509.CreateRevocationList(rand.Reader, crlTemplate, ca.cert, ca.key)
if err != nil {
return fmt.Errorf("failed to create CRL: %w", err)
}
// Encode to PEM
crlPEM := pem.EncodeToMemory(&pem.Block{
Type: "X509 CRL",
Bytes: crlDER,
})
// Save to file
crlPath := filepath.Join(ca.store.dir, "ca.crl")
if err := os.WriteFile(crlPath, crlPEM, 0644); err != nil {
return fmt.Errorf("failed to write CRL: %w", err)
}
return nil
}
// LoadCRL loads the Certificate Revocation List from the CA directory.
func (ca *CA) LoadCRL() (*x509.RevocationList, error) {
if ca.store == nil {
return nil, fmt.Errorf("certificate store not initialized")
}
crlPath := filepath.Join(ca.store.dir, "ca.crl")
crlPEM, err := os.ReadFile(crlPath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No CRL yet
}
return nil, fmt.Errorf("failed to read CRL: %w", err)
}
block, _ := pem.Decode(crlPEM)
if block == nil {
return nil, fmt.Errorf("failed to decode CRL PEM")
}
crl, err := x509.ParseRevocationList(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CRL: %w", err)
}
return crl, nil
}
// IsCertificateRevoked checks if a certificate is in the CRL.
func (ca *CA) IsCertificateRevoked(cert *x509.Certificate) (bool, error) {
// First check the store (faster)
if ca.store != nil {
serialHex := cert.SerialNumber.Text(16)
if ca.store.IsRevoked(serialHex) {
return true, nil
}
}
// Also check CRL for completeness
crl, err := ca.LoadCRL()
if err != nil {
return false, err
}
if crl == nil {
return false, nil
}
for _, revoked := range crl.RevokedCertificates {
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
return true, nil
}
}
return false, nil
}
// CRLInfo contains metadata about a CRL.
type CRLInfo struct {
Issuer string `json:"issuer"`
ThisUpdate time.Time `json:"thisUpdate"`
NextUpdate time.Time `json:"nextUpdate"`
RevokedCount int `json:"revokedCount"`
}
// GetCRLInfo returns metadata about the current CRL.
func (ca *CA) GetCRLInfo() (*CRLInfo, error) {
crl, err := ca.LoadCRL()
if err != nil {
return nil, err
}
if crl == nil {
return nil, nil
}
return &CRLInfo{
Issuer: crl.Issuer.CommonName,
ThisUpdate: crl.ThisUpdate,
NextUpdate: crl.NextUpdate,
RevokedCount: len(crl.RevokedCertificates),
}, nil
}

View File

@@ -0,0 +1,52 @@
// Package pki provides certificate authority and certificate management
// for mTLS authentication between Tyto servers and agents.
package pki
import (
"crypto/x509"
"time"
)
// CertType identifies the purpose of a certificate.
type CertType string
const (
CertTypeCA CertType = "ca"
CertTypeServer CertType = "server"
CertTypeAgent CertType = "agent"
)
// CertInfo contains metadata about a certificate.
type CertInfo struct {
SerialNumber string `json:"serialNumber"`
Subject string `json:"subject"`
Issuer string `json:"issuer"`
NotBefore time.Time `json:"notBefore"`
NotAfter time.Time `json:"notAfter"`
Type CertType `json:"type"`
DNSNames []string `json:"dnsNames,omitempty"`
Revoked bool `json:"revoked"`
RevokedAt time.Time `json:"revokedAt,omitempty"`
}
// CertInfoFromX509 extracts metadata from an X.509 certificate.
func CertInfoFromX509(cert *x509.Certificate, certType CertType) CertInfo {
return CertInfo{
SerialNumber: cert.SerialNumber.Text(16),
Subject: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotBefore: cert.NotBefore,
NotAfter: cert.NotAfter,
Type: certType,
DNSNames: cert.DNSNames,
}
}
// DefaultCAValidity is the default validity period for CA certificates.
const DefaultCAValidity = 10 * 365 * 24 * time.Hour // 10 years
// DefaultCertValidity is the default validity period for server/agent certificates.
const DefaultCertValidity = 365 * 24 * time.Hour // 1 year
// DefaultKeySize is the default RSA key size in bits.
const DefaultKeySize = 4096

View File

@@ -0,0 +1,341 @@
package pki
import (
"crypto/tls"
"crypto/x509"
"net"
"os"
"path/filepath"
"testing"
"time"
)
func TestInitCA(t *testing.T) {
cfg := CAConfig{
CommonName: "Test CA",
Organization: "Test Org",
Country: "US",
Validity: 24 * time.Hour,
KeySize: 2048, // Smaller for faster tests
}
ca, err := InitCA(cfg)
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
if ca.cert == nil {
t.Fatal("CA certificate is nil")
}
if !ca.cert.IsCA {
t.Error("Certificate is not marked as CA")
}
if ca.cert.Subject.CommonName != "Test CA" {
t.Errorf("Expected CN 'Test CA', got '%s'", ca.cert.Subject.CommonName)
}
if len(ca.certPEM) == 0 {
t.Error("Certificate PEM is empty")
}
if len(ca.keyPEM) == 0 {
t.Error("Key PEM is empty")
}
}
func TestCASaveAndLoad(t *testing.T) {
tmpDir := t.TempDir()
// Create CA
cfg := CAConfig{
CommonName: "Test CA",
KeySize: 2048,
}
ca, err := InitCA(cfg)
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
// Save to directory
if err := ca.SaveToDir(tmpDir); err != nil {
t.Fatalf("SaveToDir failed: %v", err)
}
// Verify files exist
if _, err := os.Stat(filepath.Join(tmpDir, "ca.crt")); err != nil {
t.Error("ca.crt not created")
}
if _, err := os.Stat(filepath.Join(tmpDir, "ca.key")); err != nil {
t.Error("ca.key not created")
}
// Load from directory
loadedCA, err := LoadCAFromDir(tmpDir)
if err != nil {
t.Fatalf("LoadCAFromDir failed: %v", err)
}
if loadedCA.cert.Subject.CommonName != ca.cert.Subject.CommonName {
t.Errorf("Loaded CA CN mismatch: got '%s', want '%s'",
loadedCA.cert.Subject.CommonName, ca.cert.Subject.CommonName)
}
}
func TestGenerateServerCert(t *testing.T) {
// Create CA
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
// Generate server certificate
cfg := ServerCertConfig{
CommonName: "test-server",
DNSNames: []string{"localhost", "test.local"},
IPAddresses: []net.IP{
net.ParseIP("127.0.0.1"),
net.ParseIP("::1"),
},
Validity: 24 * time.Hour,
KeySize: 2048,
}
bundle, err := ca.GenerateServerCert(cfg)
if err != nil {
t.Fatalf("GenerateServerCert failed: %v", err)
}
// Verify certificate properties
cert := bundle.Certificate
if cert.Subject.CommonName != "test-server" {
t.Errorf("Expected CN 'test-server', got '%s'", cert.Subject.CommonName)
}
if len(cert.DNSNames) != 2 {
t.Errorf("Expected 2 DNS names, got %d", len(cert.DNSNames))
}
if len(cert.IPAddresses) != 2 {
t.Errorf("Expected 2 IP addresses, got %d", len(cert.IPAddresses))
}
// Verify it's signed by CA
roots := x509.NewCertPool()
roots.AddCert(ca.cert)
_, err = cert.Verify(x509.VerifyOptions{
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
})
if err != nil {
t.Errorf("Certificate verification failed: %v", err)
}
}
func TestGenerateAgentCert(t *testing.T) {
tmpDir := t.TempDir()
// Create CA with store
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
if err := ca.SaveToDir(tmpDir); err != nil {
t.Fatalf("SaveToDir failed: %v", err)
}
// Generate agent certificate
cfg := AgentCertConfig{
AgentID: "agent-001",
Validity: 24 * time.Hour,
KeySize: 2048,
}
bundle, err := ca.GenerateAgentCert(cfg)
if err != nil {
t.Fatalf("GenerateAgentCert failed: %v", err)
}
// Verify certificate properties
cert := bundle.Certificate
if cert.Subject.CommonName != "agent-001" {
t.Errorf("Expected CN 'agent-001', got '%s'", cert.Subject.CommonName)
}
// Verify client auth usage
hasClientAuth := false
for _, usage := range cert.ExtKeyUsage {
if usage == x509.ExtKeyUsageClientAuth {
hasClientAuth = true
break
}
}
if !hasClientAuth {
t.Error("Certificate missing ClientAuth extended key usage")
}
// Verify it's signed by CA
roots := x509.NewCertPool()
roots.AddCert(ca.cert)
_, err = cert.Verify(x509.VerifyOptions{
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
})
if err != nil {
t.Errorf("Certificate verification failed: %v", err)
}
// Verify it was recorded in store
certs := ca.Store().ListByType(CertTypeAgent)
if len(certs) != 1 {
t.Errorf("Expected 1 agent cert in store, got %d", len(certs))
}
}
func TestCertificateRevocation(t *testing.T) {
tmpDir := t.TempDir()
// Create CA with store
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
if err := ca.SaveToDir(tmpDir); err != nil {
t.Fatalf("SaveToDir failed: %v", err)
}
// Generate agent certificate
bundle, err := ca.GenerateAgentCert(AgentCertConfig{
AgentID: "agent-to-revoke",
KeySize: 2048,
})
if err != nil {
t.Fatalf("GenerateAgentCert failed: %v", err)
}
serialHex := bundle.Certificate.SerialNumber.Text(16)
// Verify not revoked initially
revoked, err := ca.IsCertificateRevoked(bundle.Certificate)
if err != nil {
t.Fatalf("IsCertificateRevoked failed: %v", err)
}
if revoked {
t.Error("Certificate should not be revoked initially")
}
// Revoke the certificate
if err := ca.RevokeCertificate(serialHex); err != nil {
t.Fatalf("RevokeCertificate failed: %v", err)
}
// Verify now revoked
revoked, err = ca.IsCertificateRevoked(bundle.Certificate)
if err != nil {
t.Fatalf("IsCertificateRevoked failed: %v", err)
}
if !revoked {
t.Error("Certificate should be revoked after RevokeCertificate")
}
// Verify CRL was created
crlPath := filepath.Join(tmpDir, "ca.crl")
if _, err := os.Stat(crlPath); err != nil {
t.Error("CRL file not created")
}
// Verify CRL info
crlInfo, err := ca.GetCRLInfo()
if err != nil {
t.Fatalf("GetCRLInfo failed: %v", err)
}
if crlInfo.RevokedCount != 1 {
t.Errorf("Expected 1 revoked cert in CRL, got %d", crlInfo.RevokedCount)
}
}
func TestCertificateBundleSaveLoad(t *testing.T) {
tmpDir := t.TempDir()
// Create CA
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
// Generate certificate
bundle, err := ca.GenerateServerCert(ServerCertConfig{
CommonName: "test-server",
KeySize: 2048,
})
if err != nil {
t.Fatalf("GenerateServerCert failed: %v", err)
}
// Save to files
certPath := filepath.Join(tmpDir, "server.crt")
keyPath := filepath.Join(tmpDir, "server.key")
if err := bundle.SaveToFiles(certPath, keyPath); err != nil {
t.Fatalf("SaveToFiles failed: %v", err)
}
// Load from files
loaded, err := LoadCertificateBundle(certPath, keyPath)
if err != nil {
t.Fatalf("LoadCertificateBundle failed: %v", err)
}
if loaded.Certificate.Subject.CommonName != bundle.Certificate.Subject.CommonName {
t.Error("Loaded certificate CN mismatch")
}
}
func TestServerTLSConfig(t *testing.T) {
tmpDir := t.TempDir()
// Create CA
ca, err := InitCA(CAConfig{CommonName: "Test CA", KeySize: 2048})
if err != nil {
t.Fatalf("InitCA failed: %v", err)
}
if err := ca.SaveToDir(tmpDir); err != nil {
t.Fatalf("SaveToDir failed: %v", err)
}
// Generate server certificate
serverBundle, err := ca.GenerateServerCert(ServerCertConfig{
CommonName: "test-server",
DNSNames: []string{"localhost"},
KeySize: 2048,
})
if err != nil {
t.Fatalf("GenerateServerCert failed: %v", err)
}
serverCertPath := filepath.Join(tmpDir, "server.crt")
serverKeyPath := filepath.Join(tmpDir, "server.key")
if err := serverBundle.SaveToFiles(serverCertPath, serverKeyPath); err != nil {
t.Fatalf("SaveToFiles failed: %v", err)
}
// Create TLS config
caCertPath := filepath.Join(tmpDir, "ca.crt")
tlsConfig, err := ServerTLSConfig(caCertPath, serverCertPath, serverKeyPath)
if err != nil {
t.Fatalf("ServerTLSConfig failed: %v", err)
}
if tlsConfig.ClientAuth != tls.RequireAndVerifyClientCert {
t.Error("Expected RequireAndVerifyClientCert")
}
if tlsConfig.MinVersion != tls.VersionTLS12 {
t.Error("Expected TLS 1.2 minimum")
}
}

View File

@@ -0,0 +1,170 @@
package pki
import (
"encoding/json"
"os"
"path/filepath"
"sync"
)
// Store manages certificate records on disk.
type Store struct {
dir string
certs map[string]CertInfo // serial -> info
mu sync.RWMutex
}
// storeData is the JSON structure for persistence.
type storeData struct {
Certificates []CertInfo `json:"certificates"`
}
// NewStore creates a new certificate store in the given directory.
func NewStore(dir string) *Store {
s := &Store{
dir: dir,
certs: make(map[string]CertInfo),
}
s.load()
return s
}
// load reads the store from disk.
func (s *Store) load() error {
path := filepath.Join(s.dir, "certs.json")
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var sd storeData
if err := json.Unmarshal(data, &sd); err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
for _, cert := range sd.Certificates {
s.certs[cert.SerialNumber] = cert
}
return nil
}
// save writes the store to disk.
func (s *Store) save() error {
s.mu.RLock()
certs := make([]CertInfo, 0, len(s.certs))
for _, cert := range s.certs {
certs = append(certs, cert)
}
s.mu.RUnlock()
sd := storeData{Certificates: certs}
data, err := json.MarshalIndent(sd, "", " ")
if err != nil {
return err
}
path := filepath.Join(s.dir, "certs.json")
return os.WriteFile(path, data, 0600)
}
// AddCertificate adds a certificate to the store.
func (s *Store) AddCertificate(info CertInfo) error {
s.mu.Lock()
s.certs[info.SerialNumber] = info
s.mu.Unlock()
return s.save()
}
// GetCertificate returns a certificate by serial number.
func (s *Store) GetCertificate(serial string) (CertInfo, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
info, ok := s.certs[serial]
return info, ok
}
// ListCertificates returns all certificates in the store.
func (s *Store) ListCertificates() []CertInfo {
s.mu.RLock()
defer s.mu.RUnlock()
certs := make([]CertInfo, 0, len(s.certs))
for _, cert := range s.certs {
certs = append(certs, cert)
}
return certs
}
// ListByType returns certificates of a specific type.
func (s *Store) ListByType(certType CertType) []CertInfo {
s.mu.RLock()
defer s.mu.RUnlock()
var certs []CertInfo
for _, cert := range s.certs {
if cert.Type == certType {
certs = append(certs, cert)
}
}
return certs
}
// MarkRevoked marks a certificate as revoked.
func (s *Store) MarkRevoked(serial string) error {
s.mu.Lock()
if info, ok := s.certs[serial]; ok {
info.Revoked = true
s.certs[serial] = info
}
s.mu.Unlock()
return s.save()
}
// IsRevoked checks if a certificate is revoked.
func (s *Store) IsRevoked(serial string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
if info, ok := s.certs[serial]; ok {
return info.Revoked
}
return false
}
// RevokedSerials returns all revoked certificate serial numbers.
func (s *Store) RevokedSerials() []string {
s.mu.RLock()
defer s.mu.RUnlock()
var serials []string
for serial, cert := range s.certs {
if cert.Revoked {
serials = append(serials, serial)
}
}
return serials
}
// writeFile is a helper that creates parent directories.
func writeFile(path string, data []byte, perm os.FileMode) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.WriteFile(path, data, perm)
}
// readFile is a helper for reading files.
func readFile(path string) ([]byte, error) {
return os.ReadFile(path)
}

118
backend/internal/pki/tls.go Normal file
View File

@@ -0,0 +1,118 @@
package pki
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
)
// ServerTLSConfig creates a TLS configuration for a server with mTLS.
// This requires clients to present a valid certificate signed by the CA.
func ServerTLSConfig(caCertPath, serverCertPath, serverKeyPath string) (*tls.Config, error) {
// Load CA certificate
caCert, err := os.ReadFile(caCertPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
// Load server certificate
serverCert, err := tls.LoadX509KeyPair(serverCertPath, serverKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load server certificate: %w", err)
}
return &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}, nil
}
// ClientTLSConfig creates a TLS configuration for a client with mTLS.
// This presents the client certificate and verifies the server's certificate.
func ClientTLSConfig(caCertPath, clientCertPath, clientKeyPath string) (*tls.Config, error) {
// Load CA certificate
caCert, err := os.ReadFile(caCertPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
// Load client certificate
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
return &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}, nil
}
// ServerTLSConfigWithCRL creates a server TLS config that checks CRL for revoked certs.
func ServerTLSConfigWithCRL(ca *CA, serverCertPath, serverKeyPath string) (*tls.Config, error) {
// Load server certificate
serverCert, err := tls.LoadX509KeyPair(serverCertPath, serverKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load server certificate: %w", err)
}
// Create CA pool
caCertPool := x509.NewCertPool()
caCertPool.AddCert(ca.cert)
return &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// Check each certificate in the chain against CRL
for _, chain := range verifiedChains {
for _, cert := range chain {
if cert.IsCA {
continue // Skip CA certs
}
revoked, err := ca.IsCertificateRevoked(cert)
if err != nil {
return fmt.Errorf("failed to check revocation status: %w", err)
}
if revoked {
return fmt.Errorf("certificate %s has been revoked", cert.SerialNumber.Text(16))
}
}
}
return nil
},
}, nil
}
// ExtractAgentID extracts the agent ID from a verified client certificate.
// The agent ID is stored in the certificate's CommonName.
func ExtractAgentID(state *tls.ConnectionState) (string, error) {
if len(state.VerifiedChains) == 0 || len(state.VerifiedChains[0]) == 0 {
return "", fmt.Errorf("no verified certificate chain")
}
cert := state.VerifiedChains[0][0]
agentID := cert.Subject.CommonName
if agentID == "" {
return "", fmt.Errorf("certificate has no CommonName")
}
return agentID, nil
}