diff --git a/go.mod b/go.mod index b0c286c17b8..fe0c5a5e293 100644 --- a/go.mod +++ b/go.mod @@ -179,6 +179,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jaegertracing/jaeger-idl v0.6.0 // indirect github.com/jcmturner/aescts/v2 v2.0.0 // indirect github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index 29cbb301ed5..1a4b2b9e7a8 100644 --- a/go.sum +++ b/go.sum @@ -658,7 +658,6 @@ github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8 github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= diff --git a/retriever/postgresqlretriever/postgres.go b/retriever/postgresqlretriever/postgres.go new file mode 100644 index 00000000000..b2c2daa4ecf --- /dev/null +++ b/retriever/postgresqlretriever/postgres.go @@ -0,0 +1,64 @@ +package postgresqlretriever + +import ( + "context" + "sync" + + "github.com/jackc/pgx/v5/pgxpool" +) + +type poolEntry struct { + pool *pgxpool.Pool + refCount int +} + +var ( + mu sync.Mutex + poolMap = make(map[string]*poolEntry) +) + +// GetPool returns a pool for a given URI, creating it if needed. +func GetPool(ctx context.Context, uri string) (*pgxpool.Pool, error) { + mu.Lock() + defer mu.Unlock() + + // If already exists, bump refCount + if entry, ok := poolMap[uri]; ok { + entry.refCount++ + return entry.pool, nil + } + + // Create a new pool + pool, err := pgxpool.New(ctx, uri) + if err != nil { + return nil, err + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, err + } + + poolMap[uri] = &poolEntry{ + pool: pool, + refCount: 1, + } + + return pool, nil +} + +// ReleasePool decreases refCount and closes/removes when it hits zero. +func ReleasePool(ctx context.Context, uri string) { + mu.Lock() + defer mu.Unlock() + + entry, ok := poolMap[uri] + if !ok { + return // nothing to do + } + + entry.refCount-- + if entry.refCount <= 0 { + entry.pool.Close() + delete(poolMap, uri) + } +} diff --git a/retriever/postgresqlretriever/postgres_test.go b/retriever/postgresqlretriever/postgres_test.go new file mode 100644 index 00000000000..2287b5e2ad6 --- /dev/null +++ b/retriever/postgresqlretriever/postgres_test.go @@ -0,0 +1,85 @@ +package postgresqlretriever_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + rr "github.com/thomaspoignant/go-feature-flag/retriever/postgresqlretriever" +) + +func TestGetPool_MultipleURIsAndReuse(t *testing.T) { + ctx := context.Background() + + // Setup Container + req := testcontainers.ContainerRequest{ + Image: "postgres:15-alpine", + ExposedPorts: []string{"5432/tcp"}, + Env: map[string]string{"POSTGRES_PASSWORD": "password"}, + // This waits until the log says the system is ready, preventing connection errors + WaitingFor: wait.ForAll( + wait.ForLog("database system is ready to accept connections").WithStartupTimeout(10*time.Second), + wait.ForListeningPort("5432/tcp").WithStartupTimeout(10*time.Second), + ), + } + + pg, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + assert.NoError(t, err) + defer pg.Terminate(ctx) + + endpoint, err := pg.Endpoint(ctx, "") + assert.NoError(t, err) + baseURI := "postgres://postgres:password@" + endpoint + "/postgres?sslmode=disable" + + // Different URIs + uri1 := baseURI + "&application_name=A" + uri2 := baseURI + "&application_name=B" + + // First Calls (RefCount = 1) + pool1a, err := rr.GetPool(ctx, uri1) + assert.NoError(t, err) + assert.NotNil(t, pool1a) + + pool2a, err := rr.GetPool(ctx, uri2) + assert.NoError(t, err) + assert.NotNil(t, pool2a) + + // Verify distinct pools + assert.NotEqual(t, pool1a, pool2a) + + // Reuse Logic (RefCount = 2 for URI1) + pool1b, err := rr.GetPool(ctx, uri1) + assert.NoError(t, err) + assert.Equal(t, pool1a, pool1b, "Should return exact same pool instance") + + // Release Logic + // URI1 RefCount: 2 -> 1 + rr.ReleasePool(ctx, uri1) + + // URI1 RefCount: 1 -> 0 (Closed & Removed) + rr.ReleasePool(ctx, uri1) + + // Recreation Logic + // URI1 should now create a NEW pool + pool1c, err := rr.GetPool(ctx, uri1) + assert.NoError(t, err) + assert.NotEqual(t, pool1a, pool1c, "Should be a new pool instance after full release") + + // Cleanup new pool + rr.ReleasePool(ctx, uri1) + + // URI2 Cleanup verification + rr.ReleasePool(ctx, uri2) // RefCount -> 0 + + pool2b, err := rr.GetPool(ctx, uri2) + assert.NoError(t, err) + assert.NotEqual(t, pool2a, pool2b, "URI2 should be recreated") + + rr.ReleasePool(ctx, uri2) +} diff --git a/retriever/postgresqlretriever/retriever.go b/retriever/postgresqlretriever/retriever.go index 7a8fa5c484a..e31b3747bb4 100644 --- a/retriever/postgresqlretriever/retriever.go +++ b/retriever/postgresqlretriever/retriever.go @@ -7,6 +7,7 @@ import ( "log/slog" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/thomaspoignant/go-feature-flag/retriever" "github.com/thomaspoignant/go-feature-flag/utils" "github.com/thomaspoignant/go-feature-flag/utils/fflog" @@ -26,7 +27,7 @@ type Retriever struct { logger *fflog.FFLogger status retriever.Status columns map[string]string - conn *pgx.Conn + pool *pgxpool.Pool flagset *string } @@ -40,23 +41,19 @@ func (r *Retriever) Init(ctx context.Context, logger *fflog.FFLogger, flagset *s r.columns = r.getColumnNames() r.flagset = flagset - if r.conn == nil { + if r.pool == nil { r.logger.Info("Initializing PostgreSQL retriever") r.logger.Debug("Using columns", "columns", r.columns) - conn, err := pgx.Connect(ctx, r.URI) + pool, err := GetPool(ctx, r.URI) if err != nil { r.status = retriever.RetrieverError return err } - if err := conn.Ping(ctx); err != nil { - r.status = retriever.RetrieverError - return err - } - - r.conn = conn + r.pool = pool } + r.status = retriever.RetrieverReady return nil } @@ -71,37 +68,39 @@ func (r *Retriever) Status() retriever.Status { // Shutdown closes the database connection. func (r *Retriever) Shutdown(ctx context.Context) error { - if r.conn == nil { - return nil + if r.pool != nil { + ReleasePool(ctx, r.URI) + r.pool = nil } - return r.conn.Close(ctx) + return nil } // Retrieve fetches flag configuration from PostgreSQL. func (r *Retriever) Retrieve(ctx context.Context) ([]byte, error) { - if r.conn == nil { - return nil, fmt.Errorf("database connection is not initialized") + if r.pool == nil { + return nil, fmt.Errorf("database connection pool is not initialized") } // Build the query using the configured table and column names query := r.buildQuery() // Build the arguments for the query - args := []interface{}{} + args := []any{} if r.getFlagset() != "" { - args = []interface{}{r.getFlagset()} + // If a flagset is defined, it becomes the first ($1) argument. + args = []any{r.getFlagset()} } r.logger.Debug("Executing PostgreSQL query", slog.String("query", query), slog.Any("args", args)) - rows, err := r.conn.Query(ctx, query, args...) + rows, err := r.pool.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to execute query: %w", err) } defer rows.Close() // Map to store flag configurations with flag_name as key - flagConfigs := make(map[string]interface{}) + flagConfigs := make(map[string]any) for rows.Next() { var flagName string @@ -121,6 +120,7 @@ func (r *Retriever) Retrieve(ctx context.Context) ([]byte, error) { flagConfigs[flagName] = config } + // Check for any errors that occurred during row iteration if err := rows.Err(); err != nil { return nil, fmt.Errorf("rows iteration error: %w", err) } diff --git a/retriever/postgresqlretriever/retriever_test.go b/retriever/postgresqlretriever/retriever_test.go index dbe40a157e5..baabede471c 100644 --- a/retriever/postgresqlretriever/retriever_test.go +++ b/retriever/postgresqlretriever/retriever_test.go @@ -378,7 +378,7 @@ func TestRetrieverDataHandlingErrors(t *testing.T) { assert.Contains(t, err.Error(), "failed to execute query") } }) - } +} // TestRetrieverEdgeCases tests additional edge cases and boundary conditions func TestRetrieverEdgeCases(t *testing.T) {