diff --git a/proto_package.go b/proto_package.go index 385f229..a2362de 100644 --- a/proto_package.go +++ b/proto_package.go @@ -220,7 +220,7 @@ func (p *ProtoPackage) build(ctx context.Context) (time.Duration, error) { log.Errorf("error getting PGID: %v", err) } - var peakMem int + var peakMem int64 done := make(chan bool) go pollMemoryUsage(pgid, 1*time.Second, done, &peakMem) diff --git a/utils.go b/utils.go index 464ca19..62f86ac 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "context" "errors" "fmt" @@ -768,111 +767,123 @@ func downloadSRCINFO(pkg, tag string) (*srcinfo.Srcinfo, error) { return nSrcInfo, nil } -// getProcessMemory reads both RSS and Swap memory from /proc//status -func getProcessMemory(pid int) (int, int, error) { - statusPath := fmt.Sprintf("/proc/%d/status", pid) - file, err := os.Open(statusPath) - if err != nil { - return 0, 0, err - } - defer func(file *os.File) { - _ = file.Close() - }(file) +func getDescendantPIDs(rootPID int) ([]int, error) { + pidToPpid := map[int]int{} + var descendants []int - var rss, swap int - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 2 { - continue - } - - switch fields[0] { - case "VmRSS:": - rss, err = strconv.Atoi(fields[1]) - if err != nil { - return 0, 0, fmt.Errorf("failed to parse rss: %v", err) - } - case "VmSwap:": - swap, err = strconv.Atoi(fields[1]) - if err != nil { - return 0, 0, fmt.Errorf("failed to parse swap: %v", err) - } - } - } - return rss, swap, nil -} - -func getChildProcesses(pid int) ([]int, error) { - var children []int - taskPath := fmt.Sprintf("/proc/%d/task", pid) - - taskDirs, err := os.ReadDir(taskPath) + procEntries, err := os.ReadDir("/proc") if err != nil { return nil, err } - for _, task := range taskDirs { - childFile := fmt.Sprintf("/proc/%d/task/%s/children", pid, task.Name()) - file, err := os.Open(childFile) + for _, entry := range procEntries { + if !entry.IsDir() || entry.Name()[0] < '0' || entry.Name()[0] > '9' { + continue + } + pidStr := entry.Name() + pid, err := strconv.Atoi(pidStr) if err != nil { continue } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - childPIDs := strings.Fields(scanner.Text()) - for _, childStr := range childPIDs { - childPID, err := strconv.Atoi(childStr) - if err == nil { - children = append(children, childPID) - subChildren, _ := getChildProcesses(childPID) - children = append(children, subChildren...) + statusPath := filepath.Join("/proc", pidStr, "status") + data, err := os.ReadFile(statusPath) + if err != nil { + continue + } + + for _, line := range strings.Split(string(data), "\n") { + if strings.HasPrefix(line, "PPid:") { + fields := strings.Fields(line) + if len(fields) == 2 { + ppid, _ := strconv.Atoi(fields[1]) + pidToPpid[pid] = ppid } } } - _ = file.Close() } - return children, nil + + var walk func(int) + walk = func(current int) { + for pid, ppid := range pidToPpid { + if ppid == current { + descendants = append(descendants, pid) + walk(pid) + } + } + } + walk(rootPID) + return descendants, nil } -func getProcessTreeMemory(rootPID int) (int, int, error) { - pids := []int{rootPID} +type MemStats struct { + RSS int64 + Swap int64 +} - childPIDs, err := getChildProcesses(rootPID) - if err == nil { - pids = append(pids, childPIDs...) +func getMemoryStats(pid int) (MemStats, error) { + statusPath := fmt.Sprintf("/proc/%d/status", pid) + data, err := os.ReadFile(statusPath) + if err != nil { + return MemStats{}, err } - totalRSS, totalSwap := 0, 0 - for _, pid := range pids { - rss, swap, err := getProcessMemory(pid) - if err == nil { - totalRSS += rss - totalSwap += swap + stats := MemStats{} + for _, line := range strings.Split(string(data), "\n") { + if strings.HasPrefix(line, "VmRSS:") { + fields := strings.Fields(line) + if len(fields) >= 2 { + kb, _ := strconv.ParseInt(fields[1], 10, 64) + stats.RSS = kb * 1024 + } + } + if strings.HasPrefix(line, "VmSwap:") { + fields := strings.Fields(line) + if len(fields) >= 2 { + kb, _ := strconv.ParseInt(fields[1], 10, 64) + stats.Swap = kb * 1024 + } } } - return totalRSS, totalSwap, nil + return stats, nil } -func pollMemoryUsage(pgid int, interval time.Duration, done chan bool, peakMem *int) { +func pollMemoryUsage(pid int, interval time.Duration, done chan bool, peakMem *int64) { for { select { case <-done: return default: - rss, swap, err := getProcessTreeMemory(pgid) - if err == nil { - totalMemory := rss + swap + totalRSS := int64(0) + totalSwap := int64(0) - if totalMemory > *peakMem { - peakMem = &totalMemory - } + rootStats, err := getMemoryStats(pid) + if err == nil { + totalRSS += rootStats.RSS + totalSwap += rootStats.Swap } else { - log.Warningf("failed to get process tree memory: %v", err) + log.Errorf("failed to get memory stats for root process: %v", err) } + + descendants, err := getDescendantPIDs(pid) + if err != nil { + log.Errorf("failed to get descendants: %v", err) + } + + for _, dpid := range descendants { + stats, err := getMemoryStats(dpid) + if err == nil { + totalRSS += stats.RSS + totalSwap += stats.Swap + } + } + + totalMemory := totalRSS + totalSwap + if totalMemory > *peakMem { + peakMem = &totalMemory + } + time.Sleep(interval) } }