// +build !windows

package staros

import (
	"errors"
	"io/ioutil"
	"os"
	"strconv"
	"strings"
	"time"
)

func NetUsage() ([]NetAdapter, error) {
	data, err := ioutil.ReadFile("/proc/net/dev")
	if err != nil {
		return []NetAdapter{}, err
	}
	sps := strings.Split(strings.TrimSpace(string(data)), "\n")
	if len(sps) < 3 {
		return []NetAdapter{}, errors.New("No Adaptor")
	}
	var res []NetAdapter
	netLists := sps[2:]
	for _, v := range netLists {
		v = strings.ReplaceAll(v, "   ", "  ")
		for strings.Contains(v, "  ") {
			v = strings.ReplaceAll(v, "  ", " ")
		}
		v = strings.TrimSpace(v)
		card := strings.Split(v, " ")
		name := strings.ReplaceAll(card[0], ":", "")
		recvBytes, _ := strconv.Atoi(card[1])
		sendBytes, _ := strconv.Atoi(card[9])
		res = append(res, NetAdapter{name, uint64(recvBytes), uint64(sendBytes)})
	}
	return res, nil
}

func NetUsageByname(name string) (NetAdapter, error) {
	ada, err := NetUsage()
	if err != nil {
		return NetAdapter{}, err
	}
	for _, v := range ada {
		if v.Name == name {
			return v, nil
		}
	}
	return NetAdapter{}, errors.New("Not Found")
}

func NetSpeeds(duration time.Duration) ([]NetSpeed, error) {
	list1, err := NetUsage()
	if err != nil {
		return []NetSpeed{}, err
	}
	time.Sleep(duration)
	list2, err := NetUsage()
	if err != nil {
		return []NetSpeed{}, err
	}
	if len(list1) > len(list2) {
		return []NetSpeed{}, errors.New("NetWork Adaptor Num Not ok")
	}
	var res []NetSpeed
	for k, v := range list1 {
		recv := float64(list2[k].RecvBytes-v.RecvBytes) / duration.Seconds()
		send := float64(list2[k].SendBytes-v.SendBytes) / duration.Seconds()
		res = append(res, NetSpeed{v.Name, recv, send})
	}
	return res, nil
}

func NetSpeedsByName(duration time.Duration, name string) (NetSpeed, error) {
	ada, err := NetSpeeds(duration)
	if err != nil {
		return NetSpeed{}, err
	}
	for _, v := range ada {
		if v.Name == name {
			return v, nil
		}
	}
	return NetSpeed{}, errors.New("Not Found")
}

// NetConnections return all TCP/UDP/UNIX DOMAIN SOCKET Connections
// if your uid != 0 ,and analysePid==true ,you should have CAP_SYS_PRTACE and CAP_DAC_OVERRIDE/CAP_DAC_READ_SEARCH Caps
func NetConnections(analysePid bool,types string) ([]NetConn, error) {
	var result []NetConn
	var inodeMap map[string]int64
	var err error
	var fileList []string
	if types=="" || strings.Contains(strings.ToLower(types),"all") {
		fileList = []string{
			"/proc/net/tcp",
			"/proc/net/tcp6",
			"/proc/net/udp",
			"/proc/net/udp6",
			"/proc/net/unix",
		}
	}
	if strings.Contains(strings.ToLower(types),"tcp") {
		fileList =append(fileList,"/proc/net/tcp","/proc/net/tcp6")
	}
	if strings.Contains(strings.ToLower(types),"udp") {
		fileList =append(fileList,"/proc/net/udp","/proc/net/udp6")
	}
	if strings.Contains(strings.ToLower(types),"unix") {
		fileList =append(fileList,"/proc/net/unix")
	}
	if analysePid {
		inodeMap, err = GetInodeMap()
		if err != nil {
			return result, err
		}
	}
	for _, file := range fileList {
		data, err := ioutil.ReadFile(file)
		if err != nil {
			return result, err
		}
		tmpRes, err := analyseNetFiles(data, inodeMap, file[strings.LastIndex(file, "/")+1:])
		if err != nil {
			return result, err
		}
		result = append(result, tmpRes...)
	}
	return result, nil
}

func GetInodeMap() (map[string]int64, error) {
	res := make(map[string]int64)
	paths, err := ioutil.ReadDir("/proc")
	if err != nil {
		return nil, err
	}
	for _, v := range paths {
		if v.IsDir() && Exists("/proc/"+v.Name()+"/fd") {
			fds, err := ioutil.ReadDir("/proc/" + v.Name() + "/fd")
			if err != nil && Exists("/proc/"+v.Name()+"/fd") {
				return nil, err
			}
			for _, fd := range fds {
				socket, err := os.Readlink("/proc/" + v.Name() + "/fd/" + fd.Name())
				if err != nil {
					continue
				}
				if !strings.Contains(socket, "socket") {
					continue
				}
				start := strings.Index(socket, "[")
				if start < 0 {
					continue
				}
				pid, err := strconv.ParseInt(v.Name(), 10, 64)
				if err != nil {
					break
				}
				res[socket[start+1:len(socket)-1]] = pid
			}
		}
	}
	return res, err
}

func analyseNetFiles(data []byte, inodeMap map[string]int64, typed string) ([]NetConn, error) {
	if typed == "unix" {
		return analyseUnixFiles(data, inodeMap, typed)
	}
	var result []NetConn
	strdata := strings.TrimSpace(string(data))
	strdata = remainOne(strdata, "  ", " ")
	csvData := strings.Split(strdata, "\n")
	pidMap := make(map[int64]*Process)
	for line, lineData := range csvData {
		if line == 0 {
			continue
		}
		v := strings.Split(strings.TrimSpace(lineData), " ")
		var res NetConn
		ip, port, err := parseHexIpPort(v[1])
		if err != nil {
			return result, err
		}
		res.LocalAddr = ip
		res.LocalPort = port
		ip, port, err = parseHexIpPort(v[2])
		if err != nil {
			return result, err
		}
		res.RemoteAddr = ip
		res.RemotePort = port
		//connection state
		if strings.Contains(typed, "tcp") {
			state, err := strconv.ParseInt(strings.TrimSpace(v[3]), 16, 64)
			if err != nil {
				return result, err
			}
			res.Status = TCP_STATE[state]
		}
		txrx_queue := strings.Split(strings.TrimSpace(v[4]), ":")
		if len(txrx_queue) != 2 {
			return result, errors.New("not a valid net file")
		}
		tx_queue, err := strconv.ParseInt(txrx_queue[0], 16, 64)
		if err != nil {
			return result, err
		}
		res.TX_Queue = tx_queue
		rx_queue, err := strconv.ParseInt(txrx_queue[1], 16, 64)
		if err != nil {
			return result, err
		}
		res.RX_Queue = rx_queue
		timer := strings.Split(strings.TrimSpace(v[5]), ":")
		if len(timer) != 2 {
			return result, errors.New("not a valid net file")
		}
		switch timer[0] {
		case "00":
			res.TimerActive = "NO_TIMER"
		case "01":
			//重传定时器
			res.TimerActive = "RETRANSMIT"
		case "02":
			//连接定时器、FIN_WAIT_2定时器或TCP保活定时器
			res.TimerActive = "KEEPALIVE"
		case "03":
			//TIME_WAIT定时器
			res.TimerActive = "TIME_WAIT"
		case "04":
			//持续定时器
			res.TimerActive = "ZERO_WINDOW_PROBE"
		default:
			res.TimerActive = "UNKNOWN"
		}
		timerJif, err := strconv.ParseInt(timer[1], 16, 64)
		if err != nil {
			return result, err
		}
		res.TimerJiffies = timerJif
		timerCnt, err := strconv.ParseInt(strings.TrimSpace(v[6]), 16, 64)
		if err != nil {
			return result, err
		}
		res.RtoTimer = timerCnt
		res.Uid, err = strconv.ParseInt(v[7], 10, 64)
		if err != nil {
			return result, err
		}
		res.Inode = v[9]
		if inodeMap != nil && len(inodeMap) > 0 {
			var ok bool
			res.Pid, ok = inodeMap[res.Inode]
			if !ok {
				res.Pid = -1
			} else {
				_, ok := pidMap[res.Pid]
				if !ok {
					tmp, err := FindProcessByPid(res.Pid)
					if err != nil {
						pidMap[res.Pid] = nil
					} else {
						pidMap[res.Pid] = &tmp
					}
				}
				res.Process = pidMap[res.Pid]
			}
		}
		res.Typed = typed
		result = append(result, res)
	}
	return result, nil
}

func analyseUnixFiles(data []byte, inodeMap map[string]int64, typed string) ([]NetConn, error) {
	var result []NetConn
	strdata := strings.TrimSpace(string(data))
	strdata = remainOne(strdata, "  ", " ")
	csvData := strings.Split(strdata, "\n")
	pidMap := make(map[int64]*Process)
	for line, lineData := range csvData {
		if line == 0 {
			continue
		}
		v := strings.Split(strings.TrimSpace(lineData), " ")
		var res NetConn
		res.Inode = v[6]
		if len(v) == 8 {
			res.Socket = v[7]
		}
		if inodeMap != nil && len(inodeMap) > 0 {
			var ok bool
			res.Pid, ok = inodeMap[res.Inode]
			if !ok {
				res.Pid = -1
			} else {
				_, ok := pidMap[res.Pid]
				if !ok || pidMap[res.Pid] == nil {
					tmp, err := FindProcessByPid(res.Pid)
					if err != nil {
						pidMap[res.Pid] = nil
					} else {
						pidMap[res.Pid] = &tmp
					}
				}
				if pidMap[res.Pid] != nil {
					res.Uid = int64(pidMap[res.Pid].RUID)
					res.Process = pidMap[res.Pid]
				}
			}
		}
		res.Typed = typed
		result = append(result, res)
	}
	return result, nil
}