Skip to content

Commit

Permalink
Add a helper library for client APIs to verify new claims (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcspragu authored Jul 10, 2024
1 parent db78ddb commit 2da2fe1
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
13 changes: 13 additions & 0 deletions siteverify/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "siteverify",
srcs = ["siteverify.go"],
importpath = "github.com/RMI/credential-service/siteverify",
visibility = ["//visibility:public"],
deps = [
"//allowlist",
"@com_github_go_chi_jwtauth_v5//:jwtauth",
"@org_uber_go_zap//:zap",
],
)
66 changes: 66 additions & 0 deletions siteverify/siteverify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Package siteverify provides utilities for verifying a token contains the
// expected site.
package siteverify

import (
"fmt"
"net/http"
"strings"

"github.com/RMI/credential-service/allowlist"
"github.com/go-chi/jwtauth/v5"
"go.uber.org/zap"
)

func CheckSite(site allowlist.Site, logger zap.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}

if claims == nil {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

siteClaimI, ok := claims["sites"]
if !ok {
logger.Info("JWT claims had no 'sites' claim")
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

siteClaim, ok := siteClaimI.(string)
if !ok {
logger.Info("JWT 'sites' claim had unexpected type", zap.String("type", fmt.Sprintf("%T", siteClaimI)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

if !isClaimValidForSite(siteClaim, site) {
logger.Info("JWT 'sites' claim was invalid", zap.String("claim", siteClaim))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

// Token has the correct site, pass it through
next.ServeHTTP(w, r)
})
}
}

func isClaimValidForSite(siteClaim string, target allowlist.Site) bool {
if siteClaim == "all" {
return true
}

for _, s := range strings.Split(siteClaim, ",") {
if s == string(target) {
return true
}
}
return false
}

0 comments on commit 2da2fe1

Please sign in to comment.