-
Notifications
You must be signed in to change notification settings - Fork 3
Documentation
This documentation provides an overview and explanation of a PostgreSQL database wrapper written in Go. The wrapper simplifies database interactions by providing utility functions to execute queries and map the results directly into Go structs, handling transactions, and managing the connection pool.
The wrapper is structured into multiple Go files within the postgres
package:
- Wrapper and DB interface (
wrapper.go
) - Utility functions for struct mapping (
utils.go
) - Transaction wrapper (
tx_wrapper.go
)
package postgres
import (
"context"
"errors"
"reflect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
The DB
interface abstracts the database operations, allowing for easier testing and flexibility.
type DB interface {
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error)
}
The Wrapper
struct contains an instance of a connection pool to the PostgreSQL database.
type Wrapper struct {
pool *pgxpool.Pool
}
Creates a new Wrapper
instance with the provided connection pool.
func NewWrapper(pool *pgxpool.Pool) *Wrapper {
return &Wrapper{pool: pool}
}
Executes a query that is expected to return a single row.
func (w *Wrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return w.pool.QueryRow(ctx, sql, args...)
}
Executes a query that returns multiple rows.
func (w *Wrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
return w.pool.Query(ctx, sql, args...)
}
Executes a command that doesn't return rows (e.g., INSERT, UPDATE, DELETE).
func (w *Wrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
return w.pool.Exec(ctx, sql, args...)
}
Executes a query that returns a single row and scans it into the provided struct.
func (w *Wrapper) Get(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
// Ensure dest is a pointer to a struct
destVal := reflect.ValueOf(dest)
if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Struct {
return errors.New("dest must be a pointer to a struct")
}
// Obtain expected column names from the struct
columns, err := GetColumnNames(dest)
if err != nil {
return err
}
// Get pointers to the struct fields
fields, err := StructFieldsPointers(dest, columns)
if err != nil {
return err
}
// Execute the query
row := w.pool.QueryRow(ctx, sqlStr, args...)
// Scan the row into the struct fields
if err := row.Scan(fields...); err != nil {
return err
}
return nil
}
Executes a query that returns multiple rows and scans them into a slice of structs.
func (w *Wrapper) Select(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
// Execute the query
rows, err := w.pool.Query(ctx, sqlStr, args...)
if err != nil {
return err
}
defer rows.Close()
// Ensure dest is a pointer to a slice
destVal := reflect.ValueOf(dest)
if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Slice {
return errors.New("dest must be a pointer to a slice")
}
sliceVal := destVal.Elem()
elemType := sliceVal.Type().Elem()
// Determine if the slice is of structs or pointers to structs
ptrToStruct := false
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
ptrToStruct = true
elemType = elemType.Elem()
} else if elemType.Kind() != reflect.Struct {
return errors.New("slice elements must be structs or pointers to structs")
}
// Get column names from the query result
fieldDescriptions := rows.FieldDescriptions()
columns := make([]string, len(fieldDescriptions))
for i, fd := range fieldDescriptions {
columns[i] = string(fd.Name)
}
// Iterate over each row
for rows.Next() {
elemPtr := reflect.New(elemType)
// Get pointers to struct fields based on columns
fields, err := StructFieldsPointers(elemPtr.Interface(), columns)
if err != nil {
return err
}
// Scan the row into the struct fields
if err := rows.Scan(fields...); err != nil {
return err
}
// Append the struct to the slice
if ptrToStruct {
sliceVal.Set(reflect.Append(sliceVal, elemPtr))
} else {
sliceVal.Set(reflect.Append(sliceVal, elemPtr.Elem()))
}
}
// Check for errors during iteration
if err := rows.Err(); err != nil {
return err
}
return nil
}
Starts a new transaction.
func (w *Wrapper) Begin(ctx context.Context) (*TxWrapper, error) {
tx, err := w.pool.Begin(ctx)
if err != nil {
return nil, err
}
return &TxWrapper{tx: tx}, nil
}
Starts a new transaction with specific options.
func (w *Wrapper) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (*TxWrapper, error) {
tx, err := w.pool.BeginTx(ctx, txOptions)
if err != nil {
return nil, err
}
return &TxWrapper{tx: tx}, nil
}
These utility functions assist in mapping database query results to Go structs based on struct tags.
Creates a slice of pointers to the struct fields corresponding to the given column names.
func StructFieldsPointers(strct interface{}, columns []string) ([]interface{}, error) {
v := reflect.ValueOf(strct)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return nil, errors.New("input must be a pointer to a struct")
}
// Map to hold field names and their values
fieldMap := make(map[string]reflect.Value)
CollectFields(v.Elem(), "", fieldMap)
// Create a slice of field pointers corresponding to the columns
fields := make([]interface{}, len(columns))
for i, col := range columns {
fieldVal, ok := fieldMap[col]
if !ok {
return nil, fmt.Errorf("no matching struct field found for column %s", col)
}
fields[i] = fieldVal.Addr().Interface()
}
return fields, nil
}
Recursively collects fields from a struct, including embedded structs, and maps them by their column names.
func CollectFields(v reflect.Value, prefix string, fieldMap map[string]reflect.Value) {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
if !fieldValue.CanSet() {
continue
}
tag := field.Tag.Get("db")
if tag == "-" {
if fieldValue.Kind() == reflect.Struct {
// Recurse into anonymous structs
CollectFields(fieldValue, prefix, fieldMap)
}
continue
}
if tag == "" {
tag = field.Name
}
var colName string
if prefix != "" && !field.Anonymous {
colName = prefix + "_" + tag
} else {
colName = tag
}
if fieldValue.Kind() == reflect.Struct {
CollectFields(fieldValue, colName, fieldMap)
} else {
fieldMap[colName] = fieldValue
}
}
}
Returns a list of column names expected in the query result based on the struct's db
tags.
func GetColumnNames(dest interface{}) ([]string, error) {
var columns []string
destVal := reflect.ValueOf(dest)
if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Struct {
return nil, errors.New("dest must be a pointer to a struct")
}
CollectColumnNames(destVal.Elem(), "", &columns)
return columns, nil
}
Recursively collects column names from a struct, including embedded structs, based on db
tags.
func CollectColumnNames(v reflect.Value, prefix string, columns *[]string) {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
if !fieldValue.CanSet() {
continue
}
tag := field.Tag.Get("db")
if tag == "-" {
if fieldValue.Kind() == reflect.Struct {
CollectColumnNames(fieldValue, prefix, columns)
}
continue
}
if tag == "" {
tag = field.Name
}
var colName string
if prefix != "" && !field.Anonymous {
colName = prefix + "_" + tag
} else {
colName = tag
}
if fieldValue.Kind() == reflect.Struct {
CollectColumnNames(fieldValue, colName, columns)
} else {
*columns = append(*columns, colName)
}
}
}
The TxWrapper
struct provides the same methods as the Wrapper
but allows operations within a transaction context.
package postgres
import (
"context"
"errors"
"reflect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
Wraps a pgx.Tx
to provide transactional operations.
type TxWrapper struct {
tx pgx.Tx
}
Executes a query that is expected to return a single row within the transaction.
func (tw *TxWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return tw.tx.QueryRow(ctx, sql, args...)
}
Executes a query that returns multiple rows within the transaction.
func (tw *TxWrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
return tw.tx.Query(ctx, sql, args...)
}
Executes a command that doesn't return rows within the transaction.
func (tw *TxWrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
return tw.tx.Exec(ctx, sql, args...)
}
Executes a query that returns a single row and scans it into the provided struct within the transaction.
func (tw *TxWrapper) Get(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
// Similar implementation as the Wrapper's Get method
// ...
}
Executes a query that returns multiple rows and scans them into a slice of structs within the transaction.
func (tw *TxWrapper) Select(ctx context.Context, dest interface{}, sqlStr string, args ...interface{}) error {
// Similar implementation as the Wrapper's Select method
// ...
}
Commits the transaction.
func (tw *TxWrapper) Commit(ctx context.Context) error {
return tw.tx.Commit(ctx)
}
Rolls back the transaction.
func (tw *TxWrapper) Rollback(ctx context.Context) error {
return tw.tx.Rollback(ctx)
}
Below are examples of how to utilize the wrapper to interact with the PostgreSQL database.
import (
"context"
"log"
"github.com/jackc/pgx/v5/pgxpool"
)
func main() {
ctx := context.Background()
pool, err := pgxpool.New(ctx, "postgresql://user:password@localhost:5432/dbname")
if err != nil {
log.Fatal(err)
}
defer pool.Close()
db := postgres.NewWrapper(pool)
// Use db for database operations
}
type User struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
}
var user User
err := db.Get(ctx, &user, "SELECT id, name, email FROM users WHERE id = $1", 1)
if err != nil {
// Handle error
}
var users []User
err := db.Select(ctx, &users, "SELECT id, name, email FROM users")
if err != nil {
// Handle error
}
commandTag, err := db.Exec(ctx, "UPDATE users SET email = $1 WHERE id = $2", "[email protected]", 1)
if err != nil {
// Handle error
}
if commandTag.RowsAffected() != 1 {
// Handle case where no rows or multiple rows were updated
}
tx, err := db.Begin(ctx)
if err != nil {
// Handle error
}
defer func() {
if err != nil {
_ = tx.Rollback(ctx)
} else {
err = tx.Commit(ctx)
}
}()
// Perform transactional operations using tx
err = tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", amount, fromAccountID)
if err != nil {
return err
}
err = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", amount, toAccountID)
if err != nil {
return err
}
-
Struct Tags: The
db
struct tag is used to map struct fields to database columns. If the tag is omitted, the field name is used. -
Ignoring Fields: Fields with the tag
db:"-"
are ignored during mapping. - Nested Structs: The utility functions support nested structs. Column names are composed by joining nested struct field names with underscores.
type Address struct {
Street string `db:"street"`
City string `db:"city"`
}
type User struct {
ID int `db:"id"`
Name string `db:"name"`
Address Address `db:"-"`
}
- The expected columns in the query would be:
id
,name
,street
,city
. - It is also possible to give a tag "address", then the expected columns in the query would be:
address_street
,address_city
Use the way that suits your situataion.
The wrapper methods return errors that should be handled appropriately:
-
Invalid Inputs: Errors are returned if the
dest
parameter is not a pointer to a struct or slice of structs. - Field Mapping Errors: Errors occur if there's a mismatch between struct fields and query result columns.
- Database Errors: Errors from the database, such as connection issues, query syntax errors, or constraint violations.
All the necessary documentation can be found here