From 3183a465d71a13535e52589bb85b987176872fcd Mon Sep 17 00:00:00 2001 From: zeripath Date: Mon, 31 May 2021 07:18:11 +0100 Subject: [PATCH] Make modules/context.Context a context.Context (#16031) * Make modules/context.Context a context.Context Signed-off-by: Andrew Thornton * Simplify context calls Signed-off-by: Andrew Thornton * Set the base context for requests to the HammerContext Signed-off-by: Andrew Thornton Co-authored-by: Lunny Xiao --- modules/context/context.go | 22 +++++++++++++++++++++- modules/graceful/server_http.go | 3 +++ routers/admin/users.go | 4 ++-- routers/api/v1/admin/user.go | 4 ++-- routers/events/events.go | 2 +- routers/install.go | 4 ++-- routers/private/manager.go | 2 +- routers/private/restore_repo.go | 2 +- routers/repo/blame.go | 2 +- routers/repo/lfs.go | 2 +- routers/user/auth.go | 12 ++++++------ routers/user/auth_openid.go | 4 ++-- routers/user/setting/account.go | 2 +- services/archiver/archiver.go | 4 ++-- 14 files changed, 46 insertions(+), 23 deletions(-) diff --git a/modules/context/context.go b/modules/context/context.go index d812d7b58..d45e9ff87 100644 --- a/modules/context/context.go +++ b/modules/context/context.go @@ -509,7 +509,7 @@ func (ctx *Context) ParamsInt64(p string) int64 { // SetParams set params into routes func (ctx *Context) SetParams(k, v string) { - chiCtx := chi.RouteContext(ctx.Req.Context()) + chiCtx := chi.RouteContext(ctx) chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v)) } @@ -528,6 +528,26 @@ func (ctx *Context) Status(status int) { ctx.Resp.WriteHeader(status) } +// Deadline is part of the interface for context.Context and we pass this to the request context +func (ctx *Context) Deadline() (deadline time.Time, ok bool) { + return ctx.Req.Context().Deadline() +} + +// Done is part of the interface for context.Context and we pass this to the request context +func (ctx *Context) Done() <-chan struct{} { + return ctx.Req.Context().Done() +} + +// Err is part of the interface for context.Context and we pass this to the request context +func (ctx *Context) Err() error { + return ctx.Req.Context().Err() +} + +// Value is part of the interface for context.Context and we pass this to the request context +func (ctx *Context) Value(key interface{}) interface{} { + return ctx.Req.Context().Value(key) +} + // Handler represents a custom handler type Handler func(*Context) diff --git a/modules/graceful/server_http.go b/modules/graceful/server_http.go index b101a10d9..4471e379e 100644 --- a/modules/graceful/server_http.go +++ b/modules/graceful/server_http.go @@ -5,7 +5,9 @@ package graceful import ( + "context" "crypto/tls" + "net" "net/http" ) @@ -16,6 +18,7 @@ func newHTTPServer(network, address, name string, handler http.Handler) (*Server WriteTimeout: DefaultWriteTimeOut, MaxHeaderBytes: DefaultMaxHeaderBytes, Handler: handler, + BaseContext: func(net.Listener) context.Context { return GetManager().HammerContext() }, } server.OnShutdown = func() { httpServer.SetKeepAlivesEnabled(false) diff --git a/routers/admin/users.go b/routers/admin/users.go index 3b29eeefc..a71a11dd8 100644 --- a/routers/admin/users.go +++ b/routers/admin/users.go @@ -113,7 +113,7 @@ func NewUserPost(ctx *context.Context) { ctx.RenderWithErr(password.BuildComplexityError(ctx), tplUserNew, &form) return } - pwned, err := password.IsPwned(ctx.Req.Context(), form.Password) + pwned, err := password.IsPwned(ctx, form.Password) if pwned { ctx.Data["Err_Password"] = true errMsg := ctx.Tr("auth.password_pwned") @@ -256,7 +256,7 @@ func EditUserPost(ctx *context.Context) { ctx.RenderWithErr(password.BuildComplexityError(ctx), tplUserEdit, &form) return } - pwned, err := password.IsPwned(ctx.Req.Context(), form.Password) + pwned, err := password.IsPwned(ctx, form.Password) if pwned { ctx.Data["Err_Password"] = true errMsg := ctx.Tr("auth.password_pwned") diff --git a/routers/api/v1/admin/user.go b/routers/api/v1/admin/user.go index 2d4a3815f..4bbe7f77b 100644 --- a/routers/api/v1/admin/user.go +++ b/routers/api/v1/admin/user.go @@ -88,7 +88,7 @@ func CreateUser(ctx *context.APIContext) { ctx.Error(http.StatusBadRequest, "PasswordComplexity", err) return } - pwned, err := password.IsPwned(ctx.Req.Context(), form.Password) + pwned, err := password.IsPwned(ctx, form.Password) if pwned { if err != nil { log.Error(err.Error()) @@ -162,7 +162,7 @@ func EditUser(ctx *context.APIContext) { ctx.Error(http.StatusBadRequest, "PasswordComplexity", err) return } - pwned, err := password.IsPwned(ctx.Req.Context(), form.Password) + pwned, err := password.IsPwned(ctx, form.Password) if pwned { if err != nil { log.Error(err.Error()) diff --git a/routers/events/events.go b/routers/events/events.go index 2c1034038..b140bf660 100644 --- a/routers/events/events.go +++ b/routers/events/events.go @@ -42,7 +42,7 @@ func Events(ctx *context.Context) { } // Listen to connection close and un-register messageChan - notify := ctx.Req.Context().Done() + notify := ctx.Done() ctx.Resp.Flush() shutdownCtx := graceful.GetManager().ShutdownContext() diff --git a/routers/install.go b/routers/install.go index ef53422c4..30340e99c 100644 --- a/routers/install.go +++ b/routers/install.go @@ -400,7 +400,7 @@ func InstallPost(ctx *context.Context) { } // Re-read settings - PostInstallInit(ctx.Req.Context()) + PostInstallInit(ctx) // Create admin account if len(form.AdminName) > 0 { @@ -454,7 +454,7 @@ func InstallPost(ctx *context.Context) { // Now get the http.Server from this request and shut it down // NB: This is not our hammerable graceful shutdown this is http.Server.Shutdown - srv := ctx.Req.Context().Value(http.ServerContextKey).(*http.Server) + srv := ctx.Value(http.ServerContextKey).(*http.Server) go func() { if err := srv.Shutdown(graceful.GetManager().HammerContext()); err != nil { log.Error("Unable to shutdown the install server! Error: %v", err) diff --git a/routers/private/manager.go b/routers/private/manager.go index 192c4947e..1ccb18436 100644 --- a/routers/private/manager.go +++ b/routers/private/manager.go @@ -35,7 +35,7 @@ func FlushQueues(ctx *context.PrivateContext) { }) return } - err := queue.GetManager().FlushAll(ctx.Req.Context(), opts.Timeout) + err := queue.GetManager().FlushAll(ctx, opts.Timeout) if err != nil { ctx.JSON(http.StatusRequestTimeout, map[string]interface{}{ "err": fmt.Sprintf("%v", err), diff --git a/routers/private/restore_repo.go b/routers/private/restore_repo.go index c002de874..df787e1b3 100644 --- a/routers/private/restore_repo.go +++ b/routers/private/restore_repo.go @@ -36,7 +36,7 @@ func RestoreRepo(ctx *myCtx.PrivateContext) { } if err := migrations.RestoreRepository( - ctx.Req.Context(), + ctx, params.RepoDir, params.OwnerName, params.RepoName, diff --git a/routers/repo/blame.go b/routers/repo/blame.go index f5b228bdf..1a3e1dcb9 100644 --- a/routers/repo/blame.go +++ b/routers/repo/blame.go @@ -124,7 +124,7 @@ func RefBlame(ctx *context.Context) { return } - blameReader, err := git.CreateBlameReader(ctx.Req.Context(), models.RepoPath(userName, repoName), commitID, fileName) + blameReader, err := git.CreateBlameReader(ctx, models.RepoPath(userName, repoName), commitID, fileName) if err != nil { ctx.NotFound("CreateBlameReader", err) return diff --git a/routers/repo/lfs.go b/routers/repo/lfs.go index 3a7ce2e23..c17bd2f87 100644 --- a/routers/repo/lfs.go +++ b/routers/repo/lfs.go @@ -414,7 +414,7 @@ func LFSPointerFiles(ctx *context.Context) { err = func() error { pointerChan := make(chan lfs.PointerBlob) errChan := make(chan error, 1) - go lfs.SearchPointerBlobs(ctx.Req.Context(), ctx.Repo.GitRepo, pointerChan, errChan) + go lfs.SearchPointerBlobs(ctx, ctx.Repo.GitRepo, pointerChan, errChan) numPointers := 0 var numAssociated, numNoExist, numAssociatable int diff --git a/routers/user/auth.go b/routers/user/auth.go index 5f8b1a6b9..827b7cdef 100644 --- a/routers/user/auth.go +++ b/routers/user/auth.go @@ -1011,9 +1011,9 @@ func LinkAccountPostRegister(ctx *context.Context) { case setting.ImageCaptcha: valid = context.GetImageCaptcha().VerifyReq(ctx.Req) case setting.ReCaptcha: - valid, err = recaptcha.Verify(ctx.Req.Context(), form.GRecaptchaResponse) + valid, err = recaptcha.Verify(ctx, form.GRecaptchaResponse) case setting.HCaptcha: - valid, err = hcaptcha.Verify(ctx.Req.Context(), form.HcaptchaResponse) + valid, err = hcaptcha.Verify(ctx, form.HcaptchaResponse) default: ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType)) return @@ -1153,9 +1153,9 @@ func SignUpPost(ctx *context.Context) { case setting.ImageCaptcha: valid = context.GetImageCaptcha().VerifyReq(ctx.Req) case setting.ReCaptcha: - valid, err = recaptcha.Verify(ctx.Req.Context(), form.GRecaptchaResponse) + valid, err = recaptcha.Verify(ctx, form.GRecaptchaResponse) case setting.HCaptcha: - valid, err = hcaptcha.Verify(ctx.Req.Context(), form.HcaptchaResponse) + valid, err = hcaptcha.Verify(ctx, form.HcaptchaResponse) default: ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType)) return @@ -1191,7 +1191,7 @@ func SignUpPost(ctx *context.Context) { ctx.RenderWithErr(password.BuildComplexityError(ctx), tplSignUp, &form) return } - pwned, err := password.IsPwned(ctx.Req.Context(), form.Password) + pwned, err := password.IsPwned(ctx, form.Password) if pwned { errMsg := ctx.Tr("auth.password_pwned") if err != nil { @@ -1620,7 +1620,7 @@ func ResetPasswdPost(ctx *context.Context) { ctx.Data["Err_Password"] = true ctx.RenderWithErr(password.BuildComplexityError(ctx), tplResetPassword, nil) return - } else if pwned, err := password.IsPwned(ctx.Req.Context(), passwd); pwned || err != nil { + } else if pwned, err := password.IsPwned(ctx, passwd); pwned || err != nil { errMsg := ctx.Tr("auth.password_pwned") if err != nil { log.Error(err.Error()) diff --git a/routers/user/auth_openid.go b/routers/user/auth_openid.go index b1dfc6ada..1a73a08c4 100644 --- a/routers/user/auth_openid.go +++ b/routers/user/auth_openid.go @@ -385,13 +385,13 @@ func RegisterOpenIDPost(ctx *context.Context) { ctx.ServerError("", err) return } - valid, err = recaptcha.Verify(ctx.Req.Context(), form.GRecaptchaResponse) + valid, err = recaptcha.Verify(ctx, form.GRecaptchaResponse) case setting.HCaptcha: if err := ctx.Req.ParseForm(); err != nil { ctx.ServerError("", err) return } - valid, err = hcaptcha.Verify(ctx.Req.Context(), form.HcaptchaResponse) + valid, err = hcaptcha.Verify(ctx, form.HcaptchaResponse) default: ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType)) return diff --git a/routers/user/setting/account.go b/routers/user/setting/account.go index e12d63ee0..48ab37d93 100644 --- a/routers/user/setting/account.go +++ b/routers/user/setting/account.go @@ -58,7 +58,7 @@ func AccountPost(ctx *context.Context) { ctx.Flash.Error(ctx.Tr("form.password_not_match")) } else if !password.IsComplexEnough(form.Password) { ctx.Flash.Error(password.BuildComplexityError(ctx)) - } else if pwned, err := password.IsPwned(ctx.Req.Context(), form.Password); pwned || err != nil { + } else if pwned, err := password.IsPwned(ctx, form.Password); pwned || err != nil { errMsg := ctx.Tr("auth.password_pwned") if err != nil { log.Error(err.Error()) diff --git a/services/archiver/archiver.go b/services/archiver/archiver.go index 359fc8b62..dfa6334d9 100644 --- a/services/archiver/archiver.go +++ b/services/archiver/archiver.go @@ -76,7 +76,7 @@ func (aReq *ArchiveRequest) IsComplete() bool { func (aReq *ArchiveRequest) WaitForCompletion(ctx *context.Context) bool { select { case <-aReq.cchan: - case <-ctx.Req.Context().Done(): + case <-ctx.Done(): } return aReq.IsComplete() @@ -92,7 +92,7 @@ func (aReq *ArchiveRequest) TimedWaitForCompletion(ctx *context.Context, dur tim case <-time.After(dur): timeout = true case <-aReq.cchan: - case <-ctx.Req.Context().Done(): + case <-ctx.Done(): } return aReq.IsComplete(), timeout