diff --git a/deployment/clouddeploy/gke-workers/base/gitter.yaml b/deployment/clouddeploy/gke-workers/base/gitter.yaml index b89b5077649..d4ae1fbc83a 100644 --- a/deployment/clouddeploy/gke-workers/base/gitter.yaml +++ b/deployment/clouddeploy/gke-workers/base/gitter.yaml @@ -27,8 +27,8 @@ spec: imagePullPolicy: Always args: - "--port=8888" - - "--work_dir=/work/gitter" - - "--fetch_timeout=1h" + - "--work-dir=/work/gitter" + - "--fetch-timeout=1h" volumeMounts: - mountPath: /work name: disk-data diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml index 67aaeef676f..3a74ef22331 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml @@ -9,8 +9,8 @@ spec: - name: gitter args: - "--port=8888" - - "--work_dir=/work/gitter" - - "--fetch_timeout=1h" + - "--work-dir=/work/gitter" + - "--fetch-timeout=1h" env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb-test diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml index 840f7eccc09..fc1558ba8b4 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml @@ -9,8 +9,8 @@ spec: - name: gitter args: - "--port=8888" - - "--work_dir=/work/gitter" - - "--fetch_timeout=1h" + - "--work-dir=/work/gitter" + - "--fetch-timeout=1h" env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb diff --git a/go/cmd/gitter/gitter.go b/go/cmd/gitter/gitter.go index 79eba120170..0569fbfda17 100644 --- a/go/cmd/gitter/gitter.go +++ b/go/cmd/gitter/gitter.go @@ -6,7 +6,7 @@ import ( "context" "crypto/sha256" "encoding/hex" - "errors" + "encoding/json" "flag" "fmt" "io" @@ -22,22 +22,37 @@ import ( "syscall" "time" + "runtime/pprof" + "github.com/google/osv.dev/go/logger" "golang.org/x/sync/singleflight" ) -const getGitEndpoint = "/getgit" +// API Endpoints +var endpointHandlers = map[string]http.HandlerFunc{ + "GET /git": gitHandler, + "POST /cache": cacheHandler, + "POST /affected-commits": affectedCommitsHandler, +} + const defaultGitterWorkDir = "/work/gitter" -const persistanceFileName = "last-fetch.json" +const persistenceFileName = "last-fetch.json" const gitStoreFileName = "git-store" var ( - g singleflight.Group - persistancePath = path.Join(defaultGitterWorkDir, persistanceFileName) + gFetch singleflight.Group + gArchive singleflight.Group + gLoad singleflight.Group + persistencePath = path.Join(defaultGitterWorkDir, persistenceFileName) gitStorePath = path.Join(defaultGitterWorkDir, gitStoreFileName) fetchTimeout time.Duration ) +type Event struct { + EventType string `json:"eventType"` // TODO: enum this + Hash string `json:"hash"` +} + const shutdownTimeout = 10 * time.Second // runCmd executes a command with context cancellation handled by sending SIGINT. @@ -71,6 +86,23 @@ func runCmd(ctx context.Context, dir string, env []string, name string, args ... return nil } +// prepareCmd prepares the command with context cancellation handled by sending SIGINT. +func prepareCmd(ctx context.Context, dir string, env []string, name string, args ...string) *exec.Cmd { + cmd := exec.CommandContext(ctx, name, args...) + if dir != "" { + cmd.Dir = dir + } + if len(env) > 0 { + cmd.Env = append(os.Environ(), env...) + } + // Use SIGINT instead of SIGKILL for graceful shutdown of subprocesses + cmd.Cancel = func() error { + return cmd.Process.Signal(syscall.SIGINT) + } + + return cmd +} + func isLocalRequest(r *http.Request) bool { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -103,10 +135,12 @@ func isAuthError(err error) bool { (strings.Contains(strings.ToLower(errString), "repository") && strings.Contains(strings.ToLower(errString), "not found")) } -func fetchBlob(ctx context.Context, url string, forceUpdate bool) ([]byte, error) { +func fetchRepo(ctx context.Context, url string, forceUpdate bool) error { + logger.Info("Starting fetch repo", slog.String("url", url)) + start := time.Now() + repoDirName := getRepoDirName(url) repoPath := path.Join(gitStorePath, repoDirName) - archivePath := repoPath + ".zst" lastFetchMu.Lock() accessTime, ok := lastFetch[url] @@ -119,7 +153,7 @@ func fetchBlob(ctx context.Context, url string, forceUpdate bool) ([]byte, error // Clone err := runCmd(ctx, "", []string{"GIT_TERMINAL_PROMPT=0"}, "git", "clone", "--", url, repoPath) if err != nil { - return nil, fmt.Errorf("git clone failed: %w", err) + return fmt.Errorf("git clone failed: %w", err) } } else { // Fetch/Pull - implementing simple git pull for now, might need reset --hard if we want exact mirrors @@ -127,14 +161,45 @@ func fetchBlob(ctx context.Context, url string, forceUpdate bool) ([]byte, error // Ideally safely: git fetch origin && git reset --hard origin/HEAD err := runCmd(ctx, repoPath, nil, "git", "fetch", "origin") if err != nil { - return nil, fmt.Errorf("git fetch failed: %w", err) + return fmt.Errorf("git fetch failed: %w", err) } err = runCmd(ctx, repoPath, nil, "git", "reset", "--hard", "origin/HEAD") if err != nil { - return nil, fmt.Errorf("git reset failed: %w", err) + return fmt.Errorf("git reset failed: %w", err) } } + updateLastFetch(url) + } + + // Double check if the git directory exist + _, err := os.Stat(path.Join(repoPath, ".git")) + if err != nil { + if os.IsNotExist(err) { + deleteLastFetch(url) + } + + return fmt.Errorf("failed to read file: %w", err) + } + + logger.Info("Fetch completed", slog.Duration("duration", time.Since(start))) + + return nil +} + +func archiveRepo(ctx context.Context, url string) ([]byte, error) { + repoDirName := getRepoDirName(url) + repoPath := path.Join(gitStorePath, repoDirName) + archivePath := repoPath + ".zst" + + lastFetchMu.Lock() + accessTime := lastFetch[url] + lastFetchMu.Unlock() + + // Check if archive needs update + // We update if archive does not exist OR if it is older than the last fetch + stats, err := os.Stat(archivePath) + if os.IsNotExist(err) || (err == nil && stats.ModTime().Before(accessTime)) { logger.Info("Archiving git blob", slog.String("url", url)) // Archive // tar --zstd -cf -C "/" . @@ -143,8 +208,6 @@ func fetchBlob(ctx context.Context, url string, forceUpdate bool) ([]byte, error if err != nil { return nil, fmt.Errorf("tar zstd failed: %w", err) } - - updateLastFetch(url) } // If the context is cancelled, still do the fetching stuff, just don't bother returning the result @@ -155,10 +218,6 @@ func fetchBlob(ctx context.Context, url string, forceUpdate bool) ([]byte, error fileData, err := os.ReadFile(archivePath) if err != nil { - if errors.Is(err, os.ErrNotExist) { - deleteLastFetch(url) - } - return nil, fmt.Errorf("failed to read file: %w", err) } @@ -166,12 +225,26 @@ func fetchBlob(ctx context.Context, url string, forceUpdate bool) ([]byte, error } func main() { + cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`") + port := flag.Int("port", 8888, "Listen port") - workDir := flag.String("work_dir", defaultGitterWorkDir, "Work directory") - flag.DurationVar(&fetchTimeout, "fetch_timeout", time.Hour, "Fetch timeout duration") + workDir := flag.String("work-dir", defaultGitterWorkDir, "Work directory") + flag.DurationVar(&fetchTimeout, "fetch-timeout", time.Hour, "Fetch timeout duration") flag.Parse() - persistancePath = path.Join(*workDir, persistanceFileName) + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + logger.Error("could not create CPU profile", slog.Any("error", err)) + } + defer f.Close() + if err := pprof.StartCPUProfile(f); err != nil { + logger.Error("could not start CPU profile", slog.Any("error", err)) + } + defer pprof.StopCPUProfile() + } + + persistencePath = path.Join(*workDir, persistenceFileName) gitStorePath = path.Join(*workDir, gitStoreFileName) if err := os.MkdirAll(gitStorePath, 0755); err != nil { @@ -179,13 +252,15 @@ func main() { os.Exit(1) } - loadMap() + loadLastFetchMap() // Create a context that listens for the interrupt signal from the OS. ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - http.HandleFunc(getGitEndpoint, gitHandler) + for endpoint, handler := range endpointHandlers { + http.HandleFunc(endpoint, handler) + } logger.Info("Gitter starting and listening", slog.Int("port", *port)) @@ -230,7 +305,7 @@ func main() { logger.Error("Server forced to shutdown", slog.Any("error", err)) } - saveMap() + saveLastFetchMap() logger.Info("Server exiting") } @@ -252,6 +327,7 @@ func gitHandler(w http.ResponseWriter, r *http.Request) { } } + // Fetch repo first // Keep the key as the url regardless of forceUpdate. // Occasionally this could be problematic if an existing unforce updated // query is already inplace, no force update will happen. @@ -259,14 +335,13 @@ func gitHandler(w http.ResponseWriter, r *http.Request) { // the repo once, and always with force update. // This is a tradeoff for simplicity to avoid having to setup locks per repo. //nolint:contextcheck // I can't change singleflight's interface - fileData, err, _ := g.Do(url, func() (any, error) { - return fetchBlob(r.Context(), url, forceUpdate) - }) - - if err != nil { - logger.Error("Error fetching/archiving blob", slog.String("url", url), slog.Any("error", err)) + if _, err, _ := gFetch.Do(url, func() (any, error) { + return nil, fetchRepo(r.Context(), url, forceUpdate) + }); err != nil { + logger.Error("Error fetching blob", slog.String("url", url), slog.Any("error", err)) if isAuthError(err) { http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusForbidden) + return } http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusInternalServerError) @@ -274,10 +349,23 @@ func gitHandler(w http.ResponseWriter, r *http.Request) { return } + // Archive repo + //nolint:contextcheck // I can't change singleflight's interface + fileDataAny, err, _ := gArchive.Do(url, func() (any, error) { + return archiveRepo(r.Context(), url) + }) + if err != nil { + logger.Error("Error archiving blob", slog.String("url", url), slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Error archiving blob: %v", err), http.StatusInternalServerError) + + return + } + fileData := fileDataAny.([]byte) + w.Header().Set("Content-Type", "application/zstd") w.Header().Set("Content-Disposition", "attachment; filename=\"git-blob.zst\"") w.WriteHeader(http.StatusOK) - if _, err := io.Copy(w, bytes.NewReader(fileData.([]byte))); err != nil { + if _, err := io.Copy(w, bytes.NewReader(fileData)); err != nil { logger.Error("Error copying file", slog.String("url", url), slog.Any("error", err)) http.Error(w, "Error copying file", http.StatusInternalServerError) @@ -286,3 +374,165 @@ func gitHandler(w http.ResponseWriter, r *http.Request) { logger.Info("Request completed successfully", slog.String("url", url)) } + +func cacheHandler(w http.ResponseWriter, r *http.Request) { + start := time.Now() + // POST requets body processing + var body struct { + URL string `json:"url"` + ForceUpdate bool `json:"force_update"` + } + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + http.Error(w, fmt.Sprintf("Error decoding JSON: %v", err), http.StatusBadRequest) + + return + } + defer r.Body.Close() + + url := body.URL + logger.Info("Received request: /cache", slog.String("url", url)) + + // Fetch repo if it's not fresh + //nolint:contextcheck // I can't change singleflight's interface + if _, err, _ := gFetch.Do(url, func() (any, error) { + return nil, fetchRepo(r.Context(), url, body.ForceUpdate) + }); err != nil { + logger.Error("Error fetching blob", slog.String("url", url), slog.Any("error", err)) + if isAuthError(err) { + http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusForbidden) + + return + } + http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusInternalServerError) + + return + } + + repoDirName := getRepoDirName(url) + repoPath := path.Join(gitStorePath, repoDirName) + + //nolint:contextcheck // I can't change singleflight's interface + _, err, _ = gLoad.Do(repoPath, func() (any, error) { + return LoadRepository(r.Context(), repoPath) + }) + if err != nil { + logger.Error("Failed to load repository", slog.String("url", url), slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Failed to load repository: %v", err), http.StatusInternalServerError) + + return + } + + w.WriteHeader(http.StatusOK) + logger.Info("Request completed successfully: /cache", slog.String("url", url), slog.Duration("duration", time.Since(start))) +} + +func affectedCommitsHandler(w http.ResponseWriter, r *http.Request) { + start := time.Now() + // POST requets body processing + var body struct { + URL string `json:"url"` + Events []Event `json:"events"` + DetectCherrypicks bool `json:"detect_cherrypicks"` + ForceUpdate bool `json:"force_update"` + } + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + http.Error(w, fmt.Sprintf("Error decoding JSON: %v", err), http.StatusBadRequest) + + return + } + defer r.Body.Close() + + url := body.URL + introduced := []SHA1{} + fixed := []SHA1{} + lastAffected := []SHA1{} + limit := []SHA1{} + cherrypick := body.DetectCherrypicks + + for _, event := range body.Events { + hash, err := hex.DecodeString(event.Hash) + if err != nil { + logger.Error("Error parsing hash", slog.String("hash", event.Hash), slog.Any("error", err)) + continue + } + + switch event.EventType { + case "introduced": + introduced = append(introduced, SHA1(hash)) + case "fixed": + fixed = append(fixed, SHA1(hash)) + case "last_affected": + lastAffected = append(lastAffected, SHA1(hash)) + case "limit": + limit = append(limit, SHA1(hash)) + default: + logger.Error("Invalid event type", slog.String("event_type", event.EventType)) + continue + } + } + logger.Info("Received request: /affected-commits", slog.String("url", url), slog.Any("introduced", introduced), slog.Any("fixed", fixed), slog.Any("last_affected", lastAffected), slog.Any("limit", limit), slog.Bool("cherrypick", cherrypick)) + + // Limit and fixed/last_affected shouldn't exist in the same request as it doesn't make sense + if (len(fixed) > 0 || len(lastAffected) > 0) && len(limit) > 0 { + http.Error(w, "Limit and fixed/last_affected shouldn't exist in the same request", http.StatusBadRequest) + + return + } + + // Fetch repo if it's not fresh + //nolint:contextcheck // I can't change singleflight's interface + if _, err, _ := gFetch.Do(url, func() (any, error) { + return nil, fetchRepo(r.Context(), url, body.ForceUpdate) + }); err != nil { + logger.Error("Error fetching blob", slog.String("url", url), slog.Any("error", err)) + if isAuthError(err) { + http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusForbidden) + + return + } + http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusInternalServerError) + + return + } + + repoDirName := getRepoDirName(url) + repoPath := path.Join(gitStorePath, repoDirName) + + //nolint:contextcheck // I can't change singleflight's interface + repoAny, err, _ := gLoad.Do(repoPath, func() (any, error) { + return LoadRepository(r.Context(), repoPath) + }) + if err != nil { + logger.Error("Failed to load repository", slog.String("url", url), slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Failed to load repository: %v", err), http.StatusInternalServerError) + + return + } + repo := repoAny.(*Repository) + + var affectedCommits []*Commit + if len(limit) > 0 { + affectedCommits = repo.Between(introduced, limit) + } else { + affectedCommits = repo.Affected(introduced, fixed, lastAffected, cherrypick) + } + + if err != nil { + logger.Error("Error processing affected commits", slog.String("url", url), slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Error processing affected commits: %v", err), http.StatusInternalServerError) + + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(affectedCommits); err != nil { + logger.Error("Error encoding affected commits", slog.String("url", url), slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Error encoding affected commits: %v", err), http.StatusInternalServerError) + + return + } + logger.Info("Request completed successfully: /affected-commits", slog.String("url", url), slog.Duration("duration", time.Since(start))) +} diff --git a/go/cmd/gitter/gitter_test.go b/go/cmd/gitter/gitter_test.go index 38cb7603209..30d77dc6868 100644 --- a/go/cmd/gitter/gitter_test.go +++ b/go/cmd/gitter/gitter_test.go @@ -1,6 +1,8 @@ package main import ( + "bytes" + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -62,7 +64,7 @@ func TestGitHandler_InvalidURL(t *testing.T) { } for _, tt := range tests { - req, err := http.NewRequest(http.MethodGet, "/getgit?url="+tt.url, nil) + req, err := http.NewRequest(http.MethodGet, "/git?url="+tt.url, nil) if err != nil { t.Fatal(err) } @@ -77,24 +79,77 @@ func TestGitHandler_InvalidURL(t *testing.T) { } } +// Override global variables for test +// Note: In a real app we might want to dependency inject these, +// but for this simple script we modify package globals. +func setupTest(t *testing.T) { + t.Helper() + tmpDir := t.TempDir() + + gitStorePath = tmpDir + persistencePath = tmpDir + "/last-fetch.json" // Use simple path join for test + fetchTimeout = time.Minute + + // Reset lastFetch map + lastFetchMu.Lock() + lastFetch = make(map[string]time.Time) + lastFetchMu.Unlock() + + // Stop any existing timer + if saveTimer != nil { + saveTimer.Stop() + saveTimer = nil + } +} + func TestGitHandler_Integration(t *testing.T) { if testing.Short() { t.Skip("skipping integration test in short mode") } - // Setup valid workdir - tmpDir := t.TempDir() + setupTest(t) - // Override global variables for test - // Note: In a real app we might want to dependency inject these, - // but for this simple script we modify package globals. - gitStorePath = tmpDir - fetchTimeout = time.Minute - // Ensure lastFetch map is initialized - if lastFetch == nil { - loadMap() + tests := []struct { + name string + url string + expectedCode int + }{ + { + name: "Valid public repo", + url: "https://github.com/google/oss-fuzz-vulns.git", // Small repo + expectedCode: http.StatusOK, + }, + { + name: "Non-existent repo", + url: "https://github.com/google/this-repo-does-not-exist-12345.git", + expectedCode: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/git?url="+tt.url, nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + gitHandler(rr, req) + + if status := rr.Code; status != tt.expectedCode { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tt.expectedCode) + } + }) + } +} + +func TestCacheHandler(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") } + setupTest(t) + tests := []struct { name string url string @@ -114,12 +169,88 @@ func TestGitHandler_Integration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, "/getgit?url="+tt.url, nil) + body, _ := json.Marshal(map[string]string{"url": tt.url}) + req, err := http.NewRequest(http.MethodPost, "/cache", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() - gitHandler(rr, req) + cacheHandler(rr, req) + + if status := rr.Code; status != tt.expectedCode { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tt.expectedCode) + } + }) + } +} + +func TestAffectedCommitsHandler(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + setupTest(t) + + tests := []struct { + name string + url string + introduced []string + fixed []string + lastAffected []string + limit []string + expectedCode int + }{ + { + name: "Valid range in public repo", + url: "https://github.com/google/oss-fuzz-vulns.git", + introduced: []string{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + fixed: []string{"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}, + expectedCode: http.StatusOK, + }, + { + name: "Invalid mixed limit and fixed", + url: "https://github.com/google/oss-fuzz-vulns.git", + fixed: []string{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + limit: []string{"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}, + expectedCode: http.StatusBadRequest, + }, + { + name: "Non-existent repo", + url: "https://github.com/google/this-repo-does-not-exist-12345.git", + introduced: []string{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + expectedCode: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var events []Event + for _, h := range tt.introduced { + events = append(events, Event{EventType: "introduced", Hash: h}) + } + for _, h := range tt.fixed { + events = append(events, Event{EventType: "fixed", Hash: h}) + } + for _, h := range tt.lastAffected { + events = append(events, Event{EventType: "last_affected", Hash: h}) + } + for _, h := range tt.limit { + events = append(events, Event{EventType: "limit", Hash: h}) + } + + reqBody := map[string]any{ + "url": tt.url, + "events": events, + } + + body, _ := json.Marshal(reqBody) + req, err := http.NewRequest(http.MethodPost, "/affected-commits", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + affectedCommitsHandler(rr, req) if status := rr.Code; status != tt.expectedCode { t.Errorf("handler returned wrong status code: got %v want %v", diff --git a/go/cmd/gitter/pb/repository/repository.pb.go b/go/cmd/gitter/pb/repository/repository.pb.go new file mode 100644 index 00000000000..36f88572a04 --- /dev/null +++ b/go/cmd/gitter/pb/repository/repository.pb.go @@ -0,0 +1,190 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: repository/repository.proto + +package repository + +import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type CommitDetail struct { + state protoimpl.MessageState `protogen:"open.v1"` + Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` + PatchId []byte `protobuf:"bytes,2,opt,name=patch_id,json=patchId,proto3" json:"patch_id,omitempty"` + Tags []string `protobuf:"bytes,3,rep,name=tags,proto3" json:"tags,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CommitDetail) Reset() { + *x = CommitDetail{} + mi := &file_repository_repository_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CommitDetail) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CommitDetail) ProtoMessage() {} + +func (x *CommitDetail) ProtoReflect() protoreflect.Message { + mi := &file_repository_repository_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CommitDetail.ProtoReflect.Descriptor instead. +func (*CommitDetail) Descriptor() ([]byte, []int) { + return file_repository_repository_proto_rawDescGZIP(), []int{0} +} + +func (x *CommitDetail) GetHash() []byte { + if x != nil { + return x.Hash + } + return nil +} + +func (x *CommitDetail) GetPatchId() []byte { + if x != nil { + return x.PatchId + } + return nil +} + +func (x *CommitDetail) GetTags() []string { + if x != nil { + return x.Tags + } + return nil +} + +// RepositoryCache is the minimally saved details for a repository +type RepositoryCache struct { + state protoimpl.MessageState `protogen:"open.v1"` + Commits []*CommitDetail `protobuf:"bytes,1,rep,name=commits,proto3" json:"commits,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RepositoryCache) Reset() { + *x = RepositoryCache{} + mi := &file_repository_repository_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RepositoryCache) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RepositoryCache) ProtoMessage() {} + +func (x *RepositoryCache) ProtoReflect() protoreflect.Message { + mi := &file_repository_repository_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RepositoryCache.ProtoReflect.Descriptor instead. +func (*RepositoryCache) Descriptor() ([]byte, []int) { + return file_repository_repository_proto_rawDescGZIP(), []int{1} +} + +func (x *RepositoryCache) GetCommits() []*CommitDetail { + if x != nil { + return x.Commits + } + return nil +} + +var File_repository_repository_proto protoreflect.FileDescriptor + +const file_repository_repository_proto_rawDesc = "" + + "\n" + + "\x1brepository/repository.proto\x12\x06gitter\"Q\n" + + "\fCommitDetail\x12\x12\n" + + "\x04hash\x18\x01 \x01(\fR\x04hash\x12\x19\n" + + "\bpatch_id\x18\x02 \x01(\fR\apatchId\x12\x12\n" + + "\x04tags\x18\x03 \x03(\tR\x04tags\"A\n" + + "\x0fRepositoryCache\x12.\n" + + "\acommits\x18\x01 \x03(\v2\x14.gitter.CommitDetailR\acommitsB\x0eZ\f./repositoryb\x06proto3" + +var ( + file_repository_repository_proto_rawDescOnce sync.Once + file_repository_repository_proto_rawDescData []byte +) + +func file_repository_repository_proto_rawDescGZIP() []byte { + file_repository_repository_proto_rawDescOnce.Do(func() { + file_repository_repository_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_repository_repository_proto_rawDesc), len(file_repository_repository_proto_rawDesc))) + }) + return file_repository_repository_proto_rawDescData +} + +var file_repository_repository_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_repository_repository_proto_goTypes = []any{ + (*CommitDetail)(nil), // 0: gitter.CommitDetail + (*RepositoryCache)(nil), // 1: gitter.RepositoryCache +} +var file_repository_repository_proto_depIdxs = []int32{ + 0, // 0: gitter.RepositoryCache.commits:type_name -> gitter.CommitDetail + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_repository_repository_proto_init() } +func file_repository_repository_proto_init() { + if File_repository_repository_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_repository_repository_proto_rawDesc), len(file_repository_repository_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_repository_repository_proto_goTypes, + DependencyIndexes: file_repository_repository_proto_depIdxs, + MessageInfos: file_repository_repository_proto_msgTypes, + }.Build() + File_repository_repository_proto = out.File + file_repository_repository_proto_goTypes = nil + file_repository_repository_proto_depIdxs = nil +} diff --git a/go/cmd/gitter/pb/repository/repository.proto b/go/cmd/gitter/pb/repository/repository.proto new file mode 100644 index 00000000000..bbf8f4bf481 --- /dev/null +++ b/go/cmd/gitter/pb/repository/repository.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package gitter; + +option go_package = "./repository"; + +message CommitDetail { + bytes hash = 1; + bytes patch_id = 2; + repeated string tags = 3; +} + +// RepositoryCache is the minimally saved details for a repository +message RepositoryCache { + repeated CommitDetail commits = 1; +} diff --git a/go/cmd/gitter/persistance.go b/go/cmd/gitter/persistence.go similarity index 56% rename from go/cmd/gitter/persistance.go rename to go/cmd/gitter/persistence.go index bbdcc840fe4..e5954a985b4 100644 --- a/go/cmd/gitter/persistance.go +++ b/go/cmd/gitter/persistence.go @@ -7,7 +7,9 @@ import ( "sync" "time" + pb "github.com/google/osv.dev/go/cmd/gitter/pb/repository" "github.com/google/osv.dev/go/logger" + "google.golang.org/protobuf/proto" ) var ( @@ -43,32 +45,32 @@ func debounceSaveMap() { saveTimer.Stop() } saveTimer = time.AfterFunc(3*time.Second, func() { - saveMap() + saveLastFetchMap() }) } -func saveMap() { +func saveLastFetchMap() { lastFetchMu.Lock() defer lastFetchMu.Unlock() - logger.Info("Saving lastFetch map", slog.String("path", persistancePath)) + logger.Info("Saving lastFetch map", slog.String("path", persistencePath)) data, err := json.Marshal(lastFetch) if err != nil { - logger.Error("Error marshaling lastFetch map", slog.String("path", persistancePath), slog.Any("error", err)) + logger.Error("Error marshaling lastFetch map", slog.String("path", persistencePath), slog.Any("error", err)) return } - if err := os.WriteFile(persistancePath, data, 0600); err != nil { - logger.Error("Error writing lastFetch map", slog.String("path", persistancePath), slog.Any("error", err)) + if err := os.WriteFile(persistencePath, data, 0600); err != nil { + logger.Error("Error writing lastFetch map", slog.String("path", persistencePath), slog.Any("error", err)) } } -func loadMap() { - data, err := os.ReadFile(persistancePath) +func loadLastFetchMap() { + data, err := os.ReadFile(persistencePath) if err != nil { if !os.IsNotExist(err) { - logger.Error("Error reading lastFetch map", slog.String("path", persistancePath), slog.Any("error", err)) + logger.Error("Error reading lastFetch map", slog.String("path", persistencePath), slog.Any("error", err)) } return @@ -78,8 +80,42 @@ func loadMap() { defer lastFetchMu.Unlock() if err := json.Unmarshal(data, &lastFetch); err != nil { - logger.Error("Error unmarshaling lastFetch map", slog.String("path", persistancePath), slog.Any("error", err)) + logger.Error("Error unmarshaling lastFetch map", slog.String("path", persistencePath), slog.Any("error", err)) } logger.Info("Loaded lastFetch map", slog.Int("entry_count", len(lastFetch))) } + +func saveRepositoryCache(cachePath string, repo *Repository) error { + logger.Info("Saving repository cache", slog.String("path", cachePath)) + + cache := &pb.RepositoryCache{} + for _, commit := range repo.commitDetails { + cache.Commits = append(cache.Commits, &pb.CommitDetail{ + Hash: commit.Hash[:], + PatchId: commit.PatchID[:], + Tags: commit.Tags, + }) + } + + data, err := proto.Marshal(cache) + if err != nil { + return err + } + + return os.WriteFile(cachePath, data, 0600) +} + +func loadRepositoryCache(cachePath string) (*pb.RepositoryCache, error) { + data, err := os.ReadFile(cachePath) + if err != nil { + return nil, err + } + + cache := &pb.RepositoryCache{} + if err := proto.Unmarshal(data, cache); err != nil { + return nil, err + } + + return cache, nil +} diff --git a/go/cmd/gitter/repository.go b/go/cmd/gitter/repository.go new file mode 100644 index 00000000000..b034ff6db81 --- /dev/null +++ b/go/cmd/gitter/repository.go @@ -0,0 +1,548 @@ +package main + +import ( + "bufio" + "context" + "encoding/hex" + "errors" + "fmt" + "log/slog" + "maps" + "os" + "runtime" + "slices" + "strings" + "sync" + "time" + + pb "github.com/google/osv.dev/go/cmd/gitter/pb/repository" + "github.com/google/osv.dev/go/logger" + "golang.org/x/sync/errgroup" +) + +type SHA1 [20]byte + +type Commit struct { + Hash SHA1 `json:"hash"` + PatchID SHA1 `json:"patch_id"` + Parents []SHA1 `json:"parents"` + Tags []string `json:"tags"` +} + +// Repository holds the commit graph and other details for a git repository. +type Repository struct { + repoMu sync.Mutex + // Path to the .git directory within gitter's working dir + repoPath string + // Adjacency list: Parent -> []Children + commitGraph map[SHA1][]SHA1 + // Actual commit details + commitDetails map[SHA1]*Commit + // Store tags to commit because it's useful for CVE conversion + tagToCommit map[string]SHA1 + // For cherry-pick detection: PatchID -> []commit hash + patchIDToCommits map[SHA1][]SHA1 +} + +// %H commit hash; %P parent hashes; %D:refs (tab delimited) +const gitLogFormat = "%H%x09%P%x09%D" + +// NewRepository initializes a new Repository struct. +func NewRepository(repoPath string) *Repository { + return &Repository{ + repoPath: repoPath, + commitGraph: make(map[SHA1][]SHA1), + commitDetails: make(map[SHA1]*Commit), + tagToCommit: make(map[string]SHA1), + patchIDToCommits: make(map[SHA1][]SHA1), + } +} + +// LoadRepository loads a repo from disk into memory. +func LoadRepository(ctx context.Context, repoPath string) (*Repository, error) { + repo := NewRepository(repoPath) + + cachePath := repoPath + ".pb" + var cache *pb.RepositoryCache + + // Load cache pb file of the repo if exist + if c, err := loadRepositoryCache(cachePath); err == nil { + cache = c + logger.Info("Loaded repository cache", slog.Int("commits", len(cache.GetCommits()))) + } else { + if errors.Is(err, os.ErrNotExist) { + // It's fine if cache doesn't exist, log it just in case + logger.Info("No repository cache found") + } else { + return nil, fmt.Errorf("failed to load repository cache: %w", err) + } + } + + // Commit graph is built from scratch every time + newCommits, err := repo.buildCommitGraph(ctx, cache) + if err != nil { + return nil, fmt.Errorf("failed to build commit graph: %w", err) + } + + if len(newCommits) > 0 { + if err := repo.calculatePatchIDs(ctx, newCommits); err != nil { + return nil, fmt.Errorf("failed to calculate patch id for commits: %w", err) + } + } + + // Save cache + if err := saveRepositoryCache(cachePath, repo); err != nil { + logger.Error("Failed to save repository cache", slog.Any("err", err)) + } + + return repo, nil +} + +// buildCommitGraph builds the commit graph and associate commit details from scratch +func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryCache) ([]SHA1, error) { + logger.Info("Starting graph construction", slog.String("repo", r.repoPath)) + start := time.Now() + + // Build cache map + cachedPatchIDs := make(map[SHA1]SHA1) + if cache != nil { + commits := cache.GetCommits() + for _, c := range commits { + h := c.GetHash() + pid := c.GetPatchId() + if len(h) == 20 && len(pid) == 20 { + cachedPatchIDs[SHA1(h)] = SHA1(pid) + } + } + } + var newCommits []SHA1 + + // Temp outFile for git log output + tmpFile, err := os.CreateTemp(r.repoPath, "git-log.out") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + + // Run git log via bash because redirecting to file is faster than using pipe + err = runCmd(ctx, r.repoPath, nil, "bash", "-c", "git log --all --full-history --sparse --topo-order --format="+gitLogFormat+" > "+tmpFile.Name()) + if err != nil { + return nil, fmt.Errorf("failed to run git log: %w", err) + } + + // Read git log output + file, err := os.Open(tmpFile.Name()) + if err != nil { + return nil, fmt.Errorf("failed to open git-log.out: %w", err) + } + defer file.Close() + + reader := bufio.NewReaderSize(file, 1024*1024) + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + + line = strings.TrimSuffix(line, "\n") + commitInfo := strings.Split(line, "\x09") + + var childHash SHA1 + parentHashes := []SHA1{} + tags := []string{} + + switch len(commitInfo) { + case 3: + // refs are separated by commas + refs := strings.Split(commitInfo[2], ", ") + for _, ref := range refs { + // Remove prefixes from tags, other refs such as HEAD will be left as is + if strings.Contains(ref, "tag: ") { + tags = append(tags, strings.TrimPrefix(ref, "tag: ")) + } + } + + fallthrough + case 2: + // parent hashes are separated by spaces + parents := strings.Fields(commitInfo[1]) + for _, parent := range parents { + hash, err := hex.DecodeString(parent) + if err != nil { + logger.Error("Failed to decode hash", slog.String("parent", parent), slog.Any("err", err)) + continue + } + parentHashes = append(parentHashes, SHA1(hash)) + } + + fallthrough + case 1: + hash, err := hex.DecodeString(commitInfo[0]) + if err != nil { + logger.Error("Failed to decode hash", slog.String("child", commitInfo[0]), slog.Any("err", err)) + continue + } + childHash = SHA1(hash) + default: + // No line should be completely empty (doesn't even have a commit hash) so error + logger.Error("Invalid commit info", slog.String("line", line)) + continue + } + + // Add commit to graph (parent -> []child) + for _, parentHash := range parentHashes { + r.commitGraph[parentHash] = append(r.commitGraph[parentHash], childHash) + } + + commit := Commit{ + Hash: childHash, + Tags: tags, + Parents: parentHashes, + } + + if patchID, ok := cachedPatchIDs[childHash]; ok { + // Assign saved patch ID to commit details and map if found + commit.PatchID = patchID + // Also populate patchIDToCommits map + r.patchIDToCommits[patchID] = append(r.patchIDToCommits[patchID], childHash) + } else { + // Add to slice for patch ID to be generated later + newCommits = append(newCommits, childHash) + } + + r.commitDetails[childHash] = &commit + + // Also populate the tag-to-commit map + for _, tag := range tags { + r.tagToCommit[tag] = childHash + } + } + + logger.Info("Commit graph completed", slog.Int("commits", len(r.commitDetails)), slog.Int("nodes", len(r.commitGraph)), slog.Int("new_commits", len(newCommits)), slog.Duration("duration", time.Since(start))) + + return newCommits, nil +} + +// calculatePatchIDs calculates patch IDs only for the specific commits provided. +func (r *Repository) calculatePatchIDs(ctx context.Context, commits []SHA1) error { + logger.Info("Starting patch ID calculation", slog.String("repo", r.repoPath)) + start := time.Now() + + // Number of workers + workers := runtime.NumCPU() + if len(commits) < workers { + workers = len(commits) + } + + chunkSize := len(commits) / workers + + errg, ctx := errgroup.WithContext(ctx) + + for i := range workers { + start := i * chunkSize + end := min(start+chunkSize, len(commits)) + + errg.Go(func() error { + return r.calculatePatchIDsWorker(ctx, commits[start:end]) + }) + } + + if err := errg.Wait(); err != nil { + return fmt.Errorf("failed to calculate patch IDs: %w", err) + } + + logger.Info("Patch ID calculation completed", slog.Int("commits", len(commits)), slog.Duration("duration", time.Since(start))) + + return nil +} + +func (r *Repository) calculatePatchIDsWorker(ctx context.Context, chunk []SHA1) error { + // Prepare git commands + // TODO: Replace with plumbing cmd `git diff-tree`, might be slightly faster + cmdShow := prepareCmd(ctx, r.repoPath, nil, "git", "show", "--stdin", "--patch", "--first-parent", "--no-color") + cmdPatchID := prepareCmd(ctx, r.repoPath, nil, "git", "patch-id", "--stable") + + // Pipe the git show with git patch-id + in, err := cmdShow.StdinPipe() + if err != nil { + return fmt.Errorf("git show stdin pipe error: %w", err) + } + + rPipe, wPipe, err := os.Pipe() + if err != nil { + return fmt.Errorf("inter-process pipe error: %w", err) + } + defer rPipe.Close() + // wPipe should be closed in the goroutine where we wait for git show + // But keeping this defer as a failsafe + defer wPipe.Close() + + cmdShow.Stdout = wPipe + cmdPatchID.Stdin = rPipe + + out, err := cmdPatchID.StdoutPipe() + if err != nil { + return fmt.Errorf("git patch-id stdout pipe error: %w", err) + } + + // Start the processes + if err := cmdShow.Start(); err != nil { + return fmt.Errorf("failed to start git show: %w", err) + } + if err := cmdPatchID.Start(); err != nil { + return fmt.Errorf("failed to start git patch-id: %w", err) + } + + // Channel to capture errors from git show + showErrChan := make(chan error, 1) + + // Write hashes to git show stdin + go func() { + defer in.Close() + for _, hash := range chunk { + fmt.Fprintf(in, "%s\n", hex.EncodeToString(hash[:])) + } + }() + + // Wait for git show to finish + go func() { + err := cmdShow.Wait() + showErrChan <- err + wPipe.Close() // close pipe to send EOF to git patch-id + }() + + // Read results from stdout of git patch-id + scanner := bufio.NewScanner(out) + for scanner.Scan() { + line := scanner.Text() + + // The whole output of git patch-id will be empty if there is no diff (e.g. empty commit), it is safe to just continue + if line == "" { + continue + } + + patchInfo := strings.Fields(line) + // --first-parent flag in git show should have prevented git patch-id from returning multiple lines of patch IDs + // return error if this still happens + if len(patchInfo) != 2 { + return fmt.Errorf("invalid patch ID format: %s", line) + } + + patchIDBytes, err := hex.DecodeString(patchInfo[0]) + if err != nil { + return fmt.Errorf("failed to decode patch ID: %w", err) + } + patchID := SHA1(patchIDBytes) + + hashBytes, err := hex.DecodeString(patchInfo[1]) + if err != nil { + return fmt.Errorf("failed to decode commit hash: %w", err) + } + hash := SHA1(hashBytes) + + r.updatePatchID(hash, patchID) + } + + // Wait for git patch-id to finish + if err := cmdPatchID.Wait(); err != nil { + return fmt.Errorf("failed to finish git patch-id: %w", err) + } + + // Wait for git show to finish + if err := <-showErrChan; err != nil { + return fmt.Errorf("failed to finish git show: %w", err) + } + + return nil +} + +func (r *Repository) updatePatchID(commitHash, patchID SHA1) { + r.repoMu.Lock() + defer r.repoMu.Unlock() + + commit := r.commitDetails[commitHash] + commit.PatchID = patchID + r.commitDetails[commitHash] = commit + + r.patchIDToCommits[patchID] = append(r.patchIDToCommits[patchID], commitHash) +} + +// Affected returns a list of commits that are affected by the given introduced, fixed and last_affected events +func (r *Repository) Affected(introduced, fixed, lastAffected []SHA1, cherrypick bool) []*Commit { + r.repoMu.Lock() + defer r.repoMu.Unlock() + + // Expands the introduced and fixed commits to include cherrypick equivalents + // lastAffected should not be expanded because it does not imply a "fix" commit that can be cherrypicked to other branches + if cherrypick { + introduced = r.expandByCherrypick(introduced) + fixed = r.expandByCherrypick(fixed) + } + + safeCommits := r.findSafeCommits(introduced, fixed, lastAffected) + + var affectedCommits []*Commit + + stack := make([]SHA1, 0, len(introduced)) + stack = append(stack, introduced...) + + visited := make(map[SHA1]struct{}) + + for len(stack) > 0 { + curr := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + if _, ok := visited[curr]; ok { + continue + } + visited[curr] = struct{}{} + + // If commit is in safe set, we can stop the traversal + if _, ok := safeCommits[curr]; ok { + continue + } + + // Otherwise, add to affected commits + affectedCommits = append(affectedCommits, r.commitDetails[curr]) + + // Add children to DFS stack + if children, ok := r.commitGraph[curr]; ok { + stack = append(stack, children...) + } + } + + return affectedCommits +} + +// findSafeCommits returns a set of commits that are non-vulnerable +// Traversing from fixed and children of last affected to the next introduced (if exist) +func (r *Repository) findSafeCommits(introduced, fixed, lastAffected []SHA1) map[SHA1]struct{} { + introducedMap := make(map[SHA1]struct{}) + for _, commit := range introduced { + introducedMap[commit] = struct{}{} + } + + safeSet := make(map[SHA1]struct{}) + stack := make([]SHA1, 0, len(fixed)+len(lastAffected)) + stack = append(stack, fixed...) + + // All children of last affected commits are root for traversal + for _, commit := range lastAffected { + if children, ok := r.commitGraph[commit]; ok { + for _, child := range children { + // Except if child is an introduced commit + if _, ok := introducedMap[child]; ok { + continue + } + stack = append(stack, child) + } + } + } + + // DFS until we hit an "introduced" commit + for len(stack) > 0 { + curr := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + if _, ok := safeSet[curr]; ok { + continue + } + safeSet[curr] = struct{}{} + + if children, ok := r.commitGraph[curr]; ok { + for _, child := range children { + // vuln re-introduced at a later commit, subsequent commits are no longer safe + if _, ok := introducedMap[child]; ok { + continue + } + stack = append(stack, child) + } + } + } + + return safeSet +} + +// expandByCherrypick expands a slice of commits by adding commits that have the same Patch ID (cherrypicked commits) returns a new list containing the original commits PLUS any other commits that share the same Patch ID +func (r *Repository) expandByCherrypick(commits []SHA1) []SHA1 { + unique := make(map[SHA1]struct{}, len(commits)) // avoid duplication + var zeroPatchID SHA1 + + for _, hash := range commits { + unique[hash] = struct{}{} + + // Find patch ID from commit details + details, ok := r.commitDetails[hash] + + if !ok || details.PatchID == zeroPatchID { + continue + } + + // Find equivalent commits + equivalents := r.patchIDToCommits[details.PatchID] + // TODO: I think this logic will always add the current commit one more time, which isn't a problem because we're using map but still suboptimal + for _, eq := range equivalents { + unique[eq] = struct{}{} + } + } + + keys := slices.Collect(maps.Keys(unique)) + + return keys +} + +// Between walks and returns the commits that are strictly between introduced (inclusive) and limit (exclusive) +func (r *Repository) Between(introduced, limit []SHA1) []*Commit { + r.repoMu.Lock() + defer r.repoMu.Unlock() + + var affectedCommits []*Commit + + introMap := make(map[SHA1]struct{}, len(introduced)) + for _, commit := range introduced { + introMap[commit] = struct{}{} + } + + // DFS to walk from limit(s) to introduced (follow first parent) + stack := make([]SHA1, 0, len(limit)) + // Start from limits' parents + for _, commit := range limit { + details, ok := r.commitDetails[commit] + if !ok { + continue + } + stack = append(stack, details.Parents[0]) + } + + visited := make(map[SHA1]struct{}) + + for len(stack) > 0 { + curr := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + if _, ok := visited[curr]; ok { + continue + } + visited[curr] = struct{}{} + + // Add current node to affected commits + details, ok := r.commitDetails[curr] + if !ok { + continue + } + + affectedCommits = append(affectedCommits, details) + + // If commit is in introduced, we can stop the traversal after adding it to affected + if _, ok := introMap[curr]; ok { + continue + } + + // Add first parent to stack to only walk the linear branch + if len(details.Parents) > 0 { + stack = append(stack, details.Parents[0]) + } + } + + return affectedCommits +} diff --git a/go/cmd/gitter/repository_test.go b/go/cmd/gitter/repository_test.go new file mode 100644 index 00000000000..845672f545a --- /dev/null +++ b/go/cmd/gitter/repository_test.go @@ -0,0 +1,395 @@ +package main + +import ( + "encoding/hex" + "fmt" + "sort" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +// Helper to decode string into SHA1 +func decodeSHA1(s string) SHA1 { + var hash SHA1 + // Pad with zeros because the test strings are shorter than 40 char + padded := fmt.Sprintf("%040s", s) + b, err := hex.DecodeString(padded) + if err != nil { + panic(err) + } + copy(hash[:], b) + + return hash +} + +// Helper to encode SHA1 into string (leading 0's removed) +func encodeSHA1(hash SHA1) string { + // Remove padding zeros for a cleaner results + str := hex.EncodeToString(hash[:]) + + return strings.TrimLeft(str, "0") +} + +func TestExpandByCherrypick(t *testing.T) { + repo := NewRepository("/repo") + + // Commit hashes + h1 := decodeSHA1("aaaa") + h2 := decodeSHA1("bbbb") + h3 := decodeSHA1("cccc") + + // Patch ID + p1 := decodeSHA1("1111") + + // Setup commit details + repo.commitDetails[h1] = &Commit{Hash: h1, PatchID: p1} + repo.commitDetails[h2] = &Commit{Hash: h2} + repo.commitDetails[h3] = &Commit{Hash: h3, PatchID: p1} // h3 has the same patch ID as h1 should be cherry picked + + // Setup patch ID map + repo.patchIDToCommits[p1] = []SHA1{h1, h3} + + tests := []struct { + name string + input []SHA1 + expected []SHA1 + }{ + { + name: "Expand single commit with cherry-pick", + input: []SHA1{h1}, + expected: []SHA1{h1, h3}, + }, + { + name: "No expansion for commit without cherry-pick", + input: []SHA1{h2}, + expected: []SHA1{h2}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := repo.expandByCherrypick(tt.input) + + if diff := cmp.Diff(tt.expected, got); diff != "" { + t.Errorf("expandByCherrypick() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// Testing cases with introduced and fixed only. +func TestAffected_Introduced_Fixed(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // -> F -> G + // / + // A -> B -> C -> D -> E + // \ / + // -> H -> + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.commitGraph[hA] = []SHA1{hB} + repo.commitGraph[hB] = []SHA1{hC, hH} + repo.commitGraph[hC] = []SHA1{hD, hF} + repo.commitGraph[hD] = []SHA1{hE} + repo.commitGraph[hF] = []SHA1{hG} + repo.commitGraph[hH] = []SHA1{hD} + + // Setup details + repo.commitDetails[hA] = &Commit{Hash: hA} + repo.commitDetails[hB] = &Commit{Hash: hB} + repo.commitDetails[hC] = &Commit{Hash: hC} + repo.commitDetails[hD] = &Commit{Hash: hD} + repo.commitDetails[hE] = &Commit{Hash: hE} + repo.commitDetails[hF] = &Commit{Hash: hF} + repo.commitDetails[hG] = &Commit{Hash: hG} + repo.commitDetails[hH] = &Commit{Hash: hH} + + tests := []struct { + name string + introduced []SHA1 + fixed []SHA1 + lastAffected []SHA1 + expected []SHA1 + }{ + { + name: "Linear: A introduced, B fixed", + introduced: []SHA1{hA}, + fixed: []SHA1{hB}, + expected: []SHA1{hA}, + }, + { + name: "Branch propagation: A introduced, D fixed", + introduced: []SHA1{hA}, + fixed: []SHA1{hD}, + expected: []SHA1{hA, hB, hC, hF, hG, hH}, + }, + { + name: "Diverged before introduce: C introduced, E fixed", + introduced: []SHA1{hC}, + fixed: []SHA1{hE}, + expected: []SHA1{hC, hD, hF, hG}, + }, + { + name: "Two sets: (A,C) introduced, (B,D,G) fixed", + introduced: []SHA1{hA, hC}, + fixed: []SHA1{hB, hD, hG}, + expected: []SHA1{hA, hC, hF}, + }, + { + name: "Merge fix: A introduced, H fixed", + introduced: []SHA1{hA}, + fixed: []SHA1{hH}, + expected: []SHA1{hA, hB, hC, hF, hG}, + }, + { + name: "Everything affected if no fix", + introduced: []SHA1{hA}, + expected: []SHA1{hA, hB, hC, hD, hE, hF, hG, hH}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Affected(tt.introduced, tt.fixed, tt.lastAffected, false) + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + // Sort got and expected for comparison + sort.Slice(got, func(i, j int) bool { + return string(got[i][:]) < string(got[j][:]) + }) + sort.Slice(tt.expected, func(i, j int) bool { + return string(tt.expected[i][:]) < string(tt.expected[j][:]) + }) + + if diff := cmp.Diff(tt.expected, got); diff != "" { + // Turn them back into strings so it's easier to read + gotStr := make([]string, len(got)) + for i, c := range got { + gotStr[i] = encodeSHA1(c) + } + expectedStr := make([]string, len(tt.expected)) + for i, c := range tt.expected { + expectedStr[i] = encodeSHA1(c) + } + + t.Errorf("TestAffected_Introduced_Fixed() mismatch\nGot: %v\nExpected: %v", gotStr, expectedStr) + } + }) + } +} + +func TestAffected_Introduced_LastAffected(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // A -> B -> C -> D -> E -> F + // \ / + // -> G -> H + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.commitGraph[hA] = []SHA1{hB} + repo.commitGraph[hB] = []SHA1{hC, hG} + repo.commitGraph[hC] = []SHA1{hD} + repo.commitGraph[hD] = []SHA1{hE} + repo.commitGraph[hE] = []SHA1{hF} + repo.commitGraph[hG] = []SHA1{hD, hH} + + // Setup details + repo.commitDetails[hA] = &Commit{Hash: hA} + repo.commitDetails[hB] = &Commit{Hash: hB} + repo.commitDetails[hC] = &Commit{Hash: hC} + repo.commitDetails[hD] = &Commit{Hash: hD} + repo.commitDetails[hE] = &Commit{Hash: hE} + repo.commitDetails[hF] = &Commit{Hash: hF} + repo.commitDetails[hG] = &Commit{Hash: hG} + repo.commitDetails[hH] = &Commit{Hash: hH} + + tests := []struct { + name string + introduced []SHA1 + fixed []SHA1 + lastAffected []SHA1 + expected []SHA1 + }{ + { + name: "Linear: E introduced, F lastAffected", + introduced: []SHA1{hE}, + lastAffected: []SHA1{hF}, + expected: []SHA1{hE, hF}, + }, + { + name: "Branch propagation: A introduced, D lastAffected", + introduced: []SHA1{hA}, + lastAffected: []SHA1{hD}, + expected: []SHA1{hA, hB, hC, hD, hG, hH}, + }, + { + name: "Diverged before introduce: C introduced, E lastAffected", + introduced: []SHA1{hC}, + lastAffected: []SHA1{hE}, + expected: []SHA1{hC, hD, hE}, + }, + { + name: "Two sets: (C,E) introduced, (D,F) lastAffected", + introduced: []SHA1{hC, hE}, + lastAffected: []SHA1{hD, hF}, + expected: []SHA1{hC, hD, hE, hF}, + }, + { + name: "Everything affected if no lastAffected", + introduced: []SHA1{hA}, + expected: []SHA1{hA, hB, hC, hD, hE, hF, hG, hH}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Affected(tt.introduced, tt.fixed, tt.lastAffected, false) + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + // Sort got and expected for comparison + sort.Slice(got, func(i, j int) bool { + return string(got[i][:]) < string(got[j][:]) + }) + sort.Slice(tt.expected, func(i, j int) bool { + return string(tt.expected[i][:]) < string(tt.expected[j][:]) + }) + + if diff := cmp.Diff(tt.expected, got); diff != "" { + // Turn them back into strings so it's easier to read + gotStr := make([]string, len(got)) + for i, c := range got { + gotStr[i] = encodeSHA1(c) + } + expectedStr := make([]string, len(tt.expected)) + for i, c := range tt.expected { + expectedStr[i] = encodeSHA1(c) + } + + t.Errorf("TestAffected_Introduced_LastAffected() mismatch\nGot: %v\nExpected: %v", gotStr, expectedStr) + } + }) + } +} + +func TestBetween(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // A -> B -> C -> D -> E + // \ + // -> F -> G -> H + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.commitGraph[hA] = []SHA1{hB} + repo.commitGraph[hB] = []SHA1{hC, hF} + repo.commitGraph[hC] = []SHA1{hD} + repo.commitGraph[hD] = []SHA1{hE} + repo.commitGraph[hF] = []SHA1{hG} + repo.commitGraph[hG] = []SHA1{hH} + + // Setup details + repo.commitDetails[hA] = &Commit{Hash: hA} + repo.commitDetails[hB] = &Commit{Hash: hB, Parents: []SHA1{hA}} + repo.commitDetails[hC] = &Commit{Hash: hC, Parents: []SHA1{hB}} + repo.commitDetails[hD] = &Commit{Hash: hD, Parents: []SHA1{hC}} + repo.commitDetails[hE] = &Commit{Hash: hE, Parents: []SHA1{hD}} + repo.commitDetails[hF] = &Commit{Hash: hF, Parents: []SHA1{hB}} + repo.commitDetails[hG] = &Commit{Hash: hG, Parents: []SHA1{hF}} + repo.commitDetails[hH] = &Commit{Hash: hH, Parents: []SHA1{hG}} + + tests := []struct { + name string + introduced []SHA1 + limit []SHA1 + expected []SHA1 + }{ + { + name: "One branch: A introduced, D limit", + introduced: []SHA1{hA}, + limit: []SHA1{hD}, + expected: []SHA1{hA, hB, hC}, + }, + { + name: "Side branch: A introduced, G limit", + introduced: []SHA1{hA}, + limit: []SHA1{hG}, + expected: []SHA1{hA, hB, hF}, + }, + { + name: "Two branches: A introduced, (D,G) limit", + introduced: []SHA1{hA}, + limit: []SHA1{hD, hG}, + expected: []SHA1{hA, hB, hC, hF}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Between(tt.introduced, tt.limit) + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + // Sort got and expected for comparison + sort.Slice(got, func(i, j int) bool { + return string(got[i][:]) < string(got[j][:]) + }) + sort.Slice(tt.expected, func(i, j int) bool { + return string(tt.expected[i][:]) < string(tt.expected[j][:]) + }) + + if diff := cmp.Diff(tt.expected, got); diff != "" { + // Turn them back into strings so it's easier to read + gotStr := make([]string, len(got)) + for i, c := range got { + gotStr[i] = encodeSHA1(c) + } + expectedStr := make([]string, len(tt.expected)) + for i, c := range tt.expected { + expectedStr[i] = encodeSHA1(c) + } + + t.Errorf("TestBetween() mismatch\nGot: %v\nExpected: %v", gotStr, expectedStr) + } + }) + } +}