diff --git a/hosts/hosts.go b/hosts/hosts.go index d3960bb..b056de2 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -2,6 +2,7 @@ package hosts import ( "bufio" + "bytes" "fmt" "io" "net" @@ -11,6 +12,15 @@ import ( "sync" ) +var lineBreaker string + +func init() { + if runtime.GOOS == "windows" { + lineBreaker = "\r\n" + } else { + lineBreaker = "\n" + } +} func SystemHostpath() string { if runtime.GOOS == "windows" { return `C:\Windows\System32\drivers\etc\hosts` @@ -432,6 +442,47 @@ func (h *Host) addHosts(comment string, ip string, hosts ...string) error { return nil } +func (h *Host) InsertNodes(node *HostNode, before bool, comment string, ip string, hosts ...string) error { + h.Lock() + defer h.Unlock() + if h.hostData == nil { + return fmt.Errorf("hosts data not initialized") + } + if _, ok := h.fulldata[node.uid]; !ok { + return fmt.Errorf("node not exists") + } + hostNode := HostNode{ + uid: h.idx + 1, + IP: ip, + Host: hosts, + Valid: true, + Comment: comment, + } + if ip == "" && len(hosts) == 0 && comment != "" { + hostNode.OnlyComment = true + } + if before { + hostNode.nextuid = node.uid + hostNode.lastuid = node.lastuid + if node.lastuid != 0 { + h.fulldata[node.lastuid].nextuid = h.idx + } else { + h.nextUid = h.idx + } + node.lastuid = h.idx + } else { + hostNode.lastuid = node.uid + hostNode.nextuid = node.nextuid + if node.nextuid != 0 { + h.fulldata[node.nextuid].lastuid = h.idx + } else { + h.lastUid = h.idx + } + node.nextuid = h.idx + } + return nil +} + func inArray(arr []string, v string) bool { for _, vv := range arr { if v == vv { @@ -440,3 +491,77 @@ func inArray(arr []string, v string) bool { } return false } + +func (h *Host) SaveAs(path string) error { + h.Lock() + defer h.Unlock() + return h.save(path) +} + +func (h *Host) Save() error { + h.Lock() + defer h.Unlock() + return h.save(h.hostPath) +} + +func (h *Host) save(path string) error { + h.Lock() + defer h.Unlock() + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return fmt.Errorf("open hosts file %s error: %s", h.hostPath, err) + } + defer f.Close() + for _, v := range h.fulldata { + if v.OnlyComment { + if _, err := f.WriteString(v.Comment + "\n"); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + continue + } + if _, err := f.WriteString(v.IP + " "); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + if _, err := f.WriteString(strings.Join(v.Host, " ")); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + if len(v.Comment) > 0 { + if _, err := f.WriteString(" " + v.Comment); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + } + if _, err := f.WriteString(lineBreaker); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + } + return nil +} + +func (h *Host) Build() ([]byte, error) { + h.Lock() + defer h.Unlock() + var f bytes.Buffer + for _, v := range h.fulldata { + if v.OnlyComment { + if _, err := f.WriteString(v.Comment + "\n"); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + continue + } + if _, err := f.WriteString(v.IP + " "); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + if _, err := f.WriteString(strings.Join(v.Host, " ")); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + if len(v.Comment) > 0 { + if _, err := f.WriteString(" " + v.Comment); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + } + if _, err := f.WriteString(lineBreaker); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + } + return f.Bytes(), nil +}