Switched to vestigo router

Also moved db reference to data module; it now starts, but doesn't serve index.html for root yet
This commit is contained in:
Daniel J. Summers 2018-03-22 22:11:38 -05:00
parent b248f7ca7f
commit 59b5574b16
5 changed files with 43 additions and 52 deletions

View File

@ -23,6 +23,9 @@ const (
AND "lastStatus" <> 'Answered'` AND "lastStatus" <> 'Answered'`
) )
// db is a connection to the database for the entire application.
var db *sql.DB
// Settings holds the PostgreSQL configuration for myPrayerJournal. // Settings holds the PostgreSQL configuration for myPrayerJournal.
type Settings struct { type Settings struct {
Host string `json:"host"` Host string `json:"host"`
@ -35,7 +38,7 @@ type Settings struct {
/* Data Access */ /* Data Access */
// Retrieve a basic request // Retrieve a basic request
func retrieveRequest(db *sql.DB, reqID, userID string) (*Request, bool) { func retrieveRequest(reqID, userID string) (*Request, bool) {
req := Request{} req := Request{}
err := db.QueryRow(` err := db.QueryRow(`
SELECT "requestId", "enteredOn" SELECT "requestId", "enteredOn"
@ -79,8 +82,8 @@ func makeJournal(rows *sql.Rows, userID string) []JournalRequest {
} }
// AddHistory creates a history entry for a prayer request, given the status and updated text. // AddHistory creates a history entry for a prayer request, given the status and updated text.
func AddHistory(db *sql.DB, userID, reqID, status, text string) int { func AddHistory(userID, reqID, status, text string) int {
if _, ok := retrieveRequest(db, reqID, userID); !ok { if _, ok := retrieveRequest(reqID, userID); !ok {
return 404 return 404
} }
_, err := db.Exec(` _, err := db.Exec(`
@ -97,7 +100,7 @@ func AddHistory(db *sql.DB, userID, reqID, status, text string) int {
} }
// AddNew stores a new prayer request and its initial history record. // AddNew stores a new prayer request and its initial history record.
func AddNew(db *sql.DB, userID, text string) (*JournalRequest, bool) { func AddNew(userID, text string) (*JournalRequest, bool) {
id := cuid.New() id := cuid.New()
now := jsNow() now := jsNow()
tx, err := db.Begin() tx, err := db.Begin()
@ -129,8 +132,8 @@ func AddNew(db *sql.DB, userID, text string) (*JournalRequest, bool) {
} }
// AddNote adds a note to a prayer request. // AddNote adds a note to a prayer request.
func AddNote(db *sql.DB, userID, reqID, note string) int { func AddNote(userID, reqID, note string) int {
if _, ok := retrieveRequest(db, reqID, userID); !ok { if _, ok := retrieveRequest(reqID, userID); !ok {
return 404 return 404
} }
_, err := db.Exec(` _, err := db.Exec(`
@ -147,7 +150,7 @@ func AddNote(db *sql.DB, userID, reqID, note string) int {
} }
// Answered retrieves all answered requests for the given user. // Answered retrieves all answered requests for the given user.
func Answered(db *sql.DB, userID string) []JournalRequest { func Answered(userID string) []JournalRequest {
rows, err := db.Query(currentRequestSQL+ rows, err := db.Query(currentRequestSQL+
`WHERE "userId" = $1 `WHERE "userId" = $1
AND "lastStatus" = 'Answered' AND "lastStatus" = 'Answered'
@ -162,7 +165,7 @@ func Answered(db *sql.DB, userID string) []JournalRequest {
} }
// ByID retrieves a journal request by its ID. // ByID retrieves a journal request by its ID.
func ByID(db *sql.DB, userID, reqID string) (*JournalRequest, bool) { func ByID(userID, reqID string) (*JournalRequest, bool) {
req := JournalRequest{} req := JournalRequest{}
err := db.QueryRow(currentRequestSQL+ err := db.QueryRow(currentRequestSQL+
`WHERE "requestId" = $1 `WHERE "requestId" = $1
@ -178,26 +181,27 @@ func ByID(db *sql.DB, userID, reqID string) (*JournalRequest, bool) {
} }
// Connect establishes a connection to the database. // Connect establishes a connection to the database.
func Connect(s *Settings) (*sql.DB, bool) { func Connect(s *Settings) bool {
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
s.Host, s.Port, s.User, s.Password, s.DbName) s.Host, s.Port, s.User, s.Password, s.DbName)
db, err := sql.Open("postgres", connStr) var err error
db, err = sql.Open("postgres", connStr)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil, false return false
} }
err = db.Ping() err = db.Ping()
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil, false return false
} }
log.Printf("Connected to postgres://%s@%s:%d/%s\n", s.User, s.Host, s.Port, s.DbName) log.Printf("Connected to postgres://%s@%s:%d/%s\n", s.User, s.Host, s.Port, s.DbName)
return db, true return true
} }
// FullByID retrieves a journal request, including its full history and notes. // FullByID retrieves a journal request, including its full history and notes.
func FullByID(db *sql.DB, userID, reqID string) (*JournalRequest, bool) { func FullByID(userID, reqID string) (*JournalRequest, bool) {
req, ok := ByID(db, userID, reqID) req, ok := ByID(userID, reqID)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -225,12 +229,12 @@ func FullByID(db *sql.DB, userID, reqID string) (*JournalRequest, bool) {
log.Print(hRows.Err()) log.Print(hRows.Err())
return nil, false return nil, false
} }
req.Notes = NotesByID(db, userID, reqID) req.Notes = NotesByID(userID, reqID)
return req, true return req, true
} }
// Journal retrieves the current user's active prayer journal. // Journal retrieves the current user's active prayer journal.
func Journal(db *sql.DB, userID string) []JournalRequest { func Journal(userID string) []JournalRequest {
rows, err := db.Query(journalSQL, userID) rows, err := db.Query(journalSQL, userID)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
@ -241,8 +245,8 @@ func Journal(db *sql.DB, userID string) []JournalRequest {
} }
// NotesByID retrieves the notes for a given prayer request // NotesByID retrieves the notes for a given prayer request
func NotesByID(db *sql.DB, userID, reqID string) []Note { func NotesByID(userID, reqID string) []Note {
if _, ok := retrieveRequest(db, reqID, userID); !ok { if _, ok := retrieveRequest(reqID, userID); !ok {
return nil return nil
} }
rows, err := db.Query(` rows, err := db.Query(`
@ -276,7 +280,7 @@ func NotesByID(db *sql.DB, userID, reqID string) []Note {
/* DDL */ /* DDL */
// EnsureDB makes sure we have a known state of data structures. // EnsureDB makes sure we have a known state of data structures.
func EnsureDB(db *sql.DB) { func EnsureDB() {
tableSQL := func(table string) string { tableSQL := func(table string) string {
return fmt.Sprintf(`SELECT 1 FROM pg_tables WHERE schemaname='mpj' AND tablename='%s'`, table) return fmt.Sprintf(`SELECT 1 FROM pg_tables WHERE schemaname='mpj' AND tablename='%s'`, table)
} }

View File

@ -1,13 +1,11 @@
package routes package routes
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"log" "log"
"net/http" "net/http"
"github.com/danieljsummers/myPrayerJournal/src/api/data" "github.com/danieljsummers/myPrayerJournal/src/api/data"
"github.com/julienschmidt/httprouter"
) )
/* Support */ /* Support */
@ -32,9 +30,9 @@ func sendJSON(w http.ResponseWriter, r *http.Request, result interface{}) {
/* Handlers */ /* Handlers */
func journal(w http.ResponseWriter, r *http.Request, _ httprouter.Params, db *sql.DB) { func journal(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value(ContextUserKey) user := r.Context().Value(ContextUserKey)
reqs := data.Journal(db, user.(string)) reqs := data.Journal(user.(string))
if reqs == nil { if reqs == nil {
reqs = []data.JournalRequest{} reqs = []data.JournalRequest{}
} }

View File

@ -2,13 +2,11 @@ package routes
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"net/http" "net/http"
"time"
"github.com/auth0-community/go-auth0" "github.com/auth0-community/go-auth0"
"github.com/julienschmidt/httprouter" "github.com/husobee/vestigo"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
@ -19,27 +17,14 @@ type AuthConfig struct {
ClientSecret string `json:"secret"` ClientSecret string `json:"secret"`
} }
// DBHandler extends httprouter's handler with a DB instance.
type DBHandler func(http.ResponseWriter, *http.Request, httprouter.Params, *sql.DB)
//type APIHandler func(http.ResponseWriter, *http.Request, httprouter.Params, *sql.DB, string)
// ContextKey is the type of key used in our contexts. // ContextKey is the type of key used in our contexts.
type ContextKey string type ContextKey string
// ContextUserKey is the key for the current user in the context. // ContextUserKey is the key for the current user in the context.
const ContextUserKey ContextKey = "user" const ContextUserKey ContextKey = "user"
func withDB(next DBHandler, db *sql.DB) httprouter.Handle { func withAuth(next http.HandlerFunc, cfg *AuthConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { return func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(60*time.Second))
defer cancel()
next(w, r.WithContext(ctx), p, db)
}
}
func withAuth(next DBHandler, cfg *AuthConfig) DBHandler {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params, db *sql.DB) {
secret := []byte(cfg.ClientSecret) secret := []byte(cfg.ClientSecret)
secretProvider := auth0.NewKeyProvider(secret) secretProvider := auth0.NewKeyProvider(secret)
audience := []string{fmt.Sprintf("https://%s/userinfo", cfg.Domain)} audience := []string{fmt.Sprintf("https://%s/userinfo", cfg.Domain)}
@ -61,22 +46,22 @@ func withAuth(next DBHandler, cfg *AuthConfig) DBHandler {
} }
r = r.WithContext(context.WithValue(r.Context(), ContextUserKey, values["sub"])) r = r.WithContext(context.WithValue(r.Context(), ContextUserKey, values["sub"]))
// TODO pass the user ID (sub) along; this -> doesn't work | r.Header.Add("user-id", token.Claims("sub")) // TODO pass the user ID (sub) along; this -> doesn't work | r.Header.Add("user-id", token.Claims("sub"))
next(w, r, p, db) next(w, r)
} }
} }
} }
// NewRouter returns a configured router to handle all incoming requests. // NewRouter returns a configured router to handle all incoming requests.
func NewRouter(db *sql.DB, cfg *AuthConfig) *httprouter.Router { func NewRouter(cfg *AuthConfig) *vestigo.Router {
router := httprouter.New() router := vestigo.NewRouter()
for _, route := range routes { for _, route := range routes {
if route.IsPublic { if route.IsPublic {
router.Handle(route.Method, route.Pattern, withDB(route.Func, db)) router.Add(route.Method, route.Pattern, route.Func)
} else { } else {
router.Handle(route.Method, route.Pattern, withDB(withAuth(route.Func, cfg), db)) router.Add(route.Method, route.Pattern, withAuth(route.Func, cfg))
} }
} }
// router.ServeFiles("/*filepath", http.Dir("/public")) router.Get("/*", http.FileServer(http.Dir("/public")).ServeHTTP)
return router return router
} }

View File

@ -1,12 +1,16 @@
// Package routes contains endpoint handlers for the myPrayerJournal API. // Package routes contains endpoint handlers for the myPrayerJournal API.
package routes package routes
import (
"net/http"
)
// Route is a route served in the application. // Route is a route served in the application.
type Route struct { type Route struct {
Name string Name string
Method string Method string
Pattern string Pattern string
Func DBHandler Func http.HandlerFunc
IsPublic bool IsPublic bool
} }

View File

@ -39,10 +39,10 @@ func readSettings(f string) *Settings {
func main() { func main() {
cfg := readSettings("config.json") cfg := readSettings("config.json")
db, ok := data.Connect(cfg.Data) if ok := data.Connect(cfg.Data); !ok {
if !ok {
log.Fatal("Unable to connect to database; exiting") log.Fatal("Unable to connect to database; exiting")
} }
data.EnsureDB()
log.Printf("myPrayerJournal API listening on %s", cfg.Web.Port) log.Printf("myPrayerJournal API listening on %s", cfg.Web.Port)
log.Fatal(http.ListenAndServe(cfg.Web.Port, routes.NewRouter(db, cfg.Auth))) log.Fatal(http.ListenAndServe(cfg.Web.Port, routes.NewRouter(cfg.Auth)))
} }