diff --git a/Dockerfile b/Dockerfile index 1cebdc7..a6371ec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,11 +10,5 @@ FROM alpine:3.12 LABEL maintainer="maintainers@kolaente.dev" COPY --from=build-env /go/src/kolaente.dev/konrad/docker-db-backup / -COPY --from=postgres:14-alpine /usr/local/bin/pg_dump /usr/local/bin/pg_dump14 -COPY --from=postgres:13-alpine /usr/local/bin/pg_dump /usr/local/bin/pg_dump13 -COPY --from=postgres:12-alpine /usr/local/bin/pg_dump /usr/local/bin/pg_dump12 -COPY --from=postgres:11-alpine /usr/local/bin/pg_dump /usr/local/bin/pg_dump11 -COPY --from=postgres:10-alpine /usr/local/bin/pg_dump /usr/local/bin/pg_dump10 -COPY --from=mariadb:10 /usr/bin/mysqldump /usr/local/bin/mysqldump CMD ["/docker-db-backup"] diff --git a/dump.go b/dump.go index 4b0512c..f7e4f70 100644 --- a/dump.go +++ b/dump.go @@ -2,11 +2,12 @@ package main import ( "github.com/docker/docker/api/types" + "github.com/docker/docker/client" "strings" ) type Dumper interface { - Dump() error + Dump(c *client.Client) error } func NewDumperFromContainer(container *types.ContainerJSON) Dumper { @@ -22,12 +23,12 @@ func NewDumperFromContainer(container *types.ContainerJSON) Dumper { return nil } -func dumpAllDatabases() error { +func dumpAllDatabases(c *client.Client) error { lock.Lock() defer lock.Unlock() for _, dumper := range store { - err := dumper.Dump() + err := dumper.Dump(c) if err != nil { return err } diff --git a/dump_mysql.go b/dump_mysql.go index 0db5984..c88b387 100644 --- a/dump_mysql.go +++ b/dump_mysql.go @@ -3,6 +3,7 @@ package main import ( "fmt" "github.com/docker/docker/api/types" + "github.com/docker/docker/client" ) type MysqlDumper struct { @@ -49,10 +50,10 @@ func (m *MysqlDumper) buildDumpArgs() []string { return append(args, "--port", port, "-h", host, db) } -func (m *MysqlDumper) Dump() error { +func (m *MysqlDumper) Dump(c *client.Client) error { fmt.Printf("Dumping mysql database from container %s...\n", m.Container.Name) args := m.buildDumpArgs() - return runAndSaveCommand(getDumpFilename(m.Container.Name), "mysqldump", args...) + return runAndSaveCommandInContainer(getDumpFilename(m.Container.Name), c, m.Container, "mysqldump", args...) } diff --git a/dump_postgres.go b/dump_postgres.go index 7f32dae..a5a0774 100644 --- a/dump_postgres.go +++ b/dump_postgres.go @@ -3,7 +3,7 @@ package main import ( "fmt" "github.com/docker/docker/api/types" - "strings" + "github.com/docker/docker/client" ) type PostgresDumper struct { @@ -44,23 +44,10 @@ func (d *PostgresDumper) buildConnStr() string { return fmt.Sprintf("postgresql://%s:%s@%s:%s/%s", user, pw, host, port, db) } -func findPgVersion(env []string) string { - for _, s := range env { - if strings.HasPrefix(s, "PG_MAJOR=") { - return strings.TrimPrefix(s, "PG_MAJOR=") - } - } - - return "" -} - -func (d *PostgresDumper) Dump() error { +func (d *PostgresDumper) Dump(c *client.Client) error { fmt.Printf("Dumping postgres database from container %s...\n", d.Container.Name) connStr := d.buildConnStr() - // The postgres version must match the one the db server is running - pgVersion := findPgVersion(d.Container.Config.Env) - - return runAndSaveCommand(getDumpFilename(d.Container.Name), "pg_dump"+pgVersion, "--dbname", connStr) + return runAndSaveCommandInContainer(getDumpFilename(d.Container.Name), c, d.Container, "pg_dump", "--dbname", connStr) } diff --git a/dump_postgres_test.go b/dump_postgres_test.go index aff2f65..8e31a7f 100644 --- a/dump_postgres_test.go +++ b/dump_postgres_test.go @@ -115,28 +115,3 @@ func TestPostgresDumper_buildConnStr(t *testing.T) { }) } } - -func TestFindPGVersionFromEnv(t *testing.T) { - t.Run("no PG_MAJOR", func(t *testing.T) { - pgVersion := findPgVersion([]string{}) - if pgVersion != "" { - t.Errorf("Version is not empty") - } - }) - t.Run("pg 14", func(t *testing.T) { - pgVersion := findPgVersion([]string{ - "POSTGRES_PASSWORD=test", - "POSTGRES_USER=user", - "POSTGRES_DB=test", - "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/lib/postgresql/14/bin", - "GOSU_VERSION=1.14", - "LANG=en_US.utf8", - "PG_MAJOR=14", - "PG_VERSION=14.1-1.pgdg110+1", - "PGDATA=/var/lib/postgresql/data", - }) - if pgVersion != "14" { - t.Errorf("Version is not 14") - } - }) -} diff --git a/main.go b/main.go index 25a227d..6769237 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,7 @@ func main() { storeContainers(c, containers) - err = dumpAllDatabases() + err = dumpAllDatabases(c) if err != nil { // TODO: Only log errors while dumping dbs log.Fatalf("Could not dump databases: %s", err) diff --git a/save.go b/save.go index b1b3484..216b7d1 100644 --- a/save.go +++ b/save.go @@ -2,44 +2,71 @@ package main import ( "bytes" + "context" "fmt" "io" "os" - "os/exec" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" ) -func runAndSaveCommand(filename, command string, args ...string) error { - c := exec.Command(command, args...) - - //fmt.Printf("Running %s\n\n", c.String()) - +func runAndSaveCommandInContainer(filename string, c *client.Client, container *types.ContainerJSON, command string, args ...string) error { f, err := os.Create(filename) if err != nil { return err } defer f.Close() - stdout, err := c.StdoutPipe() + ctx := context.Background() + + config := types.ExecConfig{ + AttachStderr: true, + AttachStdout: true, + Cmd: append([]string{command}, args...), + } + + r, err := c.ContainerExecCreate(ctx, container.ID, config) if err != nil { return err } - var stderr bytes.Buffer - c.Stderr = &stderr - - err = c.Start() + resp, err := c.ContainerExecAttach(ctx, r.ID, types.ExecStartCheck{}) if err != nil { return err } + defer resp.Close() - _, err = io.Copy(f, stdout) + // read the output + var outBuf, errBuf bytes.Buffer + outputDone := make(chan error) + + go func() { + // StdCopy demultiplexes the stream into two buffers + _, err = stdcopy.StdCopy(&outBuf, &errBuf, resp.Reader) + outputDone <- err + }() + + select { + case err := <-outputDone: + if err != nil { + return err + } + break + + case <-ctx.Done(): + return ctx.Err() + } + + _, err = c.ContainerExecInspect(ctx, r.ID) if err != nil { + fmt.Printf(errBuf.String()) return err } - err = c.Wait() + _, err = io.Copy(f, &outBuf) if err != nil { - fmt.Printf(stderr.String()) return err }