diff --git a/certificates.go b/certificates.go index 80edc34..d066ed4 100644 --- a/certificates.go +++ b/certificates.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "crypto/tls" "crypto/x509" + "encoding/json" "errors" "github.com/OrlovEvgeny/go-mcache" "github.com/akrylysov/pogreb/fs" @@ -16,9 +17,11 @@ import ( "github.com/go-acme/lego/v4/challenge/resolver" "github.com/go-acme/lego/v4/challenge/tlsalpn01" "github.com/go-acme/lego/v4/providers/dns" + "io/ioutil" "log" "os" "strings" + "sync" "time" "github.com/akrylysov/pogreb" @@ -83,38 +86,30 @@ var tlsConfig = &tls.Config{ } var tlsCertificate tls.Certificate - if ok, err := keyDatabase.Has(sniBytes); err != nil { - // key database is not working - panic(err) - } else if ok { - // parse certificate from database - certPem, err := keyDatabase.Get(sniBytes) - if err != nil { - // key database is not working - panic(err) - } - keyPem, err := keyDatabase.Get(append(sniBytes, '/', 'k', 'e', 'y')) - if err != nil { - // key database is not working or key doesn't exist - panic(err) - } - - tlsCertificate, err = tls.X509KeyPair(certPem, keyPem) - if err != nil { - panic(err) - } + var err error + var ok bool + if tlsCertificate, ok = retrieveCertFromDB(sniBytes); ok { tlsCertificate.Leaf, err = x509.ParseCertificate(tlsCertificate.Certificate[0]) if err != nil { panic(err) } + + if !bytes.Equal(sniBytes, MainDomainSuffix) && !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-7 * 24 * time.Hour)) { + go (func() { + tlsCertificate, err = obtainCert(acmeClient, []string{sni}) + if err != nil { + log.Printf("Couldn't renew certificate.") + } + })() + } } - if tlsCertificate.Certificate == nil || !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-24 * time.Hour)) { + if tlsCertificate.Certificate == nil || !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-5 * time.Minute)) { // request a new certificate if bytes.Equal(sniBytes, MainDomainSuffix) { return nil, errors.New("won't request certificate for main domain, something really bad has happened") } - err := CheckUserLimit(targetOwner) + err = CheckUserLimit(targetOwner) if err != nil { return nil, err } @@ -125,7 +120,7 @@ var tlsConfig = &tls.Config{ } } - err := keyCache.Set(sni, &tlsCertificate, 15 * time.Minute) + err = keyCache.Set(sni, &tlsCertificate, 15 * time.Minute) if err != nil { panic(err) } @@ -133,6 +128,7 @@ var tlsConfig = &tls.Config{ }, PreferServerCipherSuites: true, NextProtos: []string{ + "http/1.1", tlsalpn01.ACMETLS1Protocol, }, @@ -166,11 +162,14 @@ func CheckUserLimit(user string) (error) { return nil } +var myAcmeAccount AcmeAccount +var myAcmeConfig *lego.Config + type AcmeAccount struct { Email string Registration *registration.Resource - key crypto.PrivateKey - limit equalizer.Limiter + Key crypto.PrivateKey `json:"-"` + KeyPEM string `json:"Key"` } func (u *AcmeAccount) GetEmail() string { return u.Email @@ -179,22 +178,11 @@ func (u AcmeAccount) GetRegistration() *registration.Resource { return u.Registration } func (u *AcmeAccount) GetPrivateKey() crypto.PrivateKey { - return u.key + return u.Key } func newAcmeClient(configureChallenge func(*resolver.SolverManager) error) *lego.Client { - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - panic(err) - } - myUser := AcmeAccount{ - Email: envOr("ACME_EMAIL", "noreply@example.email"), - key: privateKey, - } - config := lego.NewConfig(&myUser) - config.CADirURL = envOr("ACME_API", "https://acme.zerossl.com/v2/DV90") - config.Certificate.KeyType = certcrypto.RSA2048 - acmeClient, err := lego.NewClient(config) + acmeClient, err := lego.NewClient(myAcmeConfig) if err != nil { panic(err) } @@ -202,46 +190,12 @@ func newAcmeClient(configureChallenge func(*resolver.SolverManager) error) *lego if err != nil { panic(err) } - - // accept terms - if os.Getenv("ACME_EAB_KID") == "" || os.Getenv("ACME_EAB_HMAC") == "" { - reg, err := acmeClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true"}) - if err != nil { - panic(err) - } - myUser.Registration = reg - } else { - reg, err := acmeClient.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{ - TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true", - Kid: os.Getenv("ACME_EAB_KID"), - HmacEncoded: os.Getenv("ACME_EAB_HMAC"), - }) - if err != nil { - panic(err) - } - myUser.Registration = reg - } - return acmeClient } -var acmeClient = newAcmeClient(func(challenge *resolver.SolverManager) error { - return challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{}) -}) +var acmeClient, mainDomainAcmeClient *lego.Client var acmeClientCertificateLimitPerUser = map[string]*equalizer.TokenBucket{} -var mainDomainAcmeClient = newAcmeClient(func(challenge *resolver.SolverManager) error { - if os.Getenv("DNS_PROVIDER") == "" { - // using mock server, don't use wildcard certs - return challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{}) - } - provider, err := dns.NewDNSChallengeProviderByName(os.Getenv("DNS_PROVIDER")) - if err != nil { - return err - } - return challenge.SetDNS01Provider(provider) -}) - type AcmeTLSChallengeProvider struct{} var _ challenge.Provider = AcmeTLSChallengeProvider{} func (a AcmeTLSChallengeProvider) Present(domain, _, keyAuth string) error { @@ -252,13 +206,52 @@ func (a AcmeTLSChallengeProvider) CleanUp(domain, _, _ string) error { return nil } +func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) { + // parse certificate from database + certPem, err := keyDatabase.Get(sni) + if err != nil { + // key database is not working + panic(err) + } + if certPem == nil { + return tls.Certificate{}, false + } + keyPem, err := keyDatabase.Get(append(sni, '/', 'k', 'e', 'y')) + if err != nil { + // key database is not working or key doesn't exist + panic(err) + } + + tlsCertificate, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + panic(err) + } + return tlsCertificate, true +} + +var obtainLocks = sync.Map{} func obtainCert(acmeClient *lego.Client, domains []string) (tls.Certificate, error) { - name := domains[0] + name := strings.TrimPrefix(domains[0], "*") if os.Getenv("DNS_PROVIDER") == "" && len(domains[0]) > 0 && domains[0][0] == '*' { domains = domains[1:] } - log.Printf("Requesting new certificate for %v", domains) + // lock to avoid simultaneous requests + _, working := obtainLocks.LoadOrStore(name, struct{}{}) + if working { + for working { + time.Sleep(100 * time.Millisecond) + _, working = obtainLocks.Load(name) + } + cert, ok := retrieveCertFromDB([]byte(name)) + if !ok { + return tls.Certificate{}, errors.New("certificate failed in synchronous request") + } + return cert, nil + } + defer obtainLocks.Delete(name) + + log.Printf("Requesting new certificate for %v", domains) res, err := acmeClient.Certificate.Obtain(certificate.ObtainRequest{ Domains: domains, Bundle: true, @@ -272,22 +265,24 @@ func obtainCert(acmeClient *lego.Client, domains []string) (tls.Certificate, err err = keyDatabase.Put([]byte(name + "/key"), res.PrivateKey) if err != nil { + obtainLocks.Delete(name) panic(err) } err = keyDatabase.Put([]byte(name), res.Certificate) if err != nil { _ = keyDatabase.Delete([]byte(name + "/key")) + obtainLocks.Delete(name) panic(err) } tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) if err != nil { - panic(err) + return tls.Certificate{}, err } return tlsCertificate, nil } -func init() { +func setupCertificates() { var err error keyDatabase, err = pogreb.Open("key-database.pogreb", &pogreb.Options{ BackgroundSyncInterval: 30 * time.Second, @@ -302,6 +297,80 @@ func init() { panic(errors.New("you must set ACME_ACCEPT_TERMS and DNS_PROVIDER, unless ACME_API is set to https://acme.mock.directory")) } + if account, err := ioutil.ReadFile("acme-account.json"); err == nil { + err = json.Unmarshal(account, &myAcmeAccount) + if err != nil { + panic(err) + } + myAcmeAccount.Key, err = certcrypto.ParsePEMPrivateKey([]byte(myAcmeAccount.KeyPEM)) + if err != nil { + panic(err) + } + myAcmeConfig = lego.NewConfig(&myAcmeAccount) + myAcmeConfig.CADirURL = envOr("ACME_API", "https://acme.zerossl.com/v2/DV90") + myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048 + newAcmeClient(func(manager *resolver.SolverManager) error { return nil }) + } else if os.IsNotExist(err) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + myAcmeAccount = AcmeAccount{ + Email: envOr("ACME_EMAIL", "noreply@example.email"), + Key: privateKey, + KeyPEM: string(certcrypto.PEMEncode(privateKey)), + } + myAcmeConfig = lego.NewConfig(&myAcmeAccount) + myAcmeConfig.CADirURL = envOr("ACME_API", "https://acme.zerossl.com/v2/DV90") + myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048 + tempClient := newAcmeClient(func(manager *resolver.SolverManager) error { return nil }) + + // accept terms & log in to EAB + if os.Getenv("ACME_EAB_KID") == "" || os.Getenv("ACME_EAB_HMAC") == "" { + reg, err := tempClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true"}) + if err != nil { + panic(err) + } + myAcmeAccount.Registration = reg + } else { + reg, err := tempClient.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{ + TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true", + Kid: os.Getenv("ACME_EAB_KID"), + HmacEncoded: os.Getenv("ACME_EAB_HMAC"), + }) + if err != nil { + panic(err) + } + myAcmeAccount.Registration = reg + } + + acmeAccountJson, err := json.Marshal(myAcmeAccount) + if err != nil { + panic(err) + } + err = ioutil.WriteFile("acme-account.json", acmeAccountJson, 0600) + if err != nil { + panic(err) + } + } else { + panic(err) + } + + acmeClient = newAcmeClient(func(challenge *resolver.SolverManager) error { + return challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{}) + }) + mainDomainAcmeClient = newAcmeClient(func(challenge *resolver.SolverManager) error { + if os.Getenv("DNS_PROVIDER") == "" { + // using mock server, don't use wildcard certs + return challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{}) + } + provider, err := dns.NewDNSChallengeProviderByName(os.Getenv("DNS_PROVIDER")) + if err != nil { + return err + } + return challenge.SetDNS01Provider(provider) + }) + go (func() { for { err := keyDatabase.Sync() diff --git a/domains.go b/domains.go index 0731672..0bf2605 100644 --- a/domains.go +++ b/domains.go @@ -46,10 +46,10 @@ func getTargetFromDNS(domain string) (targetOwner, targetRepo, targetBranch stri cnameParts := strings.Split(strings.TrimSuffix(cname, string(MainDomainSuffix)), ".") targetOwner = cnameParts[len(cnameParts)-1] if len(cnameParts) > 1 { - targetRepo = cnameParts[len(cnameParts)-1] + targetRepo = cnameParts[len(cnameParts)-2] } if len(cnameParts) > 2 { - targetBranch = cnameParts[len(cnameParts)-2] + targetBranch = cnameParts[len(cnameParts)-3] } if targetRepo == "" { targetRepo = "pages" diff --git a/main.go b/main.go index 61ff168..c795f1f 100644 --- a/main.go +++ b/main.go @@ -102,6 +102,8 @@ func main() { } listener = tls.NewListener(listener, tlsConfig) + setupCertificates() + // Start the web server err = server.Serve(listener) if err != nil {