Files
ordr/api/auth/auth.go
2025-11-17 21:07:51 -07:00

185 lines
4.8 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"log"
"net/http"
"net/url"
"ordr-api/dto"
"os"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
// Authenticator is used to authenticate our users.
type Authenticator struct {
*oidc.Provider
oauth2.Config
}
// New instantiates the *Authenticator.
func New() (*Authenticator, error) {
provider, err := oidc.NewProvider(
context.Background(),
"https://"+os.Getenv("AUTH0_DOMAIN")+"/",
)
if err != nil {
return nil, err
}
conf := oauth2.Config{
ClientID: os.Getenv("AUTH0_CLIENT_ID"),
ClientSecret: os.Getenv("AUTH0_CLIENT_SECRET"),
RedirectURL: os.Getenv("AUTH0_CALLBACK_URL"),
Endpoint: provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, "profile", "email"},
}
return &Authenticator{
Provider: provider,
Config: conf,
}, nil
}
// VerifyIDToken verifies that an *oauth2.Token is a valid *oidc.IDToken.
func (a *Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, errors.New("no id_token field in oauth2 token")
}
oidcConfig := &oidc.Config{
ClientID: a.ClientID,
}
return a.Verifier(oidcConfig).Verify(ctx, rawIDToken)
}
// Handler for our login.
func LoginHandler(auth *Authenticator) gin.HandlerFunc {
return func(ctx *gin.Context) {
state, err := generateRandomState()
if err != nil {
ctx.String(http.StatusInternalServerError, err.Error())
return
}
// Save the state inside the session.
session := sessions.Default(ctx)
session.Set("state", state)
session.Options(sessions.Options{Path: "/"})
if err := session.Save(); err != nil {
ctx.String(http.StatusInternalServerError, err.Error())
return
}
audience_url := "https://" + os.Getenv("AUTH0_DOMAIN") + "/api/v2/"
auth_url := auth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("audience", audience_url))
var redirect dto.LoginRedirect
redirect.Status = "200 OK"
redirect.Location = auth_url
ctx.JSON(http.StatusOK, redirect)
}
}
// Handler for our logout.
func LogoutHandler(ctx *gin.Context) {
logoutUrl, err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + "/v2/logout")
if err != nil {
ctx.String(http.StatusInternalServerError, err.Error())
return
}
scheme := "http"
if ctx.Request.TLS != nil {
scheme = "https"
}
returnTo, err := url.Parse(scheme + "://" + ctx.Request.Host + os.Getenv("LOGOUT_CALLBACK_ENDPOINT"))
if err != nil {
ctx.String(http.StatusInternalServerError, err.Error())
return
}
parameters := url.Values{}
parameters.Add("returnTo", returnTo.String())
parameters.Add("client_id", os.Getenv("AUTH0_CLIENT_ID"))
logoutUrl.RawQuery = parameters.Encode()
ctx.Redirect(http.StatusTemporaryRedirect, logoutUrl.String())
}
func LogoutCallbackHandler(store cookie.Store) gin.HandlerFunc {
return func(ctx *gin.Context) {
session := sessions.Default(ctx)
session.Clear()
session.Options(sessions.Options{MaxAge: -1, Path: "/"})
err := session.Save()
if err != nil {
ctx.String(http.StatusInternalServerError, "Failed to clear session")
}
ctx.Redirect(http.StatusSeeOther, os.Getenv("PUBLIC_LOCATION"))
}
}
func AuthenticationCallbackHandler(auth *Authenticator) gin.HandlerFunc {
return func(ctx *gin.Context) {
session := sessions.Default(ctx)
if ctx.Query("state") != session.Get("state") {
ctx.String(http.StatusBadRequest, "Invalid state parameter.")
return
}
// Exchange an authorization code for a token.
token, err := auth.Exchange(ctx.Request.Context(), ctx.Query("code"))
if err != nil {
ctx.String(http.StatusUnauthorized, "Failed to convert an authorization code into a token.")
return
}
idToken, err := auth.VerifyIDToken(ctx.Request.Context(), token)
if err != nil {
ctx.String(http.StatusInternalServerError, "Failed to verify ID Token.")
return
}
var profile map[string]interface{}
if err := idToken.Claims(&profile); err != nil {
ctx.String(http.StatusInternalServerError, err.Error())
return
}
log.Println(token.ExpiresIn)
session.Set("access_token", token.AccessToken)
session.Set("refresh_token", token.RefreshToken)
session.Set("profile", profile)
if err := session.Save(); err != nil {
ctx.String(http.StatusInternalServerError, err.Error())
return
}
// Redirect to logged in page.
ctx.Redirect(http.StatusTemporaryRedirect, os.Getenv("LOGGED_IN_REDIRECT"))
}
}
func generateRandomState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
state := base64.StdEncoding.EncodeToString(b)
return state, nil
}