Skip to content

Documentation

Arlandaren edited this page Feb 26, 2025 · 3 revisions

Documentation for PgxWrappy

This documentation provides an overview and explanation of a PostgreSQL database wrapper written in Go. The wrapper simplifies database interactions by providing utility functions to execute queries and map the results directly into Go structs, handling transactions, and managing the connection pool.


Overview

The wrapper is structured into multiple Go files within the postgres package:

  1. Wrapper and DB interface (wrapper.go)
  2. Utility functions for struct mapping (utils.go)
  3. Transaction wrapper (tx_wrapper.go)

1. Wrapper and DB Interface (wrapper.go)

Package Import

package postgres

import (
    "context"
    "errors"
    "reflect"

    "github.com/jackc/pgx/v5"
    "github.com/jackc/pgx/v5/pgconn"
    "github.com/jackc/pgx/v5/pgxpool"
)

DB Interface

The DB interface abstracts the database operations, allowing for easier testing and flexibility.

type DB interface {
    QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
    Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
    Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error)
}

Wrapper Struct

The Wrapper struct contains an instance of a connection pool to the PostgreSQL database.

type Wrapper struct {
    pool *pgxpool.Pool
}

Constructor

Creates a new Wrapper instance with the provided connection pool.

func NewWrapper(pool *pgxpool.Pool) *Wrapper {
    return &Wrapper{pool: pool}
}

Database Operation Methods

QueryRow

Executes a query that is expected to return a single row.

func (w *Wrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
    return w.pool.QueryRow(ctx, sql, args...)
}

Query

Executes a query that returns multiple rows.

func (w *Wrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
    return w.pool.Query(ctx, sql, args...)
}

Exec

Executes a command that doesn't return rows (e.g., INSERT, UPDATE, DELETE).

func (w *Wrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
    return w.pool.Exec(ctx, sql, args...)
}

Utility Methods for Struct Mapping

Get

Executes a query that returns a single row and scans it into the provided struct.

func (w *Wrapper) Get(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
    // Ensure dest is a pointer to a struct
    destVal := reflect.ValueOf(dest)
    if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Struct {
        return errors.New("dest must be a pointer to a struct")
    }

    // Obtain expected column names from the struct
    columns, err := GetColumnNames(dest)
    if err != nil {
        return err
    }

    // Get pointers to the struct fields
    fields, err := StructFieldsPointers(dest, columns)
    if err != nil {
        return err
    }

    // Execute the query
    row := w.pool.QueryRow(ctx, sqlStr, args...)

    // Scan the row into the struct fields
    if err := row.Scan(fields...); err != nil {
        return err
    }

    return nil
}

Select

Executes a query that returns multiple rows and scans them into a slice of structs.

func (w *Wrapper) Select(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
    // Execute the query
    rows, err := w.pool.Query(ctx, sqlStr, args...)
    if err != nil {
        return err
    }
    defer rows.Close()

    // Ensure dest is a pointer to a slice
    destVal := reflect.ValueOf(dest)
    if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Slice {
        return errors.New("dest must be a pointer to a slice")
    }

    sliceVal := destVal.Elem()
    elemType := sliceVal.Type().Elem()

    // Determine if the slice is of structs or pointers to structs
    ptrToStruct := false
    if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
        ptrToStruct = true
        elemType = elemType.Elem()
    } else if elemType.Kind() != reflect.Struct {
        return errors.New("slice elements must be structs or pointers to structs")
    }

    // Get column names from the query result
    fieldDescriptions := rows.FieldDescriptions()
    columns := make([]string, len(fieldDescriptions))
    for i, fd := range fieldDescriptions {
        columns[i] = string(fd.Name)
    }

    // Iterate over each row
    for rows.Next() {
        elemPtr := reflect.New(elemType)

        // Get pointers to struct fields based on columns
        fields, err := StructFieldsPointers(elemPtr.Interface(), columns)
        if err != nil {
            return err
        }

        // Scan the row into the struct fields
        if err := rows.Scan(fields...); err != nil {
            return err
        }

        // Append the struct to the slice
        if ptrToStruct {
            sliceVal.Set(reflect.Append(sliceVal, elemPtr))
        } else {
            sliceVal.Set(reflect.Append(sliceVal, elemPtr.Elem()))
        }
    }

    // Check for errors during iteration
    if err := rows.Err(); err != nil {
        return err
    }

    return nil
}

Transaction Methods

Begin

Starts a new transaction.

func (w *Wrapper) Begin(ctx context.Context) (*TxWrapper, error) {
    tx, err := w.pool.Begin(ctx)
    if err != nil {
        return nil, err
    }
    return &TxWrapper{tx: tx}, nil
}

BeginTx

Starts a new transaction with specific options.

func (w *Wrapper) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (*TxWrapper, error) {
    tx, err := w.pool.BeginTx(ctx, txOptions)
    if err != nil {
        return nil, err
    }
    return &TxWrapper{tx: tx}, nil
}

2. Utility Functions for Struct Mapping (utils.go)

These utility functions assist in mapping database query results to Go structs based on struct tags.

StructFieldsPointers

Creates a slice of pointers to the struct fields corresponding to the given column names.

func StructFieldsPointers(strct interface{}, columns []string) ([]interface{}, error) {
    v := reflect.ValueOf(strct)

    if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
        return nil, errors.New("input must be a pointer to a struct")
    }

    // Map to hold field names and their values
    fieldMap := make(map[string]reflect.Value)
    CollectFields(v.Elem(), "", fieldMap)

    // Create a slice of field pointers corresponding to the columns
    fields := make([]interface{}, len(columns))
    for i, col := range columns {
        fieldVal, ok := fieldMap[col]
        if !ok {
            return nil, fmt.Errorf("no matching struct field found for column %s", col)
        }
        fields[i] = fieldVal.Addr().Interface()
    }

    return fields, nil
}

CollectFields

Recursively collects fields from a struct, including embedded structs, and maps them by their column names.

func CollectFields(v reflect.Value, prefix string, fieldMap map[string]reflect.Value) {
    t := v.Type()

    for i := 0; i < v.NumField(); i++ {
        field := t.Field(i)
        fieldValue := v.Field(i)

        if !fieldValue.CanSet() {
            continue
        }

        tag := field.Tag.Get("db")
        if tag == "-" {
            if fieldValue.Kind() == reflect.Struct {
                // Recurse into anonymous structs
                CollectFields(fieldValue, prefix, fieldMap)
            }
            continue
        }
        if tag == "" {
            tag = field.Name
        }

        var colName string
        if prefix != "" && !field.Anonymous {
            colName = prefix + "_" + tag
        } else {
            colName = tag
        }

        if fieldValue.Kind() == reflect.Struct {
            CollectFields(fieldValue, colName, fieldMap)
        } else {
            fieldMap[colName] = fieldValue
        }
    }
}

GetColumnNames

Returns a list of column names expected in the query result based on the struct's db tags.

func GetColumnNames(dest interface{}) ([]string, error) {
    var columns []string
    destVal := reflect.ValueOf(dest)
    if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Struct {
        return nil, errors.New("dest must be a pointer to a struct")
    }
    CollectColumnNames(destVal.Elem(), "", &columns)
    return columns, nil
}

CollectColumnNames

Recursively collects column names from a struct, including embedded structs, based on db tags.

func CollectColumnNames(v reflect.Value, prefix string, columns *[]string) {
    t := v.Type()
    for i := 0; i < v.NumField(); i++ {
        field := t.Field(i)
        fieldValue := v.Field(i)

        if !fieldValue.CanSet() {
            continue
        }

        tag := field.Tag.Get("db")
        if tag == "-" {
            if fieldValue.Kind() == reflect.Struct {
                CollectColumnNames(fieldValue, prefix, columns)
            }
            continue
        }
        if tag == "" {
            tag = field.Name
        }

        var colName string
        if prefix != "" && !field.Anonymous {
            colName = prefix + "_" + tag
        } else {
            colName = tag
        }

        if fieldValue.Kind() == reflect.Struct {
            CollectColumnNames(fieldValue, colName, columns)
        } else {
            *columns = append(*columns, colName)
        }
    }
}

3. Transaction Wrapper (tx_wrapper.go)

The TxWrapper struct provides the same methods as the Wrapper but allows operations within a transaction context.

Package Import

package postgres

import (
    "context"
    "errors"
    "reflect"

    "github.com/jackc/pgx/v5"
    "github.com/jackc/pgx/v5/pgconn"
)

TxWrapper Struct

Wraps a pgx.Tx to provide transactional operations.

type TxWrapper struct {
    tx pgx.Tx
}

Database Operation Methods within a Transaction

QueryRow

Executes a query that is expected to return a single row within the transaction.

func (tw *TxWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
    return tw.tx.QueryRow(ctx, sql, args...)
}

Query

Executes a query that returns multiple rows within the transaction.

func (tw *TxWrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
    return tw.tx.Query(ctx, sql, args...)
}

Exec

Executes a command that doesn't return rows within the transaction.

func (tw *TxWrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
    return tw.tx.Exec(ctx, sql, args...)
}

Utility Methods for Struct Mapping within a Transaction

Get

Executes a query that returns a single row and scans it into the provided struct within the transaction.

func (tw *TxWrapper) Get(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
    // Similar implementation as the Wrapper's Get method
    // ...
}

Select

Executes a query that returns multiple rows and scans them into a slice of structs within the transaction.

func (tw *TxWrapper) Select(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
    // Similar implementation as the Wrapper's Select method
    // ...
}

Transaction Control Methods

Commit

Commits the transaction.

func (tw *TxWrapper) Commit(ctx context.Context) error {
    return tw.tx.Commit(ctx)
}

Rollback

Rolls back the transaction.

func (tw *TxWrapper) Rollback(ctx context.Context) error {
    return tw.tx.Rollback(ctx)
}

Usage Examples

Below are examples of how to utilize the wrapper to interact with the PostgreSQL database.

Connecting to the Database

import (
    "context"
    "log"

    "github.com/jackc/pgx/v5/pgxpool"
)

func main() {
    ctx := context.Background()
    pool, err := pgxpool.New(ctx, "postgresql://user:password@localhost:5432/dbname")
    if err != nil {
        log.Fatal(err)
    }
    defer pool.Close()

    db := postgres.NewWrapper(pool)
    // Use db for database operations
}

Defining a Struct for Mapping

type User struct {
    ID    int    `db:"id"`
    Name  string `db:"name"`
    Email string `db:"email"`
}

Retrieving a Single Record

Using Get

var user User
err := db.Get(ctx, &user, "SELECT id, name, email FROM users WHERE id = $1", 1)
if err != nil {
    // Handle error
}

Retrieving Multiple Records

Using Select

var users []User
err := db.Select(ctx, &users, "SELECT id, name, email FROM users")
if err != nil {
    // Handle error
}

Executing a Command

Using Exec

commandTag, err := db.Exec(ctx, "UPDATE users SET email = $1 WHERE id = $2", "[email protected]", 1)
if err != nil {
    // Handle error
}

if commandTag.RowsAffected() != 1 {
    // Handle case where no rows or multiple rows were updated
}

Using Transactions

tx, err := db.Begin(ctx)
if err != nil {
    // Handle error
}

defer func() {
    if err != nil {
        _ = tx.Rollback(ctx)
    } else {
        err = tx.Commit(ctx)
    }
}()

// Perform transactional operations using tx
err = tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", amount, fromAccountID)
if err != nil {
    return err
}

err = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", amount, toAccountID)
if err != nil {
    return err
}

Notes on Struct Tags and Nested Structs

  • Struct Tags: The db struct tag is used to map struct fields to database columns. If the tag is omitted, the field name is used.
  • Ignoring Fields: Fields with the tag db:"-" are ignored during mapping.
  • Nested Structs: The utility functions support nested structs. Column names are composed by joining nested struct field names with underscores.

Example with Nested Structs

type Address struct {
    Street  string `db:"street"`
    City    string `db:"city"`
}

type User struct {
    ID      int     `db:"id"`
    Name    string  `db:"name"`
    Address Address `db:"-"`
}
  • The expected columns in the query would be: id, name, street, city.
  • It is also possible to give a tag "address", then the expected columns in the query would be: address_street, address_city

Use the way that suits your situataion.


Error Handling

The wrapper methods return errors that should be handled appropriately:

  • Invalid Inputs: Errors are returned if the dest parameter is not a pointer to a struct or slice of structs.
  • Field Mapping Errors: Errors occur if there's a mismatch between struct fields and query result columns.
  • Database Errors: Errors from the database, such as connection issues, query syntax errors, or constraint violations.