Playing with multiple Contexts in Go

Playing with multiple Contexts in Go

Context package has been introduced quite a while but it is still a topic of discussions. Should we use Values or not, how should we pass it to the functions, etc. But this is not the topic of this article.

Let’s imagine that we have several different contexts, which are coming from different sources and we need to behave with them like with a single one.

For example, the first context comes from main() function where we want to control the general execution of the program, the second one, on the other hand, comes from requests or some event calls. Let’s assume that we’re working on some sort of web server with long running background tasks. And we’re starting these tasks on request to the web server. What should we cover here are graceful shutdown and cancellation of those background tasks.

Let’s write some simple task logic:

package main

import (
    "context"
    "net/http"
    "os"
    "strconv"
    "time"

    "github.com/justinas/alice"
    "github.com/rs/zerolog"
    "github.com/rs/zerolog/hlog"
    "github.com/rs/zerolog/log"
)

func main() {
    log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339Nano})

    handler := alice.New(hlog.NewHandler(log.Logger),
        hlog.URLHandler("url"),
        hlog.RemoteAddrHandler("ip"),
        hlog.UserAgentHandler("user_agent"),
        hlog.RefererHandler("referer")).
        ThenFunc(TaskHandler)
    srv := http.Server{
        Addr:    ":8080",
        Handler: handler,
    }
    log.Info().Msg("Listening HTTP on :8080")
    if err := srv.ListenAndServe(); err != nil {
        log.Fatal().Err(err).Msg("Error when running HTTP server")
    }
}

func TaskHandler(w http.ResponseWriter, r *http.Request) {
    l := hlog.FromRequest(r)
    l.Info().Msgf("%s %s", r.Method, r.URL.RequestURI())
    t := r.FormValue("time")
    if t == "" {
        w.WriteHeader(http.StatusBadRequest)
        return
    }
    duration, err := strconv.Atoi(t)
    if err != nil {
        w.WriteHeader(http.StatusBadRequest)
        return
    }
    go taskManager(l.WithContext(context.Background()), duration)
    w.Write([]byte("Started"))
}

func taskManager(ctx context.Context, duration int) {
    ctx, _ = context.WithTimeout(ctx, time.Minute)
    task(ctx, duration)
}

func task(ctx context.Context, duration int) {
    l := log.Ctx(ctx)
    l.Info().Msgf("Task %d second(s): STARTED", duration)
    select {
    case <-ctx.Done():
        l.Info().Msgf("Task %d second(s): CANCELED", duration)
    case <-time.After(time.Duration(duration) * time.Second):
        l.Info().Msgf("Task %d second(s): FINISHED", duration)
    }
}

So we have pretty logger with Zerolog, which is added to requests in middleware. We have a task manager which does “a very complex job”, it adds timeout. Let’s suppose that task workload is blackbox which we should not change, but we can work around it.

Our test cases will be such:

We’re requesting localhost:8080/?time=duration. Where duration is a sleep time in seconds.

Happy path

2019-02-22T10:52:24+02:00 INF Listening HTTP on :8080
2019-02-22T10:52:36+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
2019-02-22T10:52:36+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
2019-02-22T10:53:06+02:00 INF Task 30 second(s): FINISHED ip=::1 url=/?time=30 user_agent="..."

Task timeout

2019-02-21T21:02:05+02:00 INF Listening HTTP on :8080
2019-02-21T21:02:11+02:00 INF GET /?time=120 ip=::1 url=/?time=120 user_agent="..."
2019-02-21T21:02:11+02:00 INF Task 120 second(s): STARTED ip=::1 url=/?time=120 user_agent="..."
2019-02-21T21:03:11+02:00 INF Task 120 second(s): CANCELED ip=::1 url=/?time=120 user_agent="..."

SIGINT from console

2019-02-21T21:00:45+02:00 INF Listening HTTP on :8080
2019-02-21T21:00:53+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
2019-02-21T21:00:53+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
^C

Ok, looks fine. We’re using context.Backgound() instead of r.Context() because context which comes from request will be canceled when handler will have finished the request processing and our task will be canceled too. That is why we use the new context with a timeout.

Let’s add graceful shutdown with the timeout of 10 seconds.

package main

import (
    "context"
    "net/http"
    "os"
    "os/signal"
    "strconv"
    "time"

    "github.com/justinas/alice"
    "github.com/rs/zerolog"
    "github.com/rs/zerolog/hlog"
    "github.com/rs/zerolog/log"
)

func main() {
    log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339Nano})

    handler := alice.New(hlog.NewHandler(log.Logger),
        hlog.URLHandler("url"),
        hlog.RemoteAddrHandler("ip"),
        hlog.UserAgentHandler("user_agent"),
        hlog.RefererHandler("referer")).
        ThenFunc(TaskHandler)
    srv := http.Server{
        Addr:    ":8080",
        Handler: handler,
    }

    go func() {
        log.Info().Msg("Listening HTTP on :8080")
        if err := srv.ListenAndServe(); err != http.ErrServerClosed {
            // Error starting or closing listener:
            log.Fatal().Err(err).Msg("Error when running HTTP server")
        }
    }()

    sigint := make(chan os.Signal, 1)
    signal.Notify(sigint, os.Interrupt)
    <-sigint
    // We received an interrupt signal, shut down.
    log.Info().Msg("Shutting down...")
    ctx, _ := context.WithTimeout(context.Background(), time.Second*10)
    if err := srv.Shutdown(ctx); err != nil {
        // Error from closing listeners, or context timeout:
        log.Error().Err(err).Msg("HTTP server shutdown error")
    }
    log.Info().Msg("Server has been stopped")
}

func TaskHandler(w http.ResponseWriter, r *http.Request) {
    l := hlog.FromRequest(r)
    l.Info().Msgf("%s %s", r.Method, r.URL.RequestURI())
    t := r.FormValue("time")
    if t == "" {
        w.WriteHeader(http.StatusBadRequest)
        return
    }
    duration, err := strconv.Atoi(t)
    if err != nil {
        w.WriteHeader(http.StatusBadRequest)
        return
    }
    go taskManager(l.WithContext(context.Background()), duration)
    w.Write([]byte("Started"))
}

func taskManager(ctx context.Context, duration int) {
    ctx, _ = context.WithTimeout(ctx, time.Minute)
    task(ctx, duration)
}

func task(ctx context.Context, duration int) {
    l := log.Ctx(ctx)
    l.Info().Msgf("Task %d second(s): STARTED", duration)
    select {
    case <-ctx.Done():
        l.Info().Msgf("Task %d second(s): CANCELED", duration)
    case <-time.After(time.Duration(duration) * time.Second):
        l.Info().Msgf("Task %d second(s): FINISHED", duration)
    }
}

Cool… Now web server will be waiting for unfinished requests and only then allow us to stop the application. But what have we missed here?

2019-02-21T21:04:46+02:00 INF Listening HTTP on :8080
2019-02-21T21:04:53+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
2019-02-21T21:04:53+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
^C
2019-02-21T21:04:57+02:00 INF Shutting down...
2019-02-21T21:04:59+02:00 INF Server has been stopped

All our background jobs will be rudely interrupted without any warning. This is bad. They should be gracefully finished. Moreover, we already have a classic interface to cancel all our jobs ctx context.Context. But this context comes from the handler and task manager, and we cannot seamlessly combine it with some global cancellation mechanism.

We should keep in mind that context.Context is an interface and we can replace it with something more functional which will help us with our issue. So we can just merge existing contexts in one piece. How to do that?

main.go

package main

import (
    "context"
    "net/http"
    "os"
    "os/signal"
    "strconv"
    "sync"
    "time"

    "github.com/justinas/alice"
    "github.com/rs/zerolog"
    "github.com/rs/zerolog/hlog"

    "github.com/rs/zerolog/log"
)

var (
    mainCtx         context.Context
    mu              sync.RWMutex
    shutdownTimeout = time.Second * 10
)

func main() {
    var mainCancel context.CancelFunc
    mainCtx, mainCancel = context.WithCancel(context.Background())

    log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339Nano})

    handler := alice.New(hlog.NewHandler(log.Logger),
        hlog.URLHandler("url"),
        hlog.RemoteAddrHandler("ip"),
        hlog.UserAgentHandler("user_agent"),
        hlog.RefererHandler("referer")).
        ThenFunc(TaskHandler)
    srv := http.Server{
        Addr:    ":8080",
        Handler: handler,
    }

    go func() {
        log.Info().Msg("Listening HTTP on :8080")
        if err := srv.ListenAndServe(); err != http.ErrServerClosed {
            // Error starting or closing listener:
            log.Fatal().Err(err).Msg("Error when running HTTP server")
        }
    }()

    sigint := make(chan os.Signal, 1)
    signal.Notify(sigint, os.Interrupt)
    <-sigint
    // We received an interrupt signal, shut down.
    log.Info().Msg("Shutting down...")

    // Time before forced task cancellation
    timeoutCh := time.After(shutdownTimeout)

    // Timeout for HTTP server only
    ctx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
    if err := srv.Shutdown(ctx); err != nil {
        // Error from closing listeners, or context timeout:
        log.Error().Err(err).Msg("HTTP server shutdown error")
    }

    tasksCancelled := make(chan struct{})
    go func() {
        // Mutex will wait for all tasks to unlock their read locks
        // and then will block new tasks to start
        mu.Lock()
        // Notify main thread that all tasks are stopped
        close(tasksCancelled)
        mu.Unlock()
    }()

    // Waiting until tasks cancelled or time.After event came
    select {
    case <-tasksCancelled:
    case <-timeoutCh:
        log.Info().Msg("Forcing shutdown")
        // Cancel tasks
        mainCancel()
        // Wait until all done
        <-tasksCancelled
    }

    log.Info().Msg("Server has been stopped")
}

func TaskHandler(w http.ResponseWriter, r *http.Request) {
    l := hlog.FromRequest(r)
    l.Info().Msgf("%s %s", r.Method, r.URL.RequestURI())
    t := r.FormValue("time")
    if t == "" {
        w.WriteHeader(http.StatusBadRequest)
        return
    }
    duration, err := strconv.Atoi(t)
    if err != nil {
        w.WriteHeader(http.StatusBadRequest)
        return
    }
    go taskManager(l.WithContext(context.Background()), duration)
    w.Write([]byte("Started"))
}

func taskManager(ctx context.Context, duration int) {
    ctx, _ = context.WithTimeout(ctx, time.Minute)
    ctx = mergeContexts(mainCtx, ctx)
    mu.RLock()
    task(ctx, duration)
    mu.RUnlock()
}

func task(ctx context.Context, duration int) {
    l := log.Ctx(ctx)
    l.Info().Msgf("Task %d second(s): STARTED", duration)
    select {
    case <-ctx.Done():
        l.Info().Msgf("Task %d second(s): CANCELED", duration)
    case <-time.After(time.Duration(duration) * time.Second):
        l.Info().Msgf("Task %d second(s): FINISHED", duration)
    }
}

merge.go

package main

import (
    "context"
    "sync"
    "time"
)

type mergedContext struct {
    mu      sync.Mutex
    mainCtx context.Context
    ctx     context.Context
    done    chan struct{}
    err     error
}

func mergeContexts(mainCtx, ctx context.Context) context.Context {
    c := &mergedContext{mainCtx: mainCtx, ctx: ctx, done: make(chan struct{})}
    go c.run()
    return c
}

func (c *mergedContext) Done() <-chan struct{} {
    return c.done
}

func (c *mergedContext) Err() error {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.err
}

func (c *mergedContext) Deadline() (deadline time.Time, ok bool) {
    var d time.Time
    d1, ok1 := c.ctx.Deadline()
    d2, ok2 := c.mainCtx.Deadline()
    if ok1 && d1.UnixNano() < d2.UnixNano() {
        d = d1
    } else if ok2 {
        d = d2
    }
    return d, ok1 || ok2
}

func (c *mergedContext) Value(key interface{}) interface{} {
    return c.ctx.Value(key)
}

func (c *mergedContext) run() {
    var doneCtx context.Context
    select {
    case <-c.mainCtx.Done():
        doneCtx = c.mainCtx
    case <-c.ctx.Done():
        doneCtx = c.ctx
    case <-c.done:
        return
    }

    c.mu.Lock()
    if c.err != nil {
        c.mu.Unlock()
        return
    }
    c.err = doneCtx.Err()
    c.mu.Unlock()
    close(c.done)
}

In merge.go you can see all the needed logic for merging two contexts. We combine only cancellation and deadline logic from both of them, but values are taken from the second context only. It’s easy to implement any behavior other on that. Now calling cancel or raising timeout in one of the contexts will cancel both.

In main.go we have some additional code to control cancellation.

  1. mainCtx global variable (which should be moved from globals to task manager object in a real program) and mainCancel function. We will use them to control cancellation
  2. We have RWMutex mu which will be used as a three-position switch. Tasks will be switching it to RLock simultaneously until mutex switched to the full Lock state
  3. Channels timeoutCh and tasksCancelled are used for waiting for tasks to be finished. If timeout is reached we cancel all tasks with mainCancel function and wait for tasksCancelled channel to be closed. That is it. Let’s try how it works.
    2019-02-22T11:20:02+02:00 INF Listening HTTP on :8080
    2019-02-22T11:20:07+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:20:07+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
    ^C
    2019-02-22T11:20:10+02:00 INF Shutting down...
    2019-02-22T11:20:20+02:00 INF Forcing shutdown
    2019-02-22T11:20:20+02:00 INF Task 30 second(s): CANCELED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:20:20+02:00 INF Server has been stopped
    
    Ok, what about simultaneous tasks?
    2019-02-22T11:22:03+02:00 INF Listening HTTP on :8080
    2019-02-22T11:22:04+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:04+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF GET /favicon.ico ip=::1 referer=http://localhost:8080/?time=30 url=/favicon.ico user_agent="..."
    2019-02-22T11:22:05+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF GET /?time=30 ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:05+02:00 INF Task 30 second(s): STARTED ip=::1 url=/?time=30 user_agent="..."
    ^C
    2019-02-22T11:22:07+02:00 INF Shutting down...
    2019-02-22T11:22:17+02:00 INF Forcing shutdown
    2019-02-22T11:22:17+02:00 INF Task 30 second(s): CANCELED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:17+02:00 INF Task 30 second(s): CANCELED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:17+02:00 INF Task 30 second(s): CANCELED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:17+02:00 INF Task 30 second(s): CANCELED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:17+02:00 INF Task 30 second(s): CANCELED ip=::1 url=/?time=30 user_agent="..."
    2019-02-22T11:22:17+02:00 INF Server has been stopped
    
    Works like a charm.

As we see here context.Context is a very powerful paradigm. We can easily change its behavior as we need by reimplementing several parts. And by doing that we can use a standard and convenient way to control or cancel our background tasks.

Cheers! 1_ZjyJa9T8fgEittM6vFncfg.png