Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a more granular, site-based allowlist #14

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
/test_ed25519.pub
/test_server.key
/test_server.pub

/cmd/server/configs/allowlists/*.json
!/cmd/server/configs/allowlists/example.json
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ The service currently has two main credential-exchanging endpoints:
Things to note:

- Only Azure AD B2C is supported as a source of exchangable user ID tokens at the moment, see [the server `main.go`](/cmd/server/main.go) and the [`azjwt` package](/azure/azjwt/azjwt.go) for more details.
- The allowlists themselves (i.e. in `cmd/server/configs/allowlists/{local,dev}.json`, are **not** included in this repo. If you're deploying the actual RMI service, get these from one of the developers.

## Running the Credential Service

Before running the service locally, you'll need an allowlist at `cmd/server/configs/allowlists/local.json`. You can create one based on the `example.json` in the same directory.

Run the server against an Azure AD B2C instance:

```bash
Expand Down
1 change: 1 addition & 0 deletions allowlist/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ go_test(
name = "allowlist_test",
srcs = ["allowlist_test.go"],
embed = [":allowlist"],
deps = ["@com_github_google_go_cmp//cmp"],
)
129 changes: 120 additions & 9 deletions allowlist/allowlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,143 @@
package allowlist

import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
)

type config struct {
Format string `json:"format"`
Allowlist []*AllowlistEntry `json:"allowlist"`
}

// AllowlistEntry maps some entity (a domain or email) to a list of authorized sites.
type AllowlistEntry struct {
// Only one of Domain or Email may be set
Domain string `json:"domain"`
Email string `json:"email"`

// If empty, all sites are allowed.
// This isn't a "fail closed" default, but I think that's fine at this stage.
Sites []string `json:"sites"`
}

type Site string

const (
SiteOPGEE = Site("OPGEE")
SitePACTA = Site("PACTA")
)

type Entity struct {
// If true, AllowedSites is ignored
AllowAllSites bool
AllowedSites []Site
}

type Checker struct {
allowedDomains map[string]bool
allowedDomains map[string]*Entity
allowedEmails map[string]*Entity
}

func NewChecker(allowedDomains []string) *Checker {
m := make(map[string]bool)
for _, ad := range allowedDomains {
m[ad] = true
func NewCheckerFromConfigFile(fn string) (*Checker, error) {
f, err := os.Open(fn)
if err != nil {
return nil, fmt.Errorf("failed to open allowlist config file: %w", err)
}
defer f.Close()

var cfg config
if err := json.NewDecoder(f).Decode(&cfg); err != nil {
return nil, fmt.Errorf("failed to decode allowlist config: %w", err)
}

return newChecker(&cfg)
}

func newChecker(cfg *config) (*Checker, error) {
switch cfg.Format {
case "v1":
// Valid, continue
case "":
return nil, errors.New("config file had no 'format' field, which is required")
default:
return nil, fmt.Errorf("unknown format %q", cfg.Format)
}

allowedDomains := make(map[string]*Entity)
allowedEmails := make(map[string]*Entity)
for i, ae := range cfg.Allowlist {
if ae.Domain != "" && ae.Email != "" {
return nil, fmt.Errorf("allowlist entry specified both a domain (%q) and an email (%q), which isn't allowed", ae.Domain, ae.Email)
}
if ae.Domain == "" && ae.Email == "" {
return nil, fmt.Errorf("allowlist entry at index %d did not specify a domain or email", i)
}
entity, err := parseEntity(ae.Sites)
if err != nil {
return nil, fmt.Errorf("failed to parse sites for entry at index %d: %w", i, err)
}
if ae.Domain != "" {
allowedDomains[strings.ToLower(ae.Domain)] = entity
}
if ae.Email != "" {
allowedEmails[strings.ToLower(ae.Email)] = entity
}
}
return &Checker{
allowedDomains: m,
allowedDomains: allowedDomains,
allowedEmails: allowedEmails,
}, nil
}

func parseEntity(inp []string) (*Entity, error) {
if len(inp) == 0 {
return &Entity{AllowAllSites: true}, nil
}
var sites []Site
for _, s := range inp {
st, err := parseSite(s)
if err != nil {
return nil, fmt.Errorf("failed to parse entity %q: %w", s, err)
}
sites = append(sites, st)
}
return &Entity{AllowedSites: sites}, nil
}

func parseSite(inp string) (Site, error) {
switch inp {
case "OPGEE":
return SiteOPGEE, nil
case "PACTA":
return SitePACTA, nil
default:
return "", errors.New("unknown site")
}
}

// Check returns if the email is of an allowlisted domain, and errors if the
// email is incorrectly formatted. Subdomains are not handled specially, only
// exact matches are allowed.
func (c *Checker) Check(email string) (bool, error) {
func (c *Checker) Check(email string) (*Entity, error) {
email = strings.ToLower(email)

// First, check the email
if tmp, ok := c.allowedEmails[email]; ok {
return tmp, nil
}

_, domain, ok := strings.Cut(email, "@")
if !ok {
return false, fmt.Errorf("email %q was missing '@'", email)
return nil, fmt.Errorf("email %q was missing '@'", email)
}

if tmp, ok := c.allowedDomains[domain]; ok {
return tmp, nil
}

return c.allowedDomains[domain], nil
return nil, nil
}
81 changes: 65 additions & 16 deletions allowlist/allowlist_test.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,82 @@
package allowlist

import "testing"
import (
"testing"

func TestCheck(t *testing.T) {
allowedDomain := "example.com"
"github.com/google/go-cmp/cmp"
)

c := NewChecker([]string{allowedDomain})
var exampleConfig = &config{
Format: "v1",
Allowlist: []*AllowlistEntry{
&AllowlistEntry{Domain: "example.com"}, // Can access any site
&AllowlistEntry{Domain: "only-opgee.com", Sites: []string{"OPGEE"}}, // Can only access OPGEE
&AllowlistEntry{Email: "test@only-pacta.com", Sites: []string{"PACTA"}}, // Only test@ can access PACTA
},
}

allowed, err := c.Check("allowed@example.com")
func TestCheck(t *testing.T) {
c, err := newChecker(exampleConfig)
if err != nil {
t.Fatalf("Check: %v", err)
t.Fatalf("failed to init checker: %v", err)
}

tests := []struct {
desc string
email string
want *Entity
}{
{
desc: "allowed on any site",
email: "allowed@example.com",
want: &Entity{AllowAllSites: true},
},
{
desc: "domain not in the allowlist",
email: "denied@example.net",
want: nil,
},
{
desc: "domain allowlisted for OPGEE",
email: "any-email@only-opgee.com",
want: &Entity{AllowedSites: []Site{SiteOPGEE}},
},
{
desc: "email allowlisted for PACTA",
email: "test@only-pacta.com",
want: &Entity{AllowedSites: []Site{SitePACTA}},
},
{
desc: "different email allowlisted for PACTA",
email: "not-allowed@only-pacta.com",
want: nil,
},
}
if !allowed {
t.Error("Check said email was not allowed, expected allowed")

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
got, err := c.Check(test.email)
if err != nil {
t.Fatalf("Check: %v", err)
}
if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("unexpected Check() results (-want +got)\n%s", diff)
}
})
}
}

allowed, err = c.Check("denied@example.net")
func TestCheck_Error(t *testing.T) {
c, err := newChecker(exampleConfig)
if err != nil {
t.Fatalf("Check: %v", err)
}
if allowed {
t.Error("Check said email was allowed, expected not allowed")
t.Fatalf("failed to init checker: %v", err)
}

allowed, err = c.Check("malformed.biz")
entity, err := c.Check("malformed.biz")
if err == nil {
t.Fatal("Check returned no error for invalid email address")
}
if allowed {
t.Error("Check said invalid email was allowed")
if entity != nil {
t.Errorf("Check said invalid email was allowed: %+v", entity)
}
}
2 changes: 1 addition & 1 deletion azure/azjwt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//allowlist",
"//emailctx",
"//tokenctx",
"@com_github_go_chi_jwtauth_v5//:jwtauth",
"@com_github_lestrrat_go_jwx_v2//jwk",
"@com_github_lestrrat_go_jwx_v2//jwt",
Expand Down
49 changes: 32 additions & 17 deletions azure/azjwt/azjwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"net/http"

"github.com/RMI/credential-service/allowlist"
"github.com/RMI/credential-service/emailctx"
"github.com/RMI/credential-service/tokenctx"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
Expand Down Expand Up @@ -138,7 +138,7 @@ func (a *Auth) Authenticator(next http.Handler) http.Handler {
}

// Now, check against the allowlist
allowedEmails, err := a.checkEmailAllowed(token)
allowedEmails, entity, err := a.checkEmailAllowed(token)
if err != nil {
a.logger.Warn("token failed allowlist check", zap.Error(err))
if errors.Is(err, errNotAllowlisted) {
Expand All @@ -150,7 +150,8 @@ func (a *Auth) Authenticator(next http.Handler) http.Handler {
}

// Add the email to the context so that it can be used by the handler
ctx := emailctx.AddEmailsToContext(r.Context(), allowedEmails)
ctx := tokenctx.AddEmailsToContext(r.Context(), allowedEmails)
ctx = tokenctx.AddAllowlistEntityToContext(ctx, entity)
// Token is authenticated, pass it through
next.ServeHTTP(w, r.WithContext(ctx))
}
Expand Down Expand Up @@ -212,47 +213,61 @@ func (a *Auth) parseAndVerify(ctx context.Context, r *http.Request) (jwt.Token,

var errNotAllowlisted = errors.New("email isn't allowlisted")

func (a *Auth) checkEmailAllowed(tkn jwt.Token) ([]string, error) {
func (a *Auth) checkEmailAllowed(tkn jwt.Token) ([]string, *allowlist.Entity, error) {
// See https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference
emailsVal, ok := tkn.Get("emails")
if !ok {
return nil, errors.New("token didn't contain an 'emails' claim")
return nil, nil, errors.New("token didn't contain an 'emails' claim")
}
emailsI, ok := emailsVal.([]any)
if !ok {
return nil, fmt.Errorf("'emails' claim in token had unexpected type %T", emailsVal)
return nil, nil, fmt.Errorf("'emails' claim in token had unexpected type %T", emailsVal)
}

var emails []string
for i, ei := range emailsI {
email, ok := ei.(string)
if !ok {
return nil, fmt.Errorf("email %d from 'emails' claim in token had unexpected type %T", i, ei)
return nil, nil, fmt.Errorf("email %d from 'emails' claim in token had unexpected type %T", i, ei)
}
emails = append(emails, email)
}

// If one of their emails is allowed, consider them allowed.
allowed := a.allowedEmails(emails)
allowed, entity := a.allowedEmails(emails)
if len(allowed) == 0 {
return nil, errNotAllowlisted
return nil, nil, errNotAllowlisted
}

return allowed, nil
return allowed, entity, nil
}

func (a *Auth) allowedEmails(emails []string) []string {
var result []string
func (a *Auth) allowedEmails(emails []string) ([]string, *allowlist.Entity) {
var (
outEmails []string
allowAllSites bool
sites []allowlist.Site
)
for _, email := range emails {
allowed, err := a.allowlist.Check(email)
entity, err := a.allowlist.Check(email)
if err != nil {
a.logger.Warn("failed to check allowlist", zap.String("email", email), zap.Error(err))
continue
}
// We don't return early on success, just to parse and validate all emails in the token.
if allowed {
result = append(result, email)
if entity == nil {
continue
}
if entity.AllowAllSites {
allowAllSites = true
}
sites = append(sites, entity.AllowedSites...)
outEmails = append(outEmails, email)
}
var entity *allowlist.Entity
if allowAllSites {
entity = &allowlist.Entity{AllowAllSites: true}
} else {
entity = &allowlist.Entity{AllowedSites: sites}
}
return result
return outEmails, entity
}
Loading
Loading