core/utils/net.go

391 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package utils 提供通用工具函数
// 包括数据类型转换、时间处理、网络工具等
//
// 网络工具模块
//
// 本模块提供了完整的网络相关工具函数,包括:
//
// IP地址相关
// - IsPublicIP: 判断IP是否为公网IP识别私有网段
// - GetLocationIP: 获取本机第一个有效IPv4地址
// - LocalIPv4s: 获取本机所有IPv4地址列表
// - GetOutBoundIP: 获取外网IP地址
//
// HTTP请求相关带超时保护
// - HttpGet: 发送HTTP GET请求
// - HttpPost: 发送HTTP POST请求
// - HttpPostJSON: 发送HTTP POST JSON请求
// - HttpRequest: 执行自定义HTTP请求
// - DownloadFile: 下载文件(支持进度回调)
//
// 性能特点:
// - 所有HTTP请求都有超时保护默认30秒
// - 支持自定义超时时间
// - 使用Context进行超时控制
// - 完善的错误处理
// - 并发安全
//
// 使用示例:
//
// // 获取本机IP
// localIP := utils.GetLocationIP()
//
// // HTTP GET请求默认30秒超时
// body, _ := utils.HttpGet("https://api.example.com/data")
//
// // 自定义超时
// body, _ := utils.HttpGet("https://api.example.com/data", 5*time.Second)
//
// // POST JSON数据
// data := map[string]any{"key": "value"}
// body, _ := utils.HttpPostJSON(url, headers, data)
//
// // 下载文件(带进度)
// progress := func(total, downloaded int64) {
// fmt.Printf("进度: %.2f%%\n", float64(downloaded)/float64(total)*100)
// }
// err := utils.DownloadFile(url, "file.zip", progress)
package utils
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
)
const (
// DefaultHTTPTimeout HTTP请求默认超时时间
DefaultHTTPTimeout = 30 * time.Second
// DefaultDownloadTimeout 文件下载默认超时时间
DefaultDownloadTimeout = 5 * time.Minute
// DefaultBufferSize 默认缓冲区大小
DefaultBufferSize = 32 * 1024
)
// IsPublicIP 判断是否为公网IP
// ipString: IP地址字符串
// 返回: 是否为公网IP
func IsPublicIP(ipString string) bool {
ip := net.ParseIP(ipString)
if ip == nil {
return false
}
if ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() {
return false
}
if ip4 := ip.To4(); ip4 != nil {
// 检查私有IP地址段
switch {
case ip4[0] == 10: // 10.0.0.0/8
return false
case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31: // 172.16.0.0/12
return false
case ip4[0] == 192 && ip4[1] == 168: // 192.168.0.0/16
return false
default:
return true
}
}
return false
}
// GetLocationIP 获取本机IP地址
// 返回: 本机IP地址如果找不到则返回 "127.0.0.1"
func GetLocationIP() string {
localIP := "127.0.0.1"
// 获取所有网络接口
interfaces, err := net.Interfaces()
if err != nil {
return localIP
}
for _, iface := range interfaces {
// 跳过回环接口
if iface.Flags&net.FlagLoopback != 0 {
continue
}
// 获取接口关联的地址
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
// 检查地址是否为 IPNet 类型
ipnet, ok := addr.(*net.IPNet)
if !ok || ipnet.IP.IsLoopback() {
continue
}
// 获取 IPv4 地址
ip := ipnet.IP.To4()
if ip == nil {
continue
}
ipStr := ip.String()
// 跳过链路本地地址 169.254.x.x 和虚拟网络地址 26.26.x.x
if strings.HasPrefix(ipStr, "169.254") || strings.HasPrefix(ipStr, "26.26") {
continue
}
// 返回找到的第一个有效 IP 地址
return ipStr
}
}
return localIP
}
// LocalIPv4s 获取本机所有IPv4地址
// 返回: IPv4地址列表和错误信息
func LocalIPv4s() ([]string, error) {
var ips []string
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok || ipnet.IP.IsLoopback() {
continue
}
if ipv4 := ipnet.IP.To4(); ipv4 != nil {
ipStr := ipv4.String()
// 跳过链路本地地址
if !strings.HasPrefix(ipStr, "169.254") {
ips = append(ips, ipStr)
}
}
}
return ips, nil
}
// GetOutBoundIP 获取外网IP地址
// 返回: 外网IP地址字符串和错误信息
func GetOutBoundIP() (string, error) {
body, err := HttpGet("http://ip.dhcp.cn/?ip")
if err != nil {
return "", err
}
return string(body), nil
}
// getTimeoutDuration 获取超时时间,如果未指定则使用默认值
func getTimeoutDuration(timeout []time.Duration, defaultTimeout time.Duration) time.Duration {
if len(timeout) > 0 {
return timeout[0]
}
return defaultTimeout
}
// createHTTPClient 创建带超时的HTTP客户端
func createHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
}
}
// HttpGet 发送HTTP GET请求
// url: 请求地址
// timeout: 超时时间(可选默认30秒),可以传入多个,只使用第一个
// 返回: 响应体和错误信息
func HttpGet(url string, timeout ...time.Duration) ([]byte, error) {
timeoutDuration := getTimeoutDuration(timeout, DefaultHTTPTimeout)
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
client := createHTTPClient(timeoutDuration)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
// HttpPostJSON 发送HTTP POST JSON请求
// url: 请求地址
// header: 请求头
// data: 请求数据(将被序列化为JSON)
// 返回: 响应体和错误信息
func HttpPostJSON(url string, header map[string]string, data map[string]any) ([]byte, error) {
jsonBytes, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("marshal json failed: %w", err)
}
return HttpPost(url, header, jsonBytes)
}
// HttpPost 发送HTTP POST请求
// url: 请求地址
// header: 请求头
// data: 请求体数据
// timeout: 超时时间(可选默认30秒),可以传入多个,只使用第一个
// 返回: 响应体和错误信息
func HttpPost(url string, header map[string]string, data []byte, timeout ...time.Duration) ([]byte, error) {
timeoutDuration := getTimeoutDuration(timeout, DefaultHTTPTimeout)
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel()
request, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
// 设置默认请求头
request.Header.Set("Content-Type", "application/json;charset=UTF-8")
request.Header.Set("Request-Id", ULID())
// 设置自定义请求头
for key, val := range header {
request.Header.Set(key, val)
}
client := createHTTPClient(timeoutDuration)
resp, err := client.Do(request)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("http status %d: %s", resp.StatusCode, string(respBytes))
}
return respBytes, nil
}
// HttpRequest 执行HTTP请求
// r: HTTP请求对象
// timeout: 超时时间(可选默认30秒),可以传入多个,只使用第一个
// 返回: 响应体和错误信息
func HttpRequest(r *http.Request, timeout ...time.Duration) ([]byte, error) {
timeoutDuration := getTimeoutDuration(timeout, DefaultHTTPTimeout)
// 如果请求还没有设置context添加一个带超时的context
if r.Context() == context.Background() || r.Context() == nil {
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel()
r = r.WithContext(ctx)
}
client := createHTTPClient(timeoutDuration)
resp, err := client.Do(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
// DownloadFile 下载文件
// url: 下载地址
// saveTo: 保存路径
// fb: 进度回调函数
// timeout: 超时时间(可选默认5分钟),可以传入多个,只使用第一个
func DownloadFile(url, saveTo string, fb func(length, downLen int64), timeout ...time.Duration) error {
timeoutDuration := getTimeoutDuration(timeout, DefaultDownloadTimeout)
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("create request error: %w", err)
}
client := createHTTPClient(timeoutDuration)
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("download %s error: %w", url, err)
}
defer resp.Body.Close()
if resp.Body == nil {
return fmt.Errorf("response body is nil for %s", url)
}
// 读取服务器返回的文件大小
fsize, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil {
// 如果无法获取文件大小,设置为-1表示未知
fsize = -1
}
// 创建文件
file, err := os.Create(saveTo)
if err != nil {
return fmt.Errorf("create file %s error: %w", saveTo, err)
}
defer file.Close()
// 设置文件权限
if err := file.Chmod(0644); err != nil {
return fmt.Errorf("chmod file %s error: %w", saveTo, err)
}
// 使用缓冲区读取并写入文件,同时调用进度回调
buf := make([]byte, DefaultBufferSize)
var written int64
for {
nr, readErr := resp.Body.Read(buf)
if nr > 0 {
nw, writeErr := file.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
// 调用进度回调
if fb != nil {
fb(fsize, written)
}
}
if writeErr != nil {
return fmt.Errorf("write file error: %w", writeErr)
}
if nr != nw {
return fmt.Errorf("write file error: %w", io.ErrShortWrite)
}
}
if readErr != nil {
if readErr == io.EOF {
break
}
return fmt.Errorf("read response error: %w", readErr)
}
}
return nil
}