From 544b3f73217522ff28c6f4ed99470290f60749d4 Mon Sep 17 00:00:00 2001 From: Moritz Marquardt Date: Wed, 1 Dec 2021 22:49:48 +0100 Subject: [PATCH] (Ab)use CSR field to store try-again date for renewals (instead of showing a mock cert), must be tested when the first renewals are due --- certificates.go | 92 ++++++++++++++++++++----------------------------- helpers.go | 37 +++++++++++++++++++- 2 files changed, 73 insertions(+), 56 deletions(-) diff --git a/certificates.go b/certificates.go index f85109d..36f0df0 100644 --- a/certificates.go +++ b/certificates.go @@ -24,6 +24,7 @@ import ( "log" "math/big" "os" + "strconv" "strings" "sync" "time" @@ -207,21 +208,9 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error { func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) { // parse certificate from database - resBytes, err := keyDatabase.Get(sni) - if err != nil { - // key database is not working - panic(err) - } - if resBytes == nil { - return tls.Certificate{}, false - } - - resGob := bytes.NewBuffer(resBytes) - resDec := gob.NewDecoder(resGob) res := &certificate.Resource{} - err = resDec.Decode(res) - if err != nil { - panic(err) + if !PogrebGet(keyDatabase, sni, res) { + return tls.Certificate{}, false } tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) @@ -237,7 +226,15 @@ func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) { // renew certificates 7 days before they expire if !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-7 * 24 * time.Hour)) { + if res.CSR != nil && len(res.CSR) > 0 { + // CSR stores the time when the renewal shall be tried again + nextTryUnix, err := strconv.ParseInt(string(res.CSR), 10, 64) + if err == nil && time.Now().Before(time.Unix(nextTryUnix, 0)) { + return tlsCertificate, true + } + } go (func() { + res.CSR = nil // acme client doesn't like CSR to be set tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "") if err != nil { log.Printf("Couldn't renew certificate for %s: %s", sni, err) @@ -310,18 +307,21 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re } if err != nil { log.Printf("Couldn't obtain certificate for %v: %s", domains, err) - return mockCert(domains[0], err.Error()), err + if renew != nil && renew.CertURL != "" { + tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey) + if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) { + // avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at + renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6 * time.Hour).Unix(), 10)) + PogrebPut(keyDatabase, []byte(name), renew) + return tlsCertificate, nil + } + } else { + return mockCert(domains[0], err.Error()), err + } } log.Printf("Obtained certificate for %v", domains) - var resGob bytes.Buffer - resEnc := gob.NewEncoder(&resGob) - err = resEnc.Encode(res) - if err != nil { - panic(err) - } - err = keyDatabase.Put([]byte(name), resGob.Bytes()) - + PogrebPut(keyDatabase, []byte(name), res) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) if err != nil { return tls.Certificate{}, err @@ -382,20 +382,11 @@ func mockCert(domain string, msg string) tls.Certificate { IssuerCertificate: outBytes, Domain: domain, } - var resGob bytes.Buffer - resEnc := gob.NewEncoder(&resGob) - err = resEnc.Encode(res) - if err != nil { - panic(err) - } databaseName := domain if domain == "*" + string(MainDomainSuffix) || domain == string(MainDomainSuffix[1:]) { databaseName = string(MainDomainSuffix) } - err = keyDatabase.Put([]byte(databaseName), resGob.Bytes()) - if err != nil { - panic(err) - } + PogrebPut(keyDatabase, []byte(databaseName), res) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) if err != nil { @@ -585,30 +576,21 @@ func setupCertificates() { } // update main cert - resBytes, err = keyDatabase.Get(MainDomainSuffix) - if err != nil { - // key database is not working - panic(err) - } - - resGob := bytes.NewBuffer(resBytes) - resDec := gob.NewDecoder(resGob) res := &certificate.Resource{} - err = resDec.Decode(res) - if err != nil { - panic(err) - } + if !PogrebGet(keyDatabase, MainDomainSuffix, res) { + log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted") + } else { + tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) - tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) - - // renew main certificate 30 days before it expires - if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) { - go (func() { - _, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])}, res, "") - if err != nil { - log.Printf("Couldn't renew certificate for *%s: %s", MainDomainSuffix, err) - } - })() + // renew main certificate 30 days before it expires + if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) { + go (func() { + _, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])}, res, "") + if err != nil { + log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err) + } + })() + } } time.Sleep(12 * time.Hour) diff --git a/helpers.go b/helpers.go index 506573b..46a1492 100644 --- a/helpers.go +++ b/helpers.go @@ -1,6 +1,10 @@ package main -import "bytes" +import ( + "bytes" + "encoding/gob" + "github.com/akrylysov/pogreb" +) // GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty // string for custom domains. @@ -19,3 +23,34 @@ func TrimHostPort(host []byte) []byte { } return host } + +func PogrebPut(db *pogreb.DB, name []byte, obj interface{}) { + var resGob bytes.Buffer + resEnc := gob.NewEncoder(&resGob) + err := resEnc.Encode(obj) + if err != nil { + panic(err) + } + err = db.Put(name, resGob.Bytes()) + if err != nil { + panic(err) + } +} + +func PogrebGet(db *pogreb.DB, name []byte, obj interface{}) bool { + resBytes, err := db.Get(name) + if err != nil { + panic(err) + } + if resBytes == nil { + return false + } + + resGob := bytes.NewBuffer(resBytes) + resDec := gob.NewDecoder(resGob) + err = resDec.Decode(obj) + if err != nil { + panic(err) + } + return true +}