feat: auth
This commit is contained in:
180
api/auth/auth.go
Normal file
180
api/auth/auth.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Authenticator is used to authenticate our users.
|
||||
type Authenticator struct {
|
||||
*oidc.Provider
|
||||
oauth2.Config
|
||||
}
|
||||
|
||||
// New instantiates the *Authenticator.
|
||||
func New() (*Authenticator, error) {
|
||||
provider, err := oidc.NewProvider(
|
||||
context.Background(),
|
||||
"https://"+os.Getenv("AUTH0_DOMAIN")+"/",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf := oauth2.Config{
|
||||
ClientID: os.Getenv("AUTH0_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("AUTH0_CLIENT_SECRET"),
|
||||
RedirectURL: os.Getenv("AUTH0_CALLBACK_URL"),
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, "profile", "email"},
|
||||
}
|
||||
|
||||
return &Authenticator{
|
||||
Provider: provider,
|
||||
Config: conf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyIDToken verifies that an *oauth2.Token is a valid *oidc.IDToken.
|
||||
func (a *Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, errors.New("no id_token field in oauth2 token")
|
||||
}
|
||||
|
||||
oidcConfig := &oidc.Config{
|
||||
ClientID: a.ClientID,
|
||||
}
|
||||
|
||||
return a.Verifier(oidcConfig).Verify(ctx, rawIDToken)
|
||||
}
|
||||
|
||||
// Handler for our login.
|
||||
func LoginHandler(auth *Authenticator) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
state, err := generateRandomState()
|
||||
if err != nil {
|
||||
ctx.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Save the state inside the session.
|
||||
session := sessions.Default(ctx)
|
||||
session.Set("state", state)
|
||||
session.Options(sessions.Options{Path: "/"})
|
||||
if err := session.Save(); err != nil {
|
||||
ctx.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
audience_url := "https://" + os.Getenv("AUTH0_DOMAIN") + "/api/v2/"
|
||||
auth_url := auth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("audience", audience_url))
|
||||
ctx.Redirect(http.StatusTemporaryRedirect, auth_url)
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for our logout.
|
||||
func LogoutHandler(ctx *gin.Context) {
|
||||
logoutUrl, err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + "/v2/logout")
|
||||
if err != nil {
|
||||
ctx.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if ctx.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
returnTo, err := url.Parse(scheme + "://" + ctx.Request.Host + os.Getenv("LOGOUT_CALLBACK_ENDPOINT"))
|
||||
if err != nil {
|
||||
ctx.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
parameters := url.Values{}
|
||||
parameters.Add("returnTo", returnTo.String())
|
||||
parameters.Add("client_id", os.Getenv("AUTH0_CLIENT_ID"))
|
||||
logoutUrl.RawQuery = parameters.Encode()
|
||||
|
||||
ctx.Redirect(http.StatusTemporaryRedirect, logoutUrl.String())
|
||||
}
|
||||
|
||||
func LogoutCallbackHandler(store cookie.Store) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
session := sessions.Default(ctx)
|
||||
session.Clear()
|
||||
session.Options(sessions.Options{MaxAge: -1, Path: "/"})
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
ctx.String(http.StatusInternalServerError, "Failed to clear session")
|
||||
}
|
||||
|
||||
ctx.Redirect(http.StatusSeeOther, os.Getenv("PUBLIC_LOCATION"))
|
||||
}
|
||||
}
|
||||
|
||||
func AuthenticationCallbackHandler(auth *Authenticator) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
session := sessions.Default(ctx)
|
||||
if ctx.Query("state") != session.Get("state") {
|
||||
ctx.String(http.StatusBadRequest, "Invalid state parameter.")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange an authorization code for a token.
|
||||
token, err := auth.Exchange(ctx.Request.Context(), ctx.Query("code"))
|
||||
if err != nil {
|
||||
ctx.String(http.StatusUnauthorized, "Failed to convert an authorization code into a token.")
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := auth.VerifyIDToken(ctx.Request.Context(), token)
|
||||
if err != nil {
|
||||
ctx.String(http.StatusInternalServerError, "Failed to verify ID Token.")
|
||||
return
|
||||
}
|
||||
|
||||
var profile map[string]interface{}
|
||||
if err := idToken.Claims(&profile); err != nil {
|
||||
ctx.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println(token.ExpiresIn)
|
||||
session.Set("access_token", token.AccessToken)
|
||||
session.Set("refresh_token", token.RefreshToken)
|
||||
session.Set("profile", profile)
|
||||
if err := session.Save(); err != nil {
|
||||
ctx.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to logged in page.
|
||||
ctx.Redirect(http.StatusTemporaryRedirect, "/user")
|
||||
}
|
||||
}
|
||||
|
||||
func generateRandomState() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
state := base64.StdEncoding.EncodeToString(b)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
165
api/auth/middleware/authentication_middleware.go
Normal file
165
api/auth/middleware/authentication_middleware.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.conway.engineer/ada/ordr.git/auth"
|
||||
"git.conway.engineer/ada/ordr.git/dto"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/google/go-querystring/query"
|
||||
)
|
||||
|
||||
func TokenIsNotExpired(access_token string) bool {
|
||||
parser := new(jwt.Parser)
|
||||
|
||||
token, err := parser.Parse(access_token, nil)
|
||||
if err != nil && err.Error() != "no Keyfunc was provided." {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
claims, claims_ok := token.Claims.(jwt.MapClaims)
|
||||
if !claims_ok {
|
||||
log.Fatal("failed to read claims")
|
||||
}
|
||||
|
||||
exp, exp_ok := claims["exp"].(float64)
|
||||
if !exp_ok {
|
||||
log.Fatal("exp claim not present")
|
||||
}
|
||||
|
||||
return time.Now().Before(time.Unix(int64(exp), 0))
|
||||
}
|
||||
|
||||
func GetUserProfile(context *gin.Context) {
|
||||
|
||||
session := sessions.Default(context)
|
||||
access_token := session.Get("access_token")
|
||||
user_profile_client := http.Client{}
|
||||
|
||||
user_profile_url, err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + os.Getenv("AUTH0_USER_INFO_ENDPOINT"))
|
||||
|
||||
if err != nil {
|
||||
log.Println("Failed to build user profile url " + err.Error())
|
||||
context.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
user_profile_request, user_profile_request_error := http.NewRequest("GET", user_profile_url.String(), nil)
|
||||
if user_profile_request_error != nil {
|
||||
log.Println("Failed to initialize validation request", user_profile_request_error.Error())
|
||||
context.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
user_profile_request.Header.Add("Accept", "application/json")
|
||||
user_profile_request.Header.Add("Authorization", "Bearer "+access_token.(string))
|
||||
|
||||
user_profile_response, user_profile_response_err := user_profile_client.Do(user_profile_request)
|
||||
if user_profile_response_err != nil {
|
||||
log.Println("Failed to validate user")
|
||||
context.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer user_profile_response.Body.Close()
|
||||
user_profile_bytes, _ := io.ReadAll(user_profile_response.Body)
|
||||
|
||||
var user_profile dto.UserProfileResponse
|
||||
json.Unmarshal(user_profile_bytes, &user_profile)
|
||||
context.Set("user_profile", user_profile)
|
||||
context.Next()
|
||||
}
|
||||
|
||||
func HandleRefreshToken(session sessions.Session) bool {
|
||||
token_url, token_url_err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + os.Getenv("AUTH0_TOKEN_ENDPOINT"))
|
||||
if token_url_err != nil {
|
||||
log.Fatal("Failed to parse token url", token_url_err.Error())
|
||||
}
|
||||
|
||||
refresh_token := session.Get("refresh_token")
|
||||
refresh_request_dto := dto.RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
ClientId: os.Getenv("AUTH0_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("AUTH0_CLIENT_SECRET"),
|
||||
RefreshToken: refresh_token.(string),
|
||||
}
|
||||
|
||||
refresh_body, marshal_err := query.Values(refresh_request_dto)
|
||||
if marshal_err != nil {
|
||||
log.Fatal("failed to marshal object: ", marshal_err)
|
||||
}
|
||||
|
||||
refresh_client := http.Client{}
|
||||
refresh_request, refresh_request_err := http.NewRequest("POST", token_url.String(), bytes.NewReader([]byte(refresh_body.Encode())))
|
||||
if refresh_request_err != nil {
|
||||
log.Fatal("Failed to create a request: ", refresh_request_err.Error())
|
||||
}
|
||||
refresh_request.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
refresh_response, refresh_response_err := refresh_client.Do(refresh_request)
|
||||
if refresh_response_err != nil {
|
||||
log.Fatal("Failed to fetch refresh token and response: ", refresh_response_err.Error())
|
||||
}
|
||||
defer refresh_response.Body.Close()
|
||||
|
||||
if refresh_response.StatusCode == http.StatusOK {
|
||||
var response_body dto.RefreshTokenResponse
|
||||
json_decoder := json.NewDecoder(refresh_response.Body)
|
||||
content_unmarshal_err := json_decoder.Decode(&response_body)
|
||||
|
||||
if content_unmarshal_err != nil {
|
||||
log.Println("Faileed to unmarshal data")
|
||||
return false
|
||||
}
|
||||
|
||||
session.Set("access_token", response_body.AccessToken)
|
||||
session.Set("refresh_token", response_body.RefreshToken)
|
||||
session.Options(sessions.Options{Path: "/"})
|
||||
session_save_error := session.Save()
|
||||
|
||||
if session_save_error != nil {
|
||||
log.Println("Failed to set session: " + session_save_error.Error())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func IsAuthenticated(auth *auth.Authenticator) gin.HandlerFunc {
|
||||
return func(context *gin.Context) {
|
||||
session := sessions.Default(context)
|
||||
|
||||
if session.Get("profile") == nil {
|
||||
context.Redirect(http.StatusSeeOther, "/auth/login")
|
||||
context.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
access_token := session.Get("access_token")
|
||||
|
||||
if access_token == nil {
|
||||
context.Redirect(http.StatusSeeOther, "/auth/login")
|
||||
return
|
||||
}
|
||||
|
||||
if TokenIsNotExpired(access_token.(string)) {
|
||||
context.Next()
|
||||
} else {
|
||||
if !HandleRefreshToken(session) {
|
||||
context.String(http.StatusUnauthorized, "Failed to refresh access token")
|
||||
} else {
|
||||
context.Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user