package hosts import ( "bufio" "bytes" "fmt" "io" "net" "os" "runtime" "strings" "sync" ) var lineBreaker string func init() { if runtime.GOOS == "windows" { lineBreaker = "\r\n" } else { lineBreaker = "\n" } } func SystemHostpath() string { if runtime.GOOS == "windows" { return `C:\Windows\System32\drivers\etc\hosts` } return `/etc/hosts` } type HostNode struct { uid uint64 nextuid uint64 lastuid uint64 IP string Host []string Comment string Original string OnlyComment bool Valid bool } type Host struct { idx uint64 hostPath string nextUid uint64 lastUid uint64 fulldata map[uint64]*HostNode hostData map[string][]*HostNode ipData map[string][]*HostNode sync.RWMutex } func NewHosts() *Host { return &Host{ hostData: make(map[string][]*HostNode), ipData: make(map[string][]*HostNode), fulldata: make(map[uint64]*HostNode), } } func (h *Host) Parse(hostPath string) error { h.Lock() defer h.Unlock() h.hostPath = hostPath h.fulldata = make(map[uint64]*HostNode) h.hostData = make(map[string][]*HostNode) h.ipData = make(map[string][]*HostNode) return h.parse() } func (h *Host) parse() error { f, err := os.OpenFile(h.hostPath, os.O_RDONLY, 0666) if err != nil { return fmt.Errorf("open hosts file %s error: %s", h.hostPath, err) } defer f.Close() buf := bufio.NewReader(f) for { line, err := buf.ReadString('\n') if err == io.EOF { if h.idx-1 >= 0 { h.fulldata[h.idx].nextuid = 0 h.lastUid = h.idx } break } if err != nil { return fmt.Errorf("read hosts file error: %s", err) } h.idx++ line = strings.TrimSpace(line) data, _ := h.parseLine(line) data.uid = h.idx data.lastuid = h.idx - 1 data.nextuid = h.idx + 1 h.fulldata[data.uid] = &data if h.nextUid == 0 { h.nextUid = h.idx } if data.Valid { for _, v := range data.Host { h.hostData[v] = append(h.hostData[v], &data) } h.ipData[data.IP] = append(h.ipData[data.IP], &data) } } return nil } func (h *Host) parseLine(data string) (HostNode, error) { var res = HostNode{ Original: data, } if len(data) == 0 { return res, fmt.Errorf("empty line") } if data[0] == '#' { res.Comment = data res.OnlyComment = true return res, nil } var dataArr []string cache := "" for k, v := range data { if v == '#' { if len(cache) > 0 { dataArr = append(dataArr, cache) cache = "" } dataArr = append(dataArr, data[k:]) break } if v == ' ' || v == '\t' { if len(cache) > 0 { dataArr = append(dataArr, cache) cache = "" } continue } cache += string(v) } if len(cache) > 0 { dataArr = append(dataArr, cache) } if len(dataArr) < 2 { return res, fmt.Errorf("invalid line") } if strings.HasPrefix(dataArr[1], "#") { return res, fmt.Errorf("invalid line") } if net.ParseIP(dataArr[0]) == nil { return res, fmt.Errorf("invalid ip address") } for k, v := range dataArr { switch k { case 0: res.IP = v default: if strings.HasPrefix(v, "#") { res.Comment = v } else { res.Host = append(res.Host, v) } } } res.Valid = true return res, nil } func (h *Host) List() []*HostNode { h.RLock() defer h.RUnlock() var res []*HostNode nextUid := h.nextUid for { if nextUid == 0 { break } res = append(res, h.fulldata[nextUid]) nextUid = h.fulldata[nextUid].nextuid } return res } func (h *Host) ListByHost(host string) []*HostNode { h.RLock() defer h.RUnlock() if h.hostData == nil { return nil } return h.hostData[host] } func (h *Host) ListIPsByHost(host string) []string { h.RLock() defer h.RUnlock() if h.hostData == nil { return nil } var res []string for _, v := range h.hostData[host] { res = append(res, v.IP) } return res } func (h *Host) ListFirstIPByHost(host string) string { h.RLock() defer h.RUnlock() if h.hostData == nil { return "" } for _, v := range h.hostData[host] { return v.IP } return "" } func (h *Host) ListByIP(ip string) []*HostNode { h.RLock() defer h.RUnlock() if h.ipData == nil { return nil } return h.ipData[ip] } func (h *Host) ListHostsByIP(ip string) []string { h.RLock() return h.listHostsByIP(ip) } func (h *Host) listHostsByIP(ip string) []string { if h.ipData == nil { return nil } var res []string for _, v := range h.ipData[ip] { res = append(res, v.Host...) } return res } func (h *Host) ListFirstHostByIP(ip string) string { h.RLock() defer h.RUnlock() if h.ipData == nil { return "" } for _, v := range h.ipData[ip] { if len(v.Host) > 0 { return v.Host[0] } } return "" } func (h *Host) AddHosts(ip string, hosts ...string) error { return h.addHosts("", ip, hosts...) } func (h *Host) AddHostsComment(comment string, ip string, hosts ...string) error { return h.addHosts(comment, ip, hosts...) } func (h *Host) RemoveIPHosts(ip string, hosts ...string) error { h.Lock() defer h.Unlock() if h.ipData == nil { return fmt.Errorf("hosts data not initialized") } ipInfo := h.ipData[ip] if len(ipInfo) == 0 { return nil } cntfor: for _, v := range ipInfo { for _, host := range hosts { if len(v.Host) == 1 && v.Host[0] == host { delete(h.ipData, ip) if v.lastuid != 0 { h.fulldata[v.lastuid].nextuid = v.nextuid } else { h.nextUid = v.nextuid } if v.nextuid != 0 { h.fulldata[v.nextuid].lastuid = v.lastuid } else { h.lastUid = v.lastuid } var newHostData []*HostNode for _, vv := range h.hostData[v.Host[0]] { if vv.uid != v.uid { newHostData = append(newHostData, vv) } } h.hostData[host] = newHostData delete(h.fulldata, v.uid) v = nil continue cntfor } if len(v.Host) > 1 { var newHosts []string for _, vv := range v.Host { if vv != host { newHosts = append(newHosts, vv) } } v.Host = newHosts var newHostData []*HostNode for _, vv := range h.hostData[host] { if vv.uid != v.uid { newHostData = append(newHostData, vv) } } h.hostData[host] = newHostData } } } return nil } func (h *Host) RemoveIPs(ips ...string) error { h.Lock() defer h.Unlock() if h.ipData == nil { return fmt.Errorf("hosts data not initialized") } for _, ip := range ips { ipInfo := h.ipData[ip] if len(ipInfo) == 0 { continue } for _, v := range ipInfo { delete(h.ipData, ip) delete(h.fulldata, v.uid) if v.lastuid != 0 { h.fulldata[v.lastuid].nextuid = v.nextuid } else { h.nextUid = v.nextuid } if v.nextuid != 0 { h.fulldata[v.nextuid].lastuid = v.lastuid } else { h.lastUid = v.lastuid } for _, host := range v.Host { var newHostData []*HostNode for _, vv := range h.hostData[host] { if vv.uid != v.uid { newHostData = append(newHostData, vv) } } h.hostData[host] = newHostData } } } return nil } func (h *Host) RemoveHosts(hosts ...string) error { h.Lock() defer h.Unlock() if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } for _, host := range hosts { hostInfo := h.hostData[host] if len(hostInfo) == 0 { continue } delete(h.hostData, host) for _, v := range hostInfo { var newHosts []string for _, vv := range v.Host { if vv != host { newHosts = append(newHosts, vv) } } v.Host = newHosts if len(v.Host) == 0 { delete(h.ipData, v.IP) if v.lastuid != 0 { h.fulldata[v.lastuid].nextuid = v.nextuid } else { h.nextUid = v.nextuid } if v.nextuid != 0 { h.fulldata[v.nextuid].lastuid = v.lastuid } else { h.lastUid = v.lastuid } delete(h.fulldata, v.uid) } } } return nil } func (h *Host) SetIPHosts(ip string, hosts ...string) error { err := h.RemoveIPs(ip) if err != nil { return err } return h.AddHosts(ip, hosts...) } func (h *Host) addHosts(comment string, ip string, hosts ...string) error { h.Lock() defer h.Unlock() if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } ipInfo := h.listHostsByIP(ip) var needAddHosts []string for _, v := range hosts { if !inArray(ipInfo, v) { needAddHosts = append(needAddHosts, v) } } if len(needAddHosts) == 0 { return nil } hostNode := HostNode{ uid: h.idx + 1, nextuid: 0, lastuid: h.lastUid, IP: ip, Host: needAddHosts, Valid: true, Comment: comment, } h.idx++ h.fulldata[h.lastUid].nextuid = h.idx h.lastUid = h.idx h.fulldata[h.idx] = &hostNode h.ipData[ip] = append(h.ipData[ip], &hostNode) for _, v := range needAddHosts { h.hostData[v] = append(h.hostData[v], &hostNode) } return nil } func (h *Host) InsertNodes(node *HostNode, before bool, comment string, ip string, hosts ...string) error { h.Lock() defer h.Unlock() if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } if _, ok := h.fulldata[node.uid]; !ok { return fmt.Errorf("node not exists") } hostNode := HostNode{ uid: h.idx + 1, IP: ip, Host: hosts, Valid: true, Comment: comment, } if ip == "" && len(hosts) == 0 && comment != "" { hostNode.OnlyComment = true } if before { hostNode.nextuid = node.uid hostNode.lastuid = node.lastuid if node.lastuid != 0 { h.fulldata[node.lastuid].nextuid = h.idx } else { h.nextUid = h.idx } node.lastuid = h.idx } else { hostNode.lastuid = node.uid hostNode.nextuid = node.nextuid if node.nextuid != 0 { h.fulldata[node.nextuid].lastuid = h.idx } else { h.lastUid = h.idx } node.nextuid = h.idx } return nil } func inArray(arr []string, v string) bool { for _, vv := range arr { if v == vv { return true } } return false } func (h *Host) SaveAs(path string) error { h.Lock() defer h.Unlock() return h.save(path) } func (h *Host) Save() error { h.Lock() defer h.Unlock() return h.save(h.hostPath) } func (h *Host) save(path string) error { h.Lock() defer h.Unlock() f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) if err != nil { return fmt.Errorf("open hosts file %s error: %s", h.hostPath, err) } defer f.Close() for _, v := range h.fulldata { if v.OnlyComment { if _, err := f.WriteString(v.Comment + "\n"); err != nil { return fmt.Errorf("write hosts file error: %s", err) } continue } if _, err := f.WriteString(v.IP + " "); err != nil { return fmt.Errorf("write hosts file error: %s", err) } if _, err := f.WriteString(strings.Join(v.Host, " ")); err != nil { return fmt.Errorf("write hosts file error: %s", err) } if len(v.Comment) > 0 { if _, err := f.WriteString(" " + v.Comment); err != nil { return fmt.Errorf("write hosts file error: %s", err) } } if _, err := f.WriteString(lineBreaker); err != nil { return fmt.Errorf("write hosts file error: %s", err) } } return nil } func (h *Host) Build() ([]byte, error) { h.Lock() defer h.Unlock() var f bytes.Buffer for _, v := range h.fulldata { if v.OnlyComment { if _, err := f.WriteString(v.Comment + "\n"); err != nil { return nil, fmt.Errorf("write hosts file error: %s", err) } continue } if _, err := f.WriteString(v.IP + " "); err != nil { return nil, fmt.Errorf("write hosts file error: %s", err) } if _, err := f.WriteString(strings.Join(v.Host, " ")); err != nil { return nil, fmt.Errorf("write hosts file error: %s", err) } if len(v.Comment) > 0 { if _, err := f.WriteString(" " + v.Comment); err != nil { return nil, fmt.Errorf("write hosts file error: %s", err) } } if _, err := f.WriteString(lineBreaker); err != nil { return nil, fmt.Errorf("write hosts file error: %s", err) } } return f.Bytes(), nil }