Files
golunteer/backend/pkg/db/db.go
2025-01-07 16:37:37 +00:00

226 lines
5.2 KiB
Go

package db
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_config "github.com/johannesbuehl/golunteer/backend/pkg/config"
_logger "github.com/johannesbuehl/golunteer/backend/pkg/logger"
)
var logger = _logger.Logger
var config = _config.Config
// connection to database
var DB *sqlx.DB
func init() {
// setup the database-connection
sqlConfig := mysql.Config{
AllowNativePasswords: true,
Net: "tcp",
User: config.Database.User,
Passwd: config.Database.Password,
Addr: config.Database.Host,
DBName: config.Database.Database,
}
// connect to the database
DB = sqlx.MustOpen("mysql", sqlConfig.FormatDSN())
DB.SetMaxIdleConns(10)
DB.SetMaxIdleConns(100)
DB.SetConnMaxLifetime(time.Minute)
}
// query the database
func SelectOld[T any](table string, where string, args ...any) ([]T, error) {
// validate columns against struct T
tType := reflect.TypeOf(new(T)).Elem()
columns := make([]string, tType.NumField())
validColumns := make(map[string]any)
for ii := 0; ii < tType.NumField(); ii++ {
field := tType.Field(ii)
validColumns[strings.ToLower(field.Name)] = struct{}{}
columns[ii] = strings.ToLower(field.Name)
}
for _, col := range columns {
if _, ok := validColumns[strings.ToLower(col)]; !ok {
return nil, fmt.Errorf("invalid column: %s for struct type %T", col, new(T))
}
}
// create the query
completeQuery := fmt.Sprintf("SELECT %s FROM %s", strings.Join(columns, ", "), table)
if where != "" && where != "*" {
completeQuery = fmt.Sprintf("%s WHERE %s", completeQuery, where)
}
var rows *sql.Rows
var err error
if len(args) > 0 {
DB.Ping()
rows, err = DB.Query(completeQuery, args...)
} else {
DB.Ping()
rows, err = DB.Query(completeQuery)
}
if err != nil {
logger.Error().Msgf("database access failed with error %v", err)
return nil, err
}
defer rows.Close()
results := []T{}
for rows.Next() {
var lineResult T
scanArgs := make([]any, len(columns))
v := reflect.ValueOf(&lineResult).Elem()
for ii, col := range columns {
field := v.FieldByName(col)
if field.IsValid() && field.CanSet() {
scanArgs[ii] = field.Addr().Interface()
} else {
logger.Warn().Msgf("Field %s not found in struct %T", col, lineResult)
scanArgs[ii] = new(any) // save dummy value
}
}
// scan the row into the struct
if err := rows.Scan(scanArgs...); err != nil {
logger.Warn().Msgf("Scan-error: %v", err)
return nil, err
}
results = append(results, lineResult)
}
if err := rows.Err(); err != nil {
logger.Error().Msgf("rows-error: %v", err)
return nil, err
} else {
return results, nil
}
}
// insert data intot the databse
func Insert(table string, vals any) error {
// extract columns from vals
v := reflect.ValueOf(vals)
t := v.Type()
columns := make([]string, t.NumField())
values := make([]any, t.NumField())
for ii := 0; ii < t.NumField(); ii++ {
fieldValue := v.Field(ii)
field := t.Field(ii)
columns[ii] = strings.ToLower(field.Name)
values[ii] = fieldValue.Interface()
}
placeholders := strings.Repeat(("?, "), len(columns))
placeholders = strings.TrimSuffix(placeholders, ", ")
completeQuery := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), placeholders)
_, err := DB.Exec(completeQuery, values...)
return err
}
// update data in the database
func Update(table string, set, where any) error {
setV := reflect.ValueOf(set)
setT := setV.Type()
setColumns := make([]string, setT.NumField())
setValues := make([]any, setT.NumField())
for ii := 0; ii < setT.NumField(); ii++ {
fieldValue := setV.Field(ii)
field := setT.Field(ii)
setColumns[ii] = strings.ToLower(field.Name) + " = ?"
setValues[ii] = fieldValue.Interface()
}
whereV := reflect.ValueOf(where)
whereT := whereV.Type()
whereColumns := make([]string, whereT.NumField())
whereValues := make([]any, whereT.NumField())
for ii := 0; ii < whereT.NumField(); ii++ {
fieldValue := whereV.Field(ii)
// skip empty (zero) values
if !fieldValue.IsZero() {
field := whereT.Field(ii)
whereColumns[ii] = strings.ToLower(field.Name) + " = ?"
whereValues[ii] = fmt.Sprint(fieldValue.Interface())
}
}
sets := strings.Join(setColumns, ", ")
wheres := strings.Join(whereColumns, " AND ")
placeholderValues := append(setValues, whereValues...)
completeQuery := fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, sets, wheres)
_, err := DB.Exec(completeQuery, placeholderValues...)
return err
}
// remove data from the database
func Delete(table string, vals any) error {
// extract columns from vals
v := reflect.ValueOf(vals)
t := v.Type()
columns := make([]string, t.NumField())
values := make([]any, t.NumField())
for ii := 0; ii < t.NumField(); ii++ {
fieldValue := v.Field(ii)
// skip empty (zero) values
if !fieldValue.IsZero() {
field := t.Field(ii)
columns[ii] = strings.ToLower(field.Name) + " = ?"
values[ii] = fmt.Sprint(fieldValue.Interface())
}
}
completeQuery := fmt.Sprintf("DELETE FROM %s WHERE %s", table, strings.Join(columns, ", "))
_, err := DB.Exec(completeQuery, values...)
return err
}