diff --git a/hosts/hosts.go b/hosts/hosts.go index 8df73ff..d3960bb 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -1,213 +1,442 @@ package hosts import ( - "bytes" - "io/ioutil" + "bufio" + "fmt" + "io" + "net" + "os" "runtime" + "strings" "sync" - - "b612.me/staros/sysconf" ) -var hostsPath string -var hostsCfg *sysconf.SysConf -var haveError error -var lock sync.Mutex -var reverse bool = false - -func init() { +func SystemHostpath() string { if runtime.GOOS == "windows" { - hostsPath = `C:\Windows\System32\drivers\etc\hosts` - } else { - hostsPath = `/etc/hosts` + return `C:\Windows\System32\drivers\etc\hosts` } + return `/etc/hosts` } -func SetHostsPath(path string) { - hostsPath = path +type HostNode struct { + uid uint64 + nextuid uint64 + lastuid uint64 + IP string + Host []string + Comment string + Original string + OnlyComment bool + Valid bool } - -func GetHostsPath() string { - return hostsPath +type Host struct { + idx uint64 + hostPath string + nextUid uint64 + lastUid uint64 + fulldata map[uint64]*HostNode + hostData map[string][]*HostNode + ipData map[string][]*HostNode + sync.RWMutex } -func Parse() error { - hostsCfg = new(sysconf.SysConf) - hostsCfg.EqualFlag = ` ` - hostsCfg.HaveSegMent = false - hostsCfg.CommentCR = true - hostsCfg.CommentFlag = []string{"#"} - data, haveError := ioutil.ReadFile(hostsPath) - if haveError != nil { - return haveError +func NewHosts() *Host { + return &Host{ + hostData: make(map[string][]*HostNode), + ipData: make(map[string][]*HostNode), + fulldata: make(map[uint64]*HostNode), } - data = bytes.ReplaceAll(data, []byte(` `), []byte(" ")) - haveError = hostsCfg.Parse(data) - return haveError } -func GetIpByDomain(domainName string) []string { - if haveError != nil { - return []string{} +func (h *Host) Parse(hostPath string) error { + h.Lock() + defer h.Unlock() + h.hostPath = hostPath + h.fulldata = make(map[uint64]*HostNode) + h.hostData = make(map[string][]*HostNode) + h.ipData = make(map[string][]*HostNode) + return h.parse() +} + +func (h *Host) parse() error { + f, err := os.OpenFile(h.hostPath, os.O_RDONLY, 0666) + if err != nil { + return fmt.Errorf("open hosts file %s error: %s", h.hostPath, err) } - lock.Lock() - defer lock.Unlock() - if !reverse { - hostsCfg.Reverse() - reverse = !reverse + defer f.Close() + buf := bufio.NewReader(f) + for { + line, err := buf.ReadString('\n') + if err == io.EOF { + if h.idx-1 >= 0 { + h.fulldata[h.idx].nextuid = 0 + h.lastUid = h.idx + } + break + } + if err != nil { + return fmt.Errorf("read hosts file error: %s", err) + } + h.idx++ + line = strings.TrimSpace(line) + data, _ := h.parseLine(line) + data.uid = h.idx + data.lastuid = h.idx - 1 + data.nextuid = h.idx + 1 + h.fulldata[data.uid] = &data + if h.nextUid == 0 { + h.nextUid = h.idx + } + if data.Valid { + for _, v := range data.Host { + h.hostData[v] = append(h.hostData[v], &data) + } + h.ipData[data.IP] = append(h.ipData[data.IP], &data) + } } - return hostsCfg.Data[0].GetAll(domainName) + + return nil } -func GetDomainByIp(IpAddr string) []string { - if haveError != nil { - return []string{} +func (h *Host) parseLine(data string) (HostNode, error) { + var res = HostNode{ + Original: data, + } + if len(data) == 0 { + return res, fmt.Errorf("empty line") } - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Reverse() - reverse = !reverse + if data[0] == '#' { + res.Comment = data + res.OnlyComment = true + return res, nil + } + var dataArr []string + cache := "" + for k, v := range data { + if v == '#' { + if len(cache) > 0 { + dataArr = append(dataArr, cache) + cache = "" + } + dataArr = append(dataArr, data[k:]) + break + } + if v == ' ' || v == '\t' { + if len(cache) > 0 { + dataArr = append(dataArr, cache) + cache = "" + } + continue + } + cache += string(v) } - return hostsCfg.Data[0].GetAll(IpAddr) + if len(cache) > 0 { + dataArr = append(dataArr, cache) + } + if len(dataArr) < 2 { + return res, fmt.Errorf("invalid line") + } + if strings.HasPrefix(dataArr[1], "#") { + return res, fmt.Errorf("invalid line") + } + if net.ParseIP(dataArr[0]) == nil { + return res, fmt.Errorf("invalid ip address") + } + for k, v := range dataArr { + switch k { + case 0: + res.IP = v + default: + if strings.HasPrefix(v, "#") { + res.Comment = v + } else { + res.Host = append(res.Host, v) + } + } + } + res.Valid = true + return res, nil } -func AddHosts(ip, host string) { - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Data[0].AddValue(host, ip, "") - } else { - hostsCfg.Data[0].AddValue(ip, host, "") +func (h *Host) List() []*HostNode { + h.RLock() + defer h.RUnlock() + var res []*HostNode + nextUid := h.nextUid + for { + if nextUid == 0 { + break + } + res = append(res, h.fulldata[nextUid]) + nextUid = h.fulldata[nextUid].nextuid } + return res } -func RemoveHost(ip, host string) { - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Data[0].DeleteValue(host, ip) - } else { - hostsCfg.Data[0].DeleteValue(ip, host) +func (h *Host) ListByHost(host string) []*HostNode { + h.RLock() + defer h.RUnlock() + if h.hostData == nil { + return nil } + return h.hostData[host] } -func RemoveHostbyIp(ip string) { - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) ListIPsByHost(host string) []string { + h.RLock() + defer h.RUnlock() + if h.hostData == nil { + return nil + } + var res []string + for _, v := range h.hostData[host] { + res = append(res, v.IP) } - hostsCfg.Data[0].Delete(ip) + return res } -func RemoveHostbyHost(host string) { - lock.Lock() - defer lock.Unlock() - if !reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) ListFirstIPByHost(host string) string { + h.RLock() + defer h.RUnlock() + if h.hostData == nil { + return "" + } + for _, v := range h.hostData[host] { + return v.IP } - hostsCfg.Data[0].Delete(host) + return "" } -// SetHost 设定唯一的Host对应值,一个host对应一个ip,其余的删除 -func SetHost(ip, host string) { - lock.Lock() - defer lock.Unlock() - if !reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) ListByIP(ip string) []*HostNode { + h.RLock() + defer h.RUnlock() + if h.ipData == nil { + return nil } - hostsCfg.Data[0].Set(host, ip, "") + return h.ipData[ip] +} + +func (h *Host) ListHostsByIP(ip string) []string { + h.RLock() + return h.listHostsByIP(ip) } -func SetHostbyIp(ip, host string) { - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) listHostsByIP(ip string) []string { + if h.ipData == nil { + return nil + } + var res []string + for _, v := range h.ipData[ip] { + res = append(res, v.Host...) } - hostsCfg.Data[0].Set(ip, host, "") + return res } -func Build() string { - if reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) ListFirstHostByIP(ip string) string { + h.RLock() + defer h.RUnlock() + if h.ipData == nil { + return "" } - return string(hostsCfg.Build()) + for _, v := range h.ipData[ip] { + if len(v.Host) > 0 { + return v.Host[0] + } + } + return "" } -func Write() error { - lock.Lock() - defer lock.Unlock() - data := []byte(Build()) - return ioutil.WriteFile(hostsPath, data, 0644) +func (h *Host) AddHosts(ip string, hosts ...string) error { + return h.addHosts("", ip, hosts...) } -func GetHostList() []string { - var res []string - lock.Lock() - defer lock.Unlock() - if !reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) AddHostsComment(comment string, ip string, hosts ...string) error { + return h.addHosts(comment, ip, hosts...) +} + +func (h *Host) RemoveIPHosts(ip string, hosts ...string) error { + h.Lock() + defer h.Unlock() + if h.ipData == nil { + return fmt.Errorf("hosts data not initialized") } - for _, v := range hostsCfg.Data[0].NodeData { - if v != nil { - res = append(res, v.Key) + ipInfo := h.ipData[ip] + if len(ipInfo) == 0 { + return nil + } +cntfor: + for _, v := range ipInfo { + for _, host := range hosts { + if len(v.Host) == 1 && v.Host[0] == host { + delete(h.ipData, ip) + if v.lastuid != 0 { + h.fulldata[v.lastuid].nextuid = v.nextuid + } else { + h.nextUid = v.nextuid + } + if v.nextuid != 0 { + h.fulldata[v.nextuid].lastuid = v.lastuid + } else { + h.lastUid = v.lastuid + } + var newHostData []*HostNode + for _, vv := range h.hostData[v.Host[0]] { + if vv.uid != v.uid { + newHostData = append(newHostData, vv) + } + } + h.hostData[host] = newHostData + delete(h.fulldata, v.uid) + v = nil + continue cntfor + } + if len(v.Host) > 1 { + var newHosts []string + for _, vv := range v.Host { + if vv != host { + newHosts = append(newHosts, vv) + } + } + v.Host = newHosts + var newHostData []*HostNode + for _, vv := range h.hostData[host] { + if vv.uid != v.uid { + newHostData = append(newHostData, vv) + } + } + h.hostData[host] = newHostData + } } } - return res + return nil } -func GetIpList() []string { - var res []string - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) RemoveIPs(ips ...string) error { + h.Lock() + defer h.Unlock() + if h.ipData == nil { + return fmt.Errorf("hosts data not initialized") } - for _, v := range hostsCfg.Data[0].NodeData { - if v != nil { - res = append(res, v.Key) + for _, ip := range ips { + ipInfo := h.ipData[ip] + if len(ipInfo) == 0 { + continue + } + for _, v := range ipInfo { + delete(h.ipData, ip) + delete(h.fulldata, v.uid) + if v.lastuid != 0 { + h.fulldata[v.lastuid].nextuid = v.nextuid + } else { + h.nextUid = v.nextuid + } + if v.nextuid != 0 { + h.fulldata[v.nextuid].lastuid = v.lastuid + } else { + h.lastUid = v.lastuid + } + for _, host := range v.Host { + var newHostData []*HostNode + for _, vv := range h.hostData[host] { + if vv.uid != v.uid { + newHostData = append(newHostData, vv) + } + } + h.hostData[host] = newHostData + } } } - return res + return nil } -func GetAllListbyHost() map[string][]string { - res := make(map[string][]string) - lock.Lock() - defer lock.Unlock() - if !reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) RemoveHosts(hosts ...string) error { + h.Lock() + defer h.Unlock() + if h.hostData == nil { + return fmt.Errorf("hosts data not initialized") } - for _, v := range hostsCfg.Data[0].NodeData { - if v != nil { - res[v.Key] = v.Value + for _, host := range hosts { + hostInfo := h.hostData[host] + if len(hostInfo) == 0 { + continue + } + delete(h.hostData, host) + for _, v := range hostInfo { + var newHosts []string + for _, vv := range v.Host { + if vv != host { + newHosts = append(newHosts, vv) + } + } + v.Host = newHosts + if len(v.Host) == 0 { + delete(h.ipData, v.IP) + if v.lastuid != 0 { + h.fulldata[v.lastuid].nextuid = v.nextuid + } else { + h.nextUid = v.nextuid + } + if v.nextuid != 0 { + h.fulldata[v.nextuid].lastuid = v.lastuid + } else { + h.lastUid = v.lastuid + } + delete(h.fulldata, v.uid) + } } } - return res + return nil } -func GetAllListbyIp() map[string][]string { - res := make(map[string][]string) - lock.Lock() - defer lock.Unlock() - if reverse { - hostsCfg.Reverse() - reverse = !reverse +func (h *Host) SetIPHosts(ip string, hosts ...string) error { + err := h.RemoveIPs(ip) + if err != nil { + return err } - for _, v := range hostsCfg.Data[0].NodeData { - if v != nil { - res[v.Key] = v.Value + return h.AddHosts(ip, hosts...) +} + +func (h *Host) addHosts(comment string, ip string, hosts ...string) error { + h.Lock() + defer h.Unlock() + if h.hostData == nil { + return fmt.Errorf("hosts data not initialized") + } + ipInfo := h.listHostsByIP(ip) + var needAddHosts []string + for _, v := range hosts { + if !inArray(ipInfo, v) { + needAddHosts = append(needAddHosts, v) } } - return res + if len(needAddHosts) == 0 { + return nil + } + hostNode := HostNode{ + uid: h.idx + 1, + nextuid: 0, + lastuid: h.lastUid, + IP: ip, + Host: needAddHosts, + Valid: true, + Comment: comment, + } + h.idx++ + h.fulldata[h.lastUid].nextuid = h.idx + h.lastUid = h.idx + h.fulldata[h.idx] = &hostNode + h.ipData[ip] = append(h.ipData[ip], &hostNode) + for _, v := range needAddHosts { + h.hostData[v] = append(h.hostData[v], &hostNode) + } + return nil +} + +func inArray(arr []string, v string) bool { + for _, vv := range arr { + if v == vv { + return true + } + } + return false } diff --git a/hosts/hosts_test.go b/hosts/hosts_test.go index 46a3eab..80f1e2b 100644 --- a/hosts/hosts_test.go +++ b/hosts/hosts_test.go @@ -6,7 +6,32 @@ import ( ) func Test_Hosts(t *testing.T) { - //RemoveHostbyIp("192.168.222.33") - Parse() - fmt.Println(GetAllListbyIp()) + var h = NewHosts() + err := h.Parse("./test_hosts.txt") + if err != nil { + t.Error(err) + } + for _, v := range h.List() { + fmt.Printf("%+v\n", v) + } + fmt.Println(h.nextUid, h.lastUid) + fmt.Println("") + err = h.AddHosts("122.23.12.123", "b612.me", "ok.b612.me") + if err != nil { + t.Error(err) + } + for _, v := range h.List() { + fmt.Printf("%+v\n", v) + } + fmt.Println(h.nextUid, h.lastUid) + fmt.Println("") + err = h.RemoveIPHosts("11.22.33.44", "remove.b612.me", "test.dns.set.b612.me") + if err != nil { + t.Error(err) + } + for _, v := range h.List() { + fmt.Printf("%+v\n", v) + } + fmt.Println(h.nextUid, h.lastUid) + fmt.Println("") } diff --git a/hosts/test_hosts.txt b/hosts/test_hosts.txt new file mode 100644 index 0000000..718a539 --- /dev/null +++ b/hosts/test_hosts.txt @@ -0,0 +1,22 @@ +#hosts This file describes a number of hostname-to-address +#mappings for the TCP/IP subsystem. It is mostly +#used at boot time, when no name servers are running. +#On small systems, this file can be used instead of a +#"named" name server. +#Syntax: +#IP-Address Full-Qualifie +#special IPv6 addre +127.0.0.1 localhost +127.0.0.1 b612 +#special IPv6 addresses +::1 localhost ipv6-localhost ipv6-loopback +fe00::0 ipv6-localnet +ff00::0 ipv6-mcastprefix +ff02::1 ipv6-allnodes +ff02::2 ipv6-allrouters +ff02::3 ipv6-allhosts +1.2.3.4 ssh.b612.me +4.5.6.7 dns.b612.me +8.9.10.11 release-ftpd +11.22.33.44 test.dns.set.b612.me remove.b612.me +4.5.6.7 game.b612.me