// 校验授权文件

package licence

import (
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"path"
	"sort"
	"strconv"
	"strings"
	"time"

	"git.apinb.com/bsm-sdk/core/crypto/aes"
	"git.apinb.com/bsm-sdk/core/utils"
	"github.com/shirou/gopsutil/cpu"
)

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// CPU 信息
type CpuInfo struct {
	Cpu        int32  `json:"cpu"`
	Cores      int32  `json:"cores"`
	ModelName  string `json:"modelName"`
	VendorId   string `json:"vendorId"`
	PhysicalId string `json:"physicalId"`
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
type MachineInfo struct {
	MacAddrs []string   `json:"macAddrs"`
	Cpus     []*CpuInfo `json:"cpus"`
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 授权信息
type Licence struct {
	CompanyName  string   `json:"licence_to"`    // 授权公司
	CreateDate   int      `json:"create"`        // 生效日期
	ExpireDate   int      `json:"expire"`        // 有效期
	MachineCodes []string `json:"machine_codes"` // 机器码列表
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
var (
	des_key string
	des_iv  string

	Check_Licence_File bool = true // 是否检查部署授权文件
)

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const (
	signKey     = "8E853B589944FF7A56BEF02AAA51D6F4"
	LICENCE_KEY = "TRAIN_LICENCE_KEY"
)

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
func init() {
	des_key = base64.StdEncoding.EncodeToString([]byte(signKey))
	des_iv = base64.StdEncoding.EncodeToString([]byte(signKey[:16]))
}

func WatchCheckLicence(licPath, licName string) {
	for {
		if CheckLicence(licPath, licName) == false {
			log.Println("授权文件失效,请重新部署授权文件:", licPath)
			os.Exit(99)
		}
		time.Sleep(time.Hour * 1)
	}
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
func CheckLicence(licPath, licName string) bool {
	// 加载授权文件
	content, err := LoadLicenceFromFile(licPath)
	if err != nil {
		return false
	}

	// 解密解码授权文件
	var l Licence
	if err := DecodeLicence(content, &l); err != nil {
		return false
	}

	return l.VerifyLicence(licName)
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// Licence 校验
func (l *Licence) VerifyLicence(licName string) bool {
	// 用于开发环境,为授权公司时跳过验证
	if l.CompanyName == licName {
		return true
	}

	today := StrToInt(time.Now().Format("20060102"))
	// 机器日期不在授权文件有限期之内 (早于生效日期,或超过有效期)
	if (today < l.CreateDate) || (today > l.ExpireDate) {
		return false
	}

	// 机器码不在授权列表中
	machine_code := GetMachineCode()
	if (len(machine_code) == 0) || !l.ValidMachineCode(machine_code) {
		return false
	}

	return true
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 检查机器码是否存在授权列表中
func (l *Licence) ValidMachineCode(code string) bool {
	result := false
	for _, c := range l.MachineCodes {
		if c == code {
			result = true
			break
		}
	}

	return result
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 加载授权文件
func LoadLicenceFromFile(licPath string) (string, error) {
	key_path := path.Join(licPath, "licence.key")
	if utils.PathExists(key_path) {
		file, err := os.Open(key_path)
		if err != nil {
			return "", err
		}
		defer file.Close()

		content, err := io.ReadAll(file)
		if err != nil {
			return "", err
		}

		return string(content), nil
	}
	return "", errors.New("授权文件不存在")
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 解密授权文件内容,并析构至 Licence
func DecodeLicence(content string, licence *Licence) error {
	if len(content) == 0 {
		return errors.New("授权文件无效")
	}

	if txt := DecryptStr(content); len(txt) > 0 {
		if err := json.Unmarshal([]byte(txt), licence); err != nil {
			return err
		}
	}

	return nil
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 字符串解密
func DecryptStr(txt string) string {
	result := ""

	if len(txt) > 0 {
		result = aes.Decrypt(des_key, des_iv, txt)
	}

	return result
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 加密字符串
func EncryptStr(txt string) string {
	result := ""

	if len(txt) > 0 {
		result = aes.Encrypt(des_key, des_iv, txt)
	}

	return result
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 生成授权文件
func BuildLicence(company_name string, Create_date int, expire_date int, machine_codes []string) (string, error) {
	// 构造licence
	licence := &Licence{
		CompanyName:  company_name,
		CreateDate:   Create_date,
		ExpireDate:   expire_date,
		MachineCodes: machine_codes,
	}

	content, err := json.Marshal(licence)
	if err != nil {
		return "", err
	}
	// 对json加密
	result := EncryptStr(string(content))
	return result, nil
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 生成激活码
func GetLicenceCode(machinceCode string) string {
	return hash256(machinceCode + "&" + signKey)
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 获取CPU信息
func getCpuInfo() []*CpuInfo {
	cpus, _ := cpu.Info()
	cinfos := make([]*CpuInfo, len(cpus))
	for k, v := range cpus {
		cinfos[k] = &CpuInfo{
			Cpu:        v.CPU,
			ModelName:  v.ModelName,
			Cores:      v.Cores,
			VendorId:   v.VendorID,
			PhysicalId: v.PhysicalID,
		}
	}
	sort.SliceStable(cinfos, func(i, j int) bool {
		if cinfos[i].Cores != cinfos[j].Cores {
			return cinfos[i].Cores > cinfos[j].Cores
		}
		return cinfos[i].Cpu < cinfos[j].Cpu
	})
	return cinfos
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 获取MAC地址
func getMacAddrs() []string {
	var macs []string
	netfaces, err := net.Interfaces()
	if err != nil {
		return macs
	}
	for i := 0; i < len(netfaces); i++ {
		if (netfaces[i].Flags&net.FlagUp) != 0 && (netfaces[i].Flags&net.FlagLoopback) == 0 {
			addrs, _ := netfaces[i].Addrs()
			for _, address := range addrs {
				ipnet, ok := address.(*net.IPNet)
				if ok && ipnet.IP.IsGlobalUnicast() {
					macs = append(macs, netfaces[i].HardwareAddr.String())
				}
			}
		}
	}
	sort.Slice(macs, func(i, j int) bool {
		return macs[i] > macs[j]
	})
	return macs
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 获取机器码
func GetMachineCode() string {
	info := MachineInfo{MacAddrs: getMacAddrs(), Cpus: getCpuInfo()}
	json, _ := json.Marshal(info)
	machineCode := hash256(string(json))
	licenceCode := GetLicenceCode(machineCode)
	return licenceCode
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// 加密
func hash256(d string) string {
	h := sha256.New()
	h.Write([]byte(d))
	bs := fmt.Sprintf("%x", h.Sum(nil))
	return strings.ToUpper(bs)
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
func StrToInt(valstr string) int {
	val, err := strconv.Atoi(valstr)
	if err != nil {
		val = 0
	}
	return val
}