165 lines
4.5 KiB
Go
165 lines
4.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"time"
|
|
|
|
"ordr-api/auth"
|
|
"ordr-api/dto"
|
|
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt"
|
|
|
|
"github.com/google/go-querystring/query"
|
|
)
|
|
|
|
func TokenIsNotExpired(access_token string) bool {
|
|
parser := new(jwt.Parser)
|
|
|
|
token, err := parser.Parse(access_token, nil)
|
|
if err != nil && err.Error() != "no Keyfunc was provided." {
|
|
log.Fatal(err.Error())
|
|
}
|
|
|
|
claims, claims_ok := token.Claims.(jwt.MapClaims)
|
|
if !claims_ok {
|
|
log.Fatal("failed to read claims")
|
|
}
|
|
|
|
exp, exp_ok := claims["exp"].(float64)
|
|
if !exp_ok {
|
|
log.Fatal("exp claim not present")
|
|
}
|
|
|
|
return time.Now().Before(time.Unix(int64(exp), 0))
|
|
}
|
|
|
|
func GetUserProfile(context *gin.Context) {
|
|
|
|
session := sessions.Default(context)
|
|
access_token := session.Get("access_token")
|
|
user_profile_client := http.Client{}
|
|
|
|
user_profile_url, err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + os.Getenv("AUTH0_USER_INFO_ENDPOINT"))
|
|
|
|
if err != nil {
|
|
log.Println("Failed to build user profile url " + err.Error())
|
|
context.Abort()
|
|
return
|
|
}
|
|
|
|
user_profile_request, user_profile_request_error := http.NewRequest("GET", user_profile_url.String(), nil)
|
|
if user_profile_request_error != nil {
|
|
log.Println("Failed to initialize validation request", user_profile_request_error.Error())
|
|
context.AbortWithStatus(http.StatusInternalServerError)
|
|
}
|
|
|
|
user_profile_request.Header.Add("Accept", "application/json")
|
|
user_profile_request.Header.Add("Authorization", "Bearer "+access_token.(string))
|
|
|
|
user_profile_response, user_profile_response_err := user_profile_client.Do(user_profile_request)
|
|
if user_profile_response_err != nil {
|
|
log.Println("Failed to validate user")
|
|
context.AbortWithStatus(http.StatusInternalServerError)
|
|
}
|
|
|
|
defer user_profile_response.Body.Close()
|
|
user_profile_bytes, _ := io.ReadAll(user_profile_response.Body)
|
|
|
|
var user_profile dto.UserProfileResponse
|
|
json.Unmarshal(user_profile_bytes, &user_profile)
|
|
context.Set("user_profile", user_profile)
|
|
context.Next()
|
|
}
|
|
|
|
func HandleRefreshToken(session sessions.Session) bool {
|
|
token_url, token_url_err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + os.Getenv("AUTH0_TOKEN_ENDPOINT"))
|
|
if token_url_err != nil {
|
|
log.Fatal("Failed to parse token url", token_url_err.Error())
|
|
}
|
|
|
|
refresh_token := session.Get("refresh_token")
|
|
if refresh_token == nil {
|
|
return false
|
|
}
|
|
refresh_request_dto := dto.RefreshTokenRequest{
|
|
GrantType: "refresh_token",
|
|
ClientId: os.Getenv("AUTH0_CLIENT_ID"),
|
|
ClientSecret: os.Getenv("AUTH0_CLIENT_SECRET"),
|
|
RefreshToken: refresh_token.(string),
|
|
}
|
|
|
|
refresh_body, marshal_err := query.Values(refresh_request_dto)
|
|
if marshal_err != nil {
|
|
log.Fatal("failed to marshal object: ", marshal_err)
|
|
}
|
|
|
|
refresh_client := http.Client{}
|
|
refresh_request, refresh_request_err := http.NewRequest("POST", token_url.String(), bytes.NewReader([]byte(refresh_body.Encode())))
|
|
if refresh_request_err != nil {
|
|
log.Fatal("Failed to create a request: ", refresh_request_err.Error())
|
|
}
|
|
refresh_request.Header.Add("content-type", "application/x-www-form-urlencoded")
|
|
refresh_response, refresh_response_err := refresh_client.Do(refresh_request)
|
|
if refresh_response_err != nil {
|
|
log.Fatal("Failed to fetch refresh token and response: ", refresh_response_err.Error())
|
|
}
|
|
defer refresh_response.Body.Close()
|
|
|
|
if refresh_response.StatusCode == http.StatusOK {
|
|
var response_body dto.RefreshTokenResponse
|
|
json_decoder := json.NewDecoder(refresh_response.Body)
|
|
content_unmarshal_err := json_decoder.Decode(&response_body)
|
|
|
|
if content_unmarshal_err != nil {
|
|
log.Println("Faileed to unmarshal data")
|
|
return false
|
|
}
|
|
|
|
session.Set("access_token", response_body.AccessToken)
|
|
session.Set("refresh_token", response_body.RefreshToken)
|
|
session.Options(sessions.Options{Path: "/"})
|
|
session_save_error := session.Save()
|
|
|
|
if session_save_error != nil {
|
|
log.Println("Failed to set session: " + session_save_error.Error())
|
|
return false
|
|
}
|
|
return true
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
func IsAuthenticated(auth *auth.Authenticator) gin.HandlerFunc {
|
|
return func(context *gin.Context) {
|
|
session := sessions.Default(context)
|
|
|
|
refresh_token := session.Get("refresh_token")
|
|
|
|
access_token := session.Get("access_token")
|
|
|
|
log.Printf("%s", refresh_token)
|
|
|
|
if access_token != nil {
|
|
if TokenIsNotExpired(access_token.(string)) {
|
|
context.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
if !HandleRefreshToken(session) {
|
|
context.AbortWithStatus(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
context.Next()
|
|
}
|
|
}
|