Ensure that all migration requests are cancellable (#12669)

* Ensure that all migration requests are cancellable

Signed-off-by: Andrew Thornton <art27@cantab.net>

* Use WithContext as RequestWithContext is go 1.14

Signed-off-by: Andrew Thornton <art27@cantab.net>

Co-authored-by: techknowlogick <techknowlogick@gitea.io>
mj-v1.14.3
zeripath 4 years ago committed by GitHub
parent 84eac6ed6c
commit 714ab71ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,7 +35,7 @@ type Downloader interface {
// DownloaderFactory defines an interface to match a downloader implementation and create a downloader // DownloaderFactory defines an interface to match a downloader implementation and create a downloader
type DownloaderFactory interface { type DownloaderFactory interface {
New(opts MigrateOptions) (Downloader, error) New(ctx context.Context, opts MigrateOptions) (Downloader, error)
GitServiceType() structs.GitServiceType GitServiceType() structs.GitServiceType
} }
@ -46,14 +46,16 @@ var (
// RetryDownloader retry the downloads // RetryDownloader retry the downloads
type RetryDownloader struct { type RetryDownloader struct {
Downloader Downloader
ctx context.Context
RetryTimes int // the total execute times RetryTimes int // the total execute times
RetryDelay int // time to delay seconds RetryDelay int // time to delay seconds
} }
// NewRetryDownloader creates a retry downloader // NewRetryDownloader creates a retry downloader
func NewRetryDownloader(downloader Downloader, retryTimes, retryDelay int) *RetryDownloader { func NewRetryDownloader(ctx context.Context, downloader Downloader, retryTimes, retryDelay int) *RetryDownloader {
return &RetryDownloader{ return &RetryDownloader{
Downloader: downloader, Downloader: downloader,
ctx: ctx,
RetryTimes: retryTimes, RetryTimes: retryTimes,
RetryDelay: retryDelay, RetryDelay: retryDelay,
} }
@ -61,6 +63,7 @@ func NewRetryDownloader(downloader Downloader, retryTimes, retryDelay int) *Retr
// SetContext set context // SetContext set context
func (d *RetryDownloader) SetContext(ctx context.Context) { func (d *RetryDownloader) SetContext(ctx context.Context) {
d.ctx = ctx
d.Downloader.SetContext(ctx) d.Downloader.SetContext(ctx)
} }
@ -75,7 +78,11 @@ func (d *RetryDownloader) GetRepoInfo() (*Repository, error) {
if repo, err = d.Downloader.GetRepoInfo(); err == nil { if repo, err = d.Downloader.GetRepoInfo(); err == nil {
return repo, nil return repo, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -91,7 +98,11 @@ func (d *RetryDownloader) GetTopics() ([]string, error) {
if topics, err = d.Downloader.GetTopics(); err == nil { if topics, err = d.Downloader.GetTopics(); err == nil {
return topics, nil return topics, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -107,7 +118,11 @@ func (d *RetryDownloader) GetMilestones() ([]*Milestone, error) {
if milestones, err = d.Downloader.GetMilestones(); err == nil { if milestones, err = d.Downloader.GetMilestones(); err == nil {
return milestones, nil return milestones, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -123,7 +138,11 @@ func (d *RetryDownloader) GetReleases() ([]*Release, error) {
if releases, err = d.Downloader.GetReleases(); err == nil { if releases, err = d.Downloader.GetReleases(); err == nil {
return releases, nil return releases, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -139,7 +158,11 @@ func (d *RetryDownloader) GetLabels() ([]*Label, error) {
if labels, err = d.Downloader.GetLabels(); err == nil { if labels, err = d.Downloader.GetLabels(); err == nil {
return labels, nil return labels, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -156,7 +179,11 @@ func (d *RetryDownloader) GetIssues(page, perPage int) ([]*Issue, bool, error) {
if issues, isEnd, err = d.Downloader.GetIssues(page, perPage); err == nil { if issues, isEnd, err = d.Downloader.GetIssues(page, perPage); err == nil {
return issues, isEnd, nil return issues, isEnd, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, false, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, false, err return nil, false, err
} }
@ -172,7 +199,11 @@ func (d *RetryDownloader) GetComments(issueNumber int64) ([]*Comment, error) {
if comments, err = d.Downloader.GetComments(issueNumber); err == nil { if comments, err = d.Downloader.GetComments(issueNumber); err == nil {
return comments, nil return comments, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -188,7 +219,11 @@ func (d *RetryDownloader) GetPullRequests(page, perPage int) ([]*PullRequest, er
if prs, err = d.Downloader.GetPullRequests(page, perPage); err == nil { if prs, err = d.Downloader.GetPullRequests(page, perPage); err == nil {
return prs, nil return prs, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }
@ -204,7 +239,11 @@ func (d *RetryDownloader) GetReviews(pullRequestNumber int64) ([]*Review, error)
if reviews, err = d.Downloader.GetReviews(pullRequestNumber); err == nil { if reviews, err = d.Downloader.GetReviews(pullRequestNumber); err == nil {
return reviews, nil return reviews, nil
} }
time.Sleep(time.Second * time.Duration(d.RetryDelay)) select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case <-time.After(time.Second * time.Duration(d.RetryDelay)):
}
} }
return nil, err return nil, err
} }

@ -6,6 +6,7 @@
package migrations package migrations
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -26,7 +27,7 @@ func TestGiteaUploadRepo(t *testing.T) {
user := models.AssertExistsAndLoadBean(t, &models.User{ID: 1}).(*models.User) user := models.AssertExistsAndLoadBean(t, &models.User{ID: 1}).(*models.User)
var ( var (
downloader = NewGithubDownloaderV3("", "", "", "go-xorm", "builder") downloader = NewGithubDownloaderV3(context.Background(), "", "", "", "go-xorm", "builder")
repoName = "builder-" + time.Now().Format("2006-01-02-15-04-05") repoName = "builder-" + time.Now().Format("2006-01-02-15-04-05")
uploader = NewGiteaLocalUploader(graceful.GetManager().HammerContext(), user, user.Name, repoName) uploader = NewGiteaLocalUploader(graceful.GetManager().HammerContext(), user, user.Name, repoName)
) )

@ -41,7 +41,7 @@ type GithubDownloaderV3Factory struct {
} }
// New returns a Downloader related to this factory according MigrateOptions // New returns a Downloader related to this factory according MigrateOptions
func (f *GithubDownloaderV3Factory) New(opts base.MigrateOptions) (base.Downloader, error) { func (f *GithubDownloaderV3Factory) New(ctx context.Context, opts base.MigrateOptions) (base.Downloader, error) {
u, err := url.Parse(opts.CloneAddr) u, err := url.Parse(opts.CloneAddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -53,7 +53,7 @@ func (f *GithubDownloaderV3Factory) New(opts base.MigrateOptions) (base.Download
log.Trace("Create github downloader: %s/%s", oldOwner, oldName) log.Trace("Create github downloader: %s/%s", oldOwner, oldName)
return NewGithubDownloaderV3(opts.AuthUsername, opts.AuthPassword, opts.AuthToken, oldOwner, oldName), nil return NewGithubDownloaderV3(ctx, opts.AuthUsername, opts.AuthPassword, opts.AuthToken, oldOwner, oldName), nil
} }
// GitServiceType returns the type of git service // GitServiceType returns the type of git service
@ -74,11 +74,11 @@ type GithubDownloaderV3 struct {
} }
// NewGithubDownloaderV3 creates a github Downloader via github v3 API // NewGithubDownloaderV3 creates a github Downloader via github v3 API
func NewGithubDownloaderV3(userName, password, token, repoOwner, repoName string) *GithubDownloaderV3 { func NewGithubDownloaderV3(ctx context.Context, userName, password, token, repoOwner, repoName string) *GithubDownloaderV3 {
var downloader = GithubDownloaderV3{ var downloader = GithubDownloaderV3{
userName: userName, userName: userName,
password: password, password: password,
ctx: context.Background(), ctx: ctx,
repoOwner: repoOwner, repoOwner: repoOwner,
repoName: repoName, repoName: repoName,
} }

@ -6,6 +6,7 @@
package migrations package migrations
import ( import (
"context"
"os" "os"
"testing" "testing"
"time" "time"
@ -64,7 +65,7 @@ func assertLabelEqual(t *testing.T, name, color, description string, label *base
func TestGitHubDownloadRepo(t *testing.T) { func TestGitHubDownloadRepo(t *testing.T) {
GithubLimitRateRemaining = 3 //Wait at 3 remaining since we could have 3 CI in // GithubLimitRateRemaining = 3 //Wait at 3 remaining since we could have 3 CI in //
downloader := NewGithubDownloaderV3("", "", os.Getenv("GITHUB_READ_TOKEN"), "go-gitea", "test_repo") downloader := NewGithubDownloaderV3(context.Background(), "", "", os.Getenv("GITHUB_READ_TOKEN"), "go-gitea", "test_repo")
err := downloader.RefreshRate() err := downloader.RefreshRate()
assert.NoError(t, err) assert.NoError(t, err)

@ -35,7 +35,7 @@ type GitlabDownloaderFactory struct {
} }
// New returns a Downloader related to this factory according MigrateOptions // New returns a Downloader related to this factory according MigrateOptions
func (f *GitlabDownloaderFactory) New(opts base.MigrateOptions) (base.Downloader, error) { func (f *GitlabDownloaderFactory) New(ctx context.Context, opts base.MigrateOptions) (base.Downloader, error) {
u, err := url.Parse(opts.CloneAddr) u, err := url.Parse(opts.CloneAddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -47,7 +47,7 @@ func (f *GitlabDownloaderFactory) New(opts base.MigrateOptions) (base.Downloader
log.Trace("Create gitlab downloader. BaseURL: %s RepoName: %s", baseURL, repoNameSpace) log.Trace("Create gitlab downloader. BaseURL: %s RepoName: %s", baseURL, repoNameSpace)
return NewGitlabDownloader(baseURL, repoNameSpace, opts.AuthUsername, opts.AuthPassword, opts.AuthToken), nil return NewGitlabDownloader(ctx, baseURL, repoNameSpace, opts.AuthUsername, opts.AuthPassword, opts.AuthToken), nil
} }
// GitServiceType returns the type of git service // GitServiceType returns the type of git service
@ -73,7 +73,7 @@ type GitlabDownloader struct {
// NewGitlabDownloader creates a gitlab Downloader via gitlab API // NewGitlabDownloader creates a gitlab Downloader via gitlab API
// Use either a username/password, personal token entered into the username field, or anonymous/public access // Use either a username/password, personal token entered into the username field, or anonymous/public access
// Note: Public access only allows very basic access // Note: Public access only allows very basic access
func NewGitlabDownloader(baseURL, repoPath, username, password, token string) *GitlabDownloader { func NewGitlabDownloader(ctx context.Context, baseURL, repoPath, username, password, token string) *GitlabDownloader {
var gitlabClient *gitlab.Client var gitlabClient *gitlab.Client
var err error var err error
if token != "" { if token != "" {
@ -88,7 +88,7 @@ func NewGitlabDownloader(baseURL, repoPath, username, password, token string) *G
} }
// Grab and store project/repo ID here, due to issues using the URL escaped path // Grab and store project/repo ID here, due to issues using the URL escaped path
gr, _, err := gitlabClient.Projects.GetProject(repoPath, nil, nil) gr, _, err := gitlabClient.Projects.GetProject(repoPath, nil, nil, gitlab.WithContext(ctx))
if err != nil { if err != nil {
log.Trace("Error retrieving project: %v", err) log.Trace("Error retrieving project: %v", err)
return nil return nil
@ -100,7 +100,7 @@ func NewGitlabDownloader(baseURL, repoPath, username, password, token string) *G
} }
return &GitlabDownloader{ return &GitlabDownloader{
ctx: context.Background(), ctx: ctx,
client: gitlabClient, client: gitlabClient,
repoID: gr.ID, repoID: gr.ID,
repoName: gr.Name, repoName: gr.Name,
@ -118,7 +118,7 @@ func (g *GitlabDownloader) GetRepoInfo() (*base.Repository, error) {
return nil, errors.New("error: GitlabDownloader is nil") return nil, errors.New("error: GitlabDownloader is nil")
} }
gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil) gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,7 +158,7 @@ func (g *GitlabDownloader) GetTopics() ([]string, error) {
return nil, errors.New("error: GitlabDownloader is nil") return nil, errors.New("error: GitlabDownloader is nil")
} }
gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil) gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -179,7 +179,7 @@ func (g *GitlabDownloader) GetMilestones() ([]*base.Milestone, error) {
ListOptions: gitlab.ListOptions{ ListOptions: gitlab.ListOptions{
Page: i, Page: i,
PerPage: perPage, PerPage: perPage,
}}, nil) }}, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -237,7 +237,7 @@ func (g *GitlabDownloader) GetLabels() ([]*base.Label, error) {
ls, _, err := g.client.Labels.ListLabels(g.repoID, &gitlab.ListLabelsOptions{ ls, _, err := g.client.Labels.ListLabels(g.repoID, &gitlab.ListLabelsOptions{
Page: i, Page: i,
PerPage: perPage, PerPage: perPage,
}, nil) }, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -288,7 +288,7 @@ func (g *GitlabDownloader) GetReleases() ([]*base.Release, error) {
ls, _, err := g.client.Releases.ListReleases(g.repoID, &gitlab.ListReleasesOptions{ ls, _, err := g.client.Releases.ListReleases(g.repoID, &gitlab.ListReleasesOptions{
Page: i, Page: i,
PerPage: perPage, PerPage: perPage,
}, nil) }, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -305,11 +305,18 @@ func (g *GitlabDownloader) GetReleases() ([]*base.Release, error) {
// GetAsset returns an asset // GetAsset returns an asset
func (g *GitlabDownloader) GetAsset(tag string, id int64) (io.ReadCloser, error) { func (g *GitlabDownloader) GetAsset(tag string, id int64) (io.ReadCloser, error) {
link, _, err := g.client.ReleaseLinks.GetReleaseLink(g.repoID, tag, int(id)) link, _, err := g.client.ReleaseLinks.GetReleaseLink(g.repoID, tag, int(id), gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := http.Get(link.URL)
req, err := http.NewRequest("GET", link.URL, nil)
if err != nil {
return nil, err
}
req = req.WithContext(g.ctx)
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -336,7 +343,7 @@ func (g *GitlabDownloader) GetIssues(page, perPage int) ([]*base.Issue, bool, er
var allIssues = make([]*base.Issue, 0, perPage) var allIssues = make([]*base.Issue, 0, perPage)
issues, _, err := g.client.Issues.ListProjectIssues(g.repoID, opt, nil) issues, _, err := g.client.Issues.ListProjectIssues(g.repoID, opt, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, false, fmt.Errorf("error while listing issues: %v", err) return nil, false, fmt.Errorf("error while listing issues: %v", err)
} }
@ -393,14 +400,14 @@ func (g *GitlabDownloader) GetComments(issueNumber int64) ([]*base.Comment, erro
comments, resp, err = g.client.Discussions.ListIssueDiscussions(g.repoID, int(realIssueNumber), &gitlab.ListIssueDiscussionsOptions{ comments, resp, err = g.client.Discussions.ListIssueDiscussions(g.repoID, int(realIssueNumber), &gitlab.ListIssueDiscussionsOptions{
Page: page, Page: page,
PerPage: 100, PerPage: 100,
}, nil) }, nil, gitlab.WithContext(g.ctx))
} else { } else {
// If this is a PR, we need to figure out the Gitlab/original PR ID to be passed below // If this is a PR, we need to figure out the Gitlab/original PR ID to be passed below
realIssueNumber = issueNumber - g.issueCount realIssueNumber = issueNumber - g.issueCount
comments, resp, err = g.client.Discussions.ListMergeRequestDiscussions(g.repoID, int(realIssueNumber), &gitlab.ListMergeRequestDiscussionsOptions{ comments, resp, err = g.client.Discussions.ListMergeRequestDiscussions(g.repoID, int(realIssueNumber), &gitlab.ListMergeRequestDiscussionsOptions{
Page: page, Page: page,
PerPage: 100, PerPage: 100,
}, nil) }, nil, gitlab.WithContext(g.ctx))
} }
if err != nil { if err != nil {
@ -455,7 +462,7 @@ func (g *GitlabDownloader) GetPullRequests(page, perPage int) ([]*base.PullReque
var allPRs = make([]*base.PullRequest, 0, perPage) var allPRs = make([]*base.PullRequest, 0, perPage)
prs, _, err := g.client.MergeRequests.ListProjectMergeRequests(g.repoID, opt, nil) prs, _, err := g.client.MergeRequests.ListProjectMergeRequests(g.repoID, opt, nil, gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, fmt.Errorf("error while listing merge requests: %v", err) return nil, fmt.Errorf("error while listing merge requests: %v", err)
} }
@ -536,7 +543,7 @@ func (g *GitlabDownloader) GetPullRequests(page, perPage int) ([]*base.PullReque
// GetReviews returns pull requests review // GetReviews returns pull requests review
func (g *GitlabDownloader) GetReviews(pullRequestNumber int64) ([]*base.Review, error) { func (g *GitlabDownloader) GetReviews(pullRequestNumber int64) ([]*base.Review, error) {
state, _, err := g.client.MergeRequestApprovals.GetApprovalState(g.repoID, int(pullRequestNumber)) state, _, err := g.client.MergeRequestApprovals.GetApprovalState(g.repoID, int(pullRequestNumber), gitlab.WithContext(g.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -5,6 +5,7 @@
package migrations package migrations
import ( import (
"context"
"net/http" "net/http"
"os" "os"
"testing" "testing"
@ -27,7 +28,7 @@ func TestGitlabDownloadRepo(t *testing.T) {
t.Skipf("Can't access test repo, skipping %s", t.Name()) t.Skipf("Can't access test repo, skipping %s", t.Name())
} }
downloader := NewGitlabDownloader("https://gitlab.com", "gitea/test_repo", "", "", gitlabPersonalAccessToken) downloader := NewGitlabDownloader(context.Background(), "https://gitlab.com", "gitea/test_repo", "", "", gitlabPersonalAccessToken)
if downloader == nil { if downloader == nil {
t.Fatal("NewGitlabDownloader is nil") t.Fatal("NewGitlabDownloader is nil")
} }

@ -37,7 +37,7 @@ func MigrateRepository(ctx context.Context, doer *models.User, ownerName string,
for _, factory := range factories { for _, factory := range factories {
if factory.GitServiceType() == opts.GitServiceType { if factory.GitServiceType() == opts.GitServiceType {
downloader, err = factory.New(opts) downloader, err = factory.New(ctx, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -60,11 +60,9 @@ func MigrateRepository(ctx context.Context, doer *models.User, ownerName string,
uploader.gitServiceType = opts.GitServiceType uploader.gitServiceType = opts.GitServiceType
if setting.Migrations.MaxAttempts > 1 { if setting.Migrations.MaxAttempts > 1 {
downloader = base.NewRetryDownloader(downloader, setting.Migrations.MaxAttempts, setting.Migrations.RetryBackoff) downloader = base.NewRetryDownloader(ctx, downloader, setting.Migrations.MaxAttempts, setting.Migrations.RetryBackoff)
} }
downloader.SetContext(ctx)
if err := migrateRepository(downloader, uploader, opts); err != nil { if err := migrateRepository(downloader, uploader, opts); err != nil {
if err1 := uploader.Rollback(); err1 != nil { if err1 := uploader.Rollback(); err1 != nil {
log.Error("rollback failed: %v", err1) log.Error("rollback failed: %v", err1)

Loading…
Cancel
Save