-
Notifications
You must be signed in to change notification settings - Fork 0
/
incept.go
276 lines (224 loc) · 6.67 KB
/
incept.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
//go:build darwin || freebsd || linux
// +build darwin freebsd linux
package incept
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"syscall"
"time"
)
const (
defaultShutdownGraceTime = 15 * time.Second
envChildMarker = "INCEPT_CHILD"
binaryBackupFilename = ".replace.tmp"
)
// Incept denotes an instance of a process reloader
type Incept struct {
shutdownGraceTime time.Duration
p *os.Process
argv0 string
workingDir string
binaryBackupPath string
signalTerm os.Signal
signalKill os.Signal
exitFn func(code int)
}
// New instanciates a nnew Incept instance (should be called as soon as possible)
// When called from the parent / main process, it will simply fork a subprocess and wait forever
// When called in the child process, it will continue and serve as the replacable process for
// any updates in the future
func New(options ...func(*Incept)) (*Incept, error) {
// Determine binary path and working directory
argv0, wd, err := getBinaryPaths()
if nil != err {
return nil, err
}
// Initialize a new incept instance with default parameters
i := &Incept{
shutdownGraceTime: defaultShutdownGraceTime,
argv0: argv0,
workingDir: wd,
binaryBackupPath: filepath.Join(filepath.Dir(argv0), binaryBackupFilename),
signalTerm: syscall.SIGTERM,
signalKill: syscall.SIGKILL,
exitFn: os.Exit,
}
// Execute functional options, if any
for _, opt := range options {
opt(i)
}
// Fork a child process if this is the parent and wait forever, otherwise continue
if !i.IsChild() {
return i.handleSignal()
}
return i, nil
}
// IsChild returns if this is a child process
func (i *Incept) IsChild() bool {
return os.Getenv(envChildMarker) != ""
}
// Update performs the update, provided a new binary to load and an optional list
// of functions to execute prior to the replacement (e.g. server web server shutdown)
func (i *Incept) Update(binary []byte, shutdownFn ...func() error) error {
// Perform a stat() call to extract the file permissions of the current
// binary for transfer to the new one
stat, err := os.Stat(i.argv0)
if err != nil {
return err
}
// Rename the currently running binary to a temporary file
if err := os.Rename(i.argv0, i.binaryBackupPath); err != nil {
return fmt.Errorf("failed to rename existing binary: %w", err)
}
// Write the new binary
if err := os.WriteFile(i.argv0, binary, stat.Mode().Perm()); err != nil {
return fmt.Errorf("failed to write new binary: %w", err)
}
// Ensure the update is performed after returning from this method
// TODO: Either handle errors properly somehow or implement a zero-downtime way of replacing
// the binary in case there is a webserver (otherwise the in-line execution here would Kill
// any existing connection)
defer func() {
// TODO: This is probably still racy and only works because the return -> potential server handler
// is much faster than the execution of the shutdownFns. Maybe there's better ways
go func() {
if err := i.triggerUpdate(shutdownFn...); err != nil {
fmt.Println("got error shutting down:", err)
}
}()
}()
return nil
}
// Binary returns the currently running binary
func (i *Incept) Binary() ([]byte, error) {
return os.ReadFile(i.argv0)
}
/////////////////////////////////////////////////////////////
func (i *Incept) handleSignal() (*Incept, error) {
signalChild := make(chan os.Signal, 1)
defer close(signalChild)
signal.Notify(signalChild, syscall.SIGUSR2, syscall.SIGCHLD)
defer signal.Stop(signalChild)
p, err := i.fork()
if err != nil {
return nil, err
}
i.p = p
for {
// Process incoming signal
// TODO: Make OS specific and handle in extra method
s := <-signalChild
switch s {
// If SIGCHLD was received, the child terminated (or was terminated). Propagate
// child return value and exit
case syscall.SIGCHLD:
var ws syscall.WaitStatus
if _, err := syscall.Wait4(i.p.Pid, &ws, syscall.WNOHANG, nil); err != nil {
return nil, fmt.Errorf("error handling SIGCHLD: %w", err)
}
i.exitFn(ws.ExitStatus())
return i, err
// If SIGUSR2 was received, the child indicates that it wants to be restarted
// Fork a new child and terminate the old one
case syscall.SIGUSR2:
p, err = i.fork()
if err != nil {
return nil, fmt.Errorf("error forking after SIGUSR2: %w", err)
}
if err := i.shutdownChild(); err != nil {
return nil, fmt.Errorf("error shutting down child after SIGUSR2: %w", err)
}
<-signalChild
// Remove the old binary
if err := os.RemoveAll(i.binaryBackupPath); err != nil {
return nil, fmt.Errorf("error removing backup binary after SIGUSR2: %w", err)
}
i.p = p
}
}
}
func (i *Incept) triggerUpdate(shutdownFn ...func() error) error {
// Execute shutdown handlers, if any
for _, fn := range shutdownFn {
if err := fn(); err != nil {
return err
}
}
// Indicate to the parent / master process that the child is ready to be replaced
return syscall.Kill(os.Getppid(), syscall.SIGUSR2)
}
func (i *Incept) fork() (*os.Process, error) {
argv0, wd, err := getBinaryPaths()
if nil != err {
return nil, err
}
env := append(os.Environ(), fmt.Sprintf("%s=TRUE", envChildMarker))
p, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{
Dir: wd,
Env: env,
Files: getFDs(),
Sys: &sysProcAttr,
})
if err != nil {
return nil, err
}
return p, nil
}
func (i *Incept) shutdownChild() error {
if err := i.p.Signal(i.signalTerm); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), i.shutdownGraceTime)
defer cancel()
result := make(chan error)
go func() {
_, err := i.p.Wait()
result <- err
}()
select {
case err := <-result:
if err == nil {
return nil
}
case <-ctx.Done():
break
}
return i.p.Signal(i.signalKill)
}
// VerifyBinaryChecksum provides a simple helper that allows to cross-check
// a provided binary against a known (or side-channeled) checksum
func VerifyBinaryChecksum(data []byte, expectedChecksum []byte) error {
hash := sha256.New()
if n, err := hash.Write(data); err != nil || len(data) != n {
return fmt.Errorf("invalid binary data submitted for hashing")
}
checksum := hex.EncodeToString(hash.Sum(nil))
if checksum != string(expectedChecksum) {
return fmt.Errorf("mismatching binary checksums: expected `%s`, got `%s`", expectedChecksum, checksum)
}
return nil
}
func getFDs() []*os.File {
return []*os.File{
os.Stdin,
os.Stdout,
os.Stderr,
}
}
func getBinaryPaths() (argv0 string, wd string, err error) {
argv0, err = exec.LookPath(os.Args[0])
if nil != err {
return
}
if _, err = os.Stat(argv0); nil != err {
return
}
wd, err = os.Getwd()
return
}