diff --git a/proto_package.go b/proto_package.go index 4ccd794..385f229 100644 --- a/proto_package.go +++ b/proto_package.go @@ -225,7 +225,6 @@ func (p *ProtoPackage) build(ctx context.Context) (time.Duration, error) { go pollMemoryUsage(pgid, 1*time.Second, done, &peakMem) err = cmd.Wait() - close(done) Rusage, ok := cmd.ProcessState.SysUsage().(*syscall.Rusage) diff --git a/utils.go b/utils.go index afffd2f..464ca19 100644 --- a/utils.go +++ b/utils.go @@ -768,6 +768,7 @@ func downloadSRCINFO(pkg, tag string) (*srcinfo.Srcinfo, error) { return nSrcInfo, nil } +// getProcessMemory reads both RSS and Swap memory from /proc//status func getProcessMemory(pid int) (int, int, error) { statusPath := fmt.Sprintf("/proc/%d/status", pid) file, err := os.Open(statusPath) @@ -803,58 +804,55 @@ func getProcessMemory(pid int) (int, int, error) { return rss, swap, nil } -func getProcessTreeMemory(pgid int) (int, int, error) { - procDir, err := os.Open("/proc") - if err != nil { - return 0, 0, err - } - defer func(procDir *os.File) { - _ = procDir.Close() - }(procDir) +func getChildProcesses(pid int) ([]int, error) { + var children []int + taskPath := fmt.Sprintf("/proc/%d/task", pid) - totalRSS, totalSwap := 0, 0 - entries, err := procDir.Readdir(-1) + taskDirs, err := os.ReadDir(taskPath) if err != nil { - return 0, 0, fmt.Errorf("failed to read /proc: %v", err) + return nil, err } - for _, entry := range entries { - if !entry.IsDir() { - continue - } - - pid, err := strconv.Atoi(entry.Name()) - if err != nil { - continue - } - - statPath := fmt.Sprintf("/proc/%d/stat", pid) - file, err := os.Open(statPath) + for _, task := range taskDirs { + childFile := fmt.Sprintf("/proc/%d/task/%s/children", pid, task.Name()) + file, err := os.Open(childFile) if err != nil { continue } scanner := bufio.NewScanner(file) - if scanner.Scan() { - fields := strings.Fields(scanner.Text()) - if len(fields) >= 5 { - processPGID, err := strconv.Atoi(fields[4]) - if err != nil { - return 0, 0, fmt.Errorf("failed to parse process PG ID: %v", err) - } - if processPGID == pgid { - rss, swap, err := getProcessMemory(pid) - if err == nil { - totalRSS += rss - totalSwap += swap - } else { - return 0, 0, fmt.Errorf("failed to get process memory: %v", err) - } + for scanner.Scan() { + childPIDs := strings.Fields(scanner.Text()) + for _, childStr := range childPIDs { + childPID, err := strconv.Atoi(childStr) + if err == nil { + children = append(children, childPID) + subChildren, _ := getChildProcesses(childPID) + children = append(children, subChildren...) } } } _ = file.Close() } + return children, nil +} + +func getProcessTreeMemory(rootPID int) (int, int, error) { + pids := []int{rootPID} + + childPIDs, err := getChildProcesses(rootPID) + if err == nil { + pids = append(pids, childPIDs...) + } + + totalRSS, totalSwap := 0, 0 + for _, pid := range pids { + rss, swap, err := getProcessMemory(pid) + if err == nil { + totalRSS += rss + totalSwap += swap + } + } return totalRSS, totalSwap, nil }