package main

import (
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"os/user"
	"path/filepath"
	"runtime"
	"strconv"
	"strings"
	"time"

	"github.com/docopt/docopt-go"
	"github.com/go-ini/ini"
	"github.com/tidwall/gjson"
)

var (
	appVersion                string
	buildTime                 string
	CertBase                  string
	KeyBase                   string
	GroupName                 string
	GroupID                   int
	RedisBaseURL              string
	VaultBaseURL              string
	certificateDestination    string
	fullchainDestination      string
	keyDestination            string
	caDestination             string
	Type                      string
	tmpCertificateDestination string
	tmpFullchainDestination   string
	tmpCaDestination          string
	tmpKeyDestination         string
	certTmpDir                string
)

// app clean and exit
func appExit(status int) {
	if runtime.GOOS == "windows" {
		certTmpDir = "C:\\tmp\\acme-downloader\\"
	} else {
		certTmpDir = "/tmp/acme-downloader"
	}
	err := os.RemoveAll(certTmpDir)
	if err != nil {
	}
	os.Exit(status)
}

// check certificates
func checkCerificates(dnsname string, certificate string, fullchain string, ca string, key string, days int, fail bool) bool {

	Seconds := days * 86400
	daysNumber := time.Now().Local().Add(time.Second * time.Duration(Seconds))

	certPEM, err := ioutil.ReadFile(certificate)
	if err != nil {
		if fail == true {
			fmt.Printf("[ERR] %v\n", err)
			appExit(255)
		} else {
			return false
		}
	}

	certFullchainPEM, err := ioutil.ReadFile(fullchain)
	if err != nil {
		if fail == true {
			fmt.Printf("[ERR] %v\n", err)
			appExit(255)
		} else {
			return false
		}
	}

	rootPEM, err := ioutil.ReadFile(ca)
	if err != nil {
		if fail == true {
			fmt.Printf("[ERR] %v\n", err)
			appExit(255)
		} else {
			return false
		}
	}

	roots := x509.NewCertPool()
	ok := roots.AppendCertsFromPEM([]byte(rootPEM))
	if !ok {
		if fail == true {
			fmt.Printf("[ERR] failed to parse root certificate\n")
			appExit(255)
		} else {
			return false
		}
	}

	block, _ := pem.Decode([]byte(certPEM))
	if block == nil {
		if fail == true {
			fmt.Printf("[ERR] failed to parse certificate PEM\n")
			appExit(255)
		} else {
			return false
		}
	}
	cert, err := x509.ParseCertificate(block.Bytes)
	if err != nil {
		if fail == true {
			fmt.Printf("[ERR] failed to parse certificate %v\n", err)
			appExit(255)
		} else {
			return false
		}
	}

	fullchainBlock, _ := pem.Decode([]byte(certFullchainPEM))
	if fullchainBlock == nil {
		if fail == true {
			fmt.Printf("[ERR] failed to parse certificate PEM\n")
			appExit(255)
		} else {
			return false
		}
	}
	fullchainCert, fullchainErr := x509.ParseCertificate(fullchainBlock.Bytes)
	if fullchainErr != nil {
		if fail == true {
			fmt.Printf("[ERR] failed to parse certificate %v\n", fullchainErr)
			appExit(255)
		} else {
			return false
		}
	}

	opts := x509.VerifyOptions{
		Roots:         roots,
		DNSName:       dnsname,
		CurrentTime:   daysNumber,
		Intermediates: x509.NewCertPool(),
	}

	if _, err := cert.Verify(opts); err != nil {
		if fail == true {
			fmt.Printf("[ERR] failed to parse certificate %v\n", err.Error())
			appExit(255)
		} else {
			return false
		}
	}
	if _, fullchainErr := fullchainCert.Verify(opts); fullchainErr != nil {
		if fail == true {
			fmt.Printf("[ERR] failed to parse certificate %v\n", err.Error())
		} else {
			return false
		}
	}
	return true

}

// get redis key
func GetRedisKey(redisurl string, redistoken string) string {
	client := &http.Client{}
	req, err := http.NewRequest("GET", redisurl, nil)
	if err != nil {
		fmt.Printf("[ERR] Fail to read %v: %v\n", redisurl, err)
		appExit(255)
	}
	req.SetBasicAuth("redis", redistoken)
	resp, err := client.Do(req)
	body, err := ioutil.ReadAll(resp.Body)
	if resp.StatusCode < 200 || resp.StatusCode > 299 {
		fmt.Printf("[ERR] Fail to fetch %v\n", redisurl)
		appExit(255)
	}
	defer resp.Body.Close()
	if err != nil {
		fmt.Printf("[ERR] Fail to read %v: %v\n", redisurl, err)
		appExit(255)
	}
	return fmt.Sprintf(string(body))
}

// get Vault key
func GetVaultKey(vaulturl string, vaulttoken string) string {
	vaultClient := &http.Client{}
	req, err := http.NewRequest("GET", vaulturl, nil)
	if err != nil {
		fmt.Printf("[ERR] Fail to read %v: %v\n", vaulturl, err)
		appExit(255)
	}
	req.Header.Add("X-vault-token", vaulttoken)
	resp, err := vaultClient.Do(req)
	body, err := ioutil.ReadAll(resp.Body)
	if resp.StatusCode < 200 || resp.StatusCode > 299 {
		fmt.Printf("[ERR] Fail to fetch %v\n", vaulturl)
		appExit(255)
	}
	defer resp.Body.Close()
	if err != nil {
		fmt.Printf("[ERR] Fail to read %v: %v\n", vaulturl, err)
		appExit(255)
	}
	return gjson.Get(string(body), "data.value").String()
}

// create directory structure and write certificate to file
func WriteToFile(content string, destination string, filemode os.FileMode) {
	baseDir := filepath.Dir(destination)
	if _, err := os.Stat(baseDir); os.IsNotExist(err) {
		os.MkdirAll(baseDir, 0755)
	}

	file, err := os.OpenFile(destination, os.O_WRONLY|os.O_CREATE, filemode)
	if err != nil {
		fmt.Printf("[ERR] %v cannot be created\n", destination)
		appExit(255)
	}

	fmt.Fprintf(file, "%v\n", content)
	file.Close()
}

// move temp file to destination
func moveFile(source string, destination string, groupid int, filemode os.FileMode, dirmode os.FileMode) {
	baseDir := filepath.Dir(destination)
	if _, err := os.Stat(baseDir); os.IsNotExist(err) {
		os.MkdirAll(baseDir, 0755)
	}
	err := os.Rename(source, destination)
	if err != nil {
		fmt.Printf("[ERR] Fail to install %v: %v\n", destination, err)
		appExit(255)
	}
	if runtime.GOOS != "windows" {
		err = os.Chown(destination, 0, groupid)
		if err != nil {
			fmt.Printf("[ERR] Changing file owner to %v", groupid)
			appExit(255)
		}
	}
	fmt.Printf("[INFO] installed: %v\n", destination)
}

// ReadOSRelease from /etc/os-release
func ReadOSRelease(configfile string) map[string]string {
	ConfigParams := make(map[string]string)
	cfg, err := ini.Load(configfile)
	if err != nil {
		ConfigParams["ID"] = "unknown"
	} else {
		ConfigParams["ID"] = cfg.Section("").Key("ID").String()
	}

	return ConfigParams
}

func main() {

	OSInfo := ReadOSRelease("/etc/os-release")
	OSRelease := OSInfo["ID"]
	if OSRelease == "centos" || OSRelease == "rhel" {
		CertBase = "/etc/pki/tls/certs"
		KeyBase = "/etc/pki/tls/private"
		GroupName = "root"
	} else if OSRelease == "ubuntu" || OSRelease == "debian" {
		CertBase = "/etc/ssl/certs"
		KeyBase = "/etc/ssl/private"
		GroupName = "ssl-cert"
	} else if OSRelease == "arch" {
		CertBase = "/etc/ssl/certs"
		KeyBase = "/etc/ssl/private"
		GroupName = "root"
	} else if OSRelease == "unknown" {
		if runtime.GOOS == "windows" {
			CertBase = "DRIVE:\\PATH\\TO\\CERTIFICATE"
			KeyBase = "DRIVE:\\PATH\\TO\\KEY"
			GroupName = "root"
		} else {
			CertBase = "/PATH/TO/CERTIFICATE"
			KeyBase = "/PATH/TO/PRIV/KEY"
			GroupName = "root"
		}
	}

	usage := fmt.Sprintf(`ACME Downloader:
  - fetches and stores a given Certificate, Full Chain, CA and Private Key

Usage:
  acme-downloader --redis-token=REDISTOKEN --vault-token=VAULTTOKEN --cert-name=CERTNAME --team-name=TEAMNAME [--days=DAYS] [--type=TYPE] [--cert-destination=CERTDESTINATION] [--fullchain-destination=FULLCHAINDESTINATION] [--key-destination=KEYDESTINATION] [--ca-destination=CADESTINATION]
  acme-downloader -v | --version
  acme-downloader -b | --build
  acme-downloader -h | --help

Options:
  -h --help                                     Show this screen
  -v --version                                  Print version exit
  -b --build                                    Print version and build information and exit
  --redis-token=REDISTOKEN                      Redis access token
  --vault-token=VAULTTOKEN                      Vault access token
  --cert-name=CERTNAME                          Certificate name
  --team-name=TEAMNAME                          Team name: swd, it, ne, ti...
  --days=DAYS                                   Days before expiration [default: 30]
  --type=TYPE                                   Type, EV or OV [default: EV]
  --cert-destination=CERTDESTINATION            Cert Destination [default: %v/<cert-name>.crt]
  --fullchain-destination=FULLCHAINDESTINATION  Full Chain Destination[default: %v/<cert-name>_fullchain.crt]
  --key-destination=KEYDESTINATION              Key Destination [default: %v/<cert-name>.key]
  --ca-destination=CADESTINATION                CA Destination [default: %v/COMODO_<type>.crt]
`, CertBase, CertBase, KeyBase, CertBase)

	arguments, _ := docopt.Parse(usage, nil, true, appVersion, false)

	if arguments["--build"] == true {
		fmt.Printf("acme-downloader version: %v, built on: %v\n", appVersion, buildTime)
		appExit(0)
	}

	if runtime.GOOS == "windows" {
		tmpCertificateDestination = "C:\\tmp\\acme-downloader\\cert\\amce_cert.pem"
		tmpFullchainDestination = "C:\\tmp\\acme-downloader\\cert\\amce_fullchain.pem"
		tmpCaDestination = "C:\\tmp\\acme-downloader\\cert\\amce_ca.pem"
		tmpKeyDestination = "C:\\tmp\\acme-downloader\\key\\amce_key.pem"
		GroupID = 0 // just a fake one
	} else {
		tmpCertificateDestination = "/tmp/acme-downloader/cert/amce_cert.pem"
		tmpFullchainDestination = "/tmp/acme-downloader/cert/amce_fullchain.pem"
		tmpCaDestination = "/tmp/acme-downloader/cert/amce_ca.pem"
		tmpKeyDestination = "/tmp/acme-downloader/key/amce_key.pem"
		group, groupErr := user.LookupGroup(GroupName)
		if groupErr != nil {
			fmt.Printf("[ERR] Fail looking up %v user user info", GroupName)
			appExit(255)
		}
		GroupID, _ = strconv.Atoi(group.Gid)
	}

	VaultToken := arguments["--vault-token"].(string)
	CertName := arguments["--cert-name"].(string)
	CertNameUndercored := strings.Replace(CertName, ".", "_", -1)
	TeamName := arguments["--team-name"].(string)
	RedisToken := arguments["--redis-token"].(string)
	Type = arguments["--type"].(string)
	DayString := arguments["--days"].(string)
	Days, daysErr := strconv.Atoi(DayString)
	if daysErr != nil {
		fmt.Printf("[ERR] Days mut be an integer\n")
		appExit(255)
	}
	RedisBaseURL = "https://redis.geant.org/GET"
	VaultBaseURL = "https://vault.geant.org/v1"
	VaultURL := fmt.Sprintf("%v/%v/%v/vault_%v_key", VaultBaseURL, TeamName, CertName, CertNameUndercored)
	RedisCertURL := fmt.Sprintf("%v/%v:%v:redis_%v_pem.txt", RedisBaseURL, TeamName, CertName, CertNameUndercored)
	RedisCAURL := fmt.Sprintf("%v/%v:%v:redis_%v_chain_pem.txt", RedisBaseURL, TeamName, CertName, CertNameUndercored)
	RedisFullChainURL := fmt.Sprintf("%v/%v:%v:redis_%v_fullchain_pem.txt", RedisBaseURL, TeamName, CertName, CertNameUndercored)

	if arguments["--cert-destination"] == fmt.Sprintf(filepath.Join(CertBase, "<cert-name>.crt")) {
		certificateDestination = fmt.Sprintf(filepath.Join(CertBase, fmt.Sprintf("%v.crt", CertName)))
	} else {
		certificateDestination = arguments["--cert-destination"].(string)
	}
	if arguments["--fullchain-destination"] == fmt.Sprintf(filepath.Join(CertBase, "<cert-name>_fullchain.crt")) {
		fullchainDestination = fmt.Sprintf(filepath.Join(CertBase, fmt.Sprintf("%v_fullchain.crt", CertName)))
	} else {
		fullchainDestination = arguments["--fullchain-destination"].(string)
	}
	if arguments["--ca-destination"] == fmt.Sprintf(filepath.Join(CertBase, "COMODO_<type>.crt")) {
		caDestination = fmt.Sprintf(filepath.Join(CertBase, fmt.Sprintf("COMODO_%v.crt", Type)))
	} else {
		caDestination = arguments["--ca-destination"].(string)
	}
	if arguments["--key-destination"] == fmt.Sprintf(filepath.Join(KeyBase, "<cert-name>.key")) {
		keyDestination = fmt.Sprintf(filepath.Join(KeyBase, fmt.Sprintf("%v.key", CertName)))
	} else {
		keyDestination = arguments["--key-destination"].(string)
	}

	// check if there is a certificate installed and it is valid
	existingCert := checkCerificates(CertName, certificateDestination, fullchainDestination, caDestination, keyDestination, Days, false)
	if existingCert == true {
		fmt.Printf("[INFO] the certificates are still valid\n")
		appExit(0)
	}
	certificate := GetRedisKey(RedisCertURL, RedisToken)
	ca := GetRedisKey(RedisCAURL, RedisToken)
	fullChain := GetRedisKey(RedisFullChainURL, RedisToken)
	privKey := GetVaultKey(VaultURL, VaultToken)

	// download and test certificates on a temporary location
	WriteToFile(certificate, tmpCertificateDestination, 0644)
	WriteToFile(fullChain, tmpFullchainDestination, 0644)
	WriteToFile(ca, tmpCaDestination, 0644)
	WriteToFile(privKey, tmpKeyDestination, 0640)

	checkCerificates(CertName, tmpCertificateDestination, tmpFullchainDestination, tmpCaDestination, tmpKeyDestination, Days, true)

	// move certificates in place
	moveFile(tmpCertificateDestination, certificateDestination, GroupID, 0644, 0755)
	moveFile(tmpFullchainDestination, fullchainDestination, GroupID, 0644, 0755)
	moveFile(tmpCaDestination, caDestination, GroupID, 0644, 0755)
	moveFile(tmpKeyDestination, keyDestination, GroupID, 0640, 0750)

	// Exit 1 means application needs to be reloaded
	appExit(1)

}