-
Notifications
You must be signed in to change notification settings - Fork 2
/
cert_management.go
127 lines (108 loc) · 3.01 KB
/
cert_management.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package amp
import (
"crypto/tls"
"crypto/x509"
"sync"
"time"
"go.uber.org/zap"
)
// From
// https://stackoverflow.com/a/40883377/265026
type KeypairReloader struct {
logger *zap.Logger
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
}
func NewKeypairReloader(certPath, keyPath string, logger *zap.Logger) (*KeypairReloader, error) {
kpr := &KeypairReloader{
logger: logger,
certPath: certPath,
keyPath: keyPath,
}
logger.Info("NewKeypairReloader loading",
zap.String("certPath", certPath),
zap.String("keyPath", keyPath))
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}
kpr.cert = &cert
go func() {
err := kpr.certExpChecker()
if err != nil {
kpr.logger.Fatal("certExpChecker",
zap.String("certPath", certPath),
zap.String("keyPath", keyPath),
zap.Error(err),
)
}
}()
//go func() {
// c := make(chan os.Signal, 1)
// signal.Notify(c, syscall.SIGHUP)
// for range c {
// logger.Info("Reloading TLS certificate and key",
// zap.String("cert", certPath), zap.String("key", keyPath))
// if err := result.maybeReload(); err != nil {
// logger.Error("Keeping old TLS certificate because the new one could not be loaded", zap.Error(err))
// }
// }
//}()
return kpr, nil
}
func (kpr *KeypairReloader) certExpChecker() error {
for {
// Parse the certificate data
parsedCert, err := x509.ParseCertificate(kpr.cert.Certificate[0])
if err != nil {
return err
}
now := time.Now()
expSecs := parsedCert.NotAfter.Unix() - now.Unix()
kpr.logger.Info("Checking cert",
zap.Int64("expiresInSec", expSecs))
// if the cert expires in less than 10 minutes attempt a reload
if expSecs < 600 {
kpr.logger.Warn("TLS Certificate is about to expire or has expired, attempting reload",
zap.Int64("expiresInSec", expSecs),
)
if err := kpr.maybeReload(); err != nil {
kpr.logger.Error("Keeping old TLS certificate because the new one could not be loaded",
zap.Error(err))
}
}
// if expSecs < 600 then wait 10 seconds between checks
waitTime := 10 * time.Second
// if expSecs > 600 then wait 50% of remaining seconds
if expSecs > 600 {
waitTime = time.Duration(int64(float64(expSecs)*0.5)) * time.Second
}
kpr.logger.Info("Setting next certificate check",
zap.Int64("expiresInSec", expSecs),
zap.Duration("waitTime", waitTime))
time.Sleep(waitTime)
}
}
func (kpr *KeypairReloader) maybeReload() error {
kpr.logger.Info("Attempting certificate reload",
zap.String("certPath", kpr.certPath),
zap.String("keyPath", kpr.keyPath),
)
newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath)
if err != nil {
return err
}
kpr.certMu.Lock()
defer kpr.certMu.Unlock()
kpr.cert = &newCert
return nil
}
func (kpr *KeypairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
kpr.certMu.RLock()
defer kpr.certMu.RUnlock()
return kpr.cert, nil
}
}