diff --git a/.gitignore b/.gitignore index 676ea6d..183998c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ /test_ed25519.pub /test_server.key /test_server.pub + +/cmd/server/configs/allowlists/*.json +!/cmd/server/configs/allowlists/example.json diff --git a/README.md b/README.md index 836ad02..5c97ff7 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,12 @@ The service currently has two main credential-exchanging endpoints: Things to note: - Only Azure AD B2C is supported as a source of exchangable user ID tokens at the moment, see [the server `main.go`](/cmd/server/main.go) and the [`azjwt` package](/azure/azjwt/azjwt.go) for more details. +- The allowlists themselves (i.e. in `cmd/server/configs/allowlists/{local,dev}.json`, are **not** included in this repo. If you're deploying the actual RMI service, get these from one of the developers. ## Running the Credential Service +Before running the service locally, you'll need an allowlist at `cmd/server/configs/allowlists/local.json`. You can create one based on the `example.json` in the same directory. + Run the server against an Azure AD B2C instance: ```bash diff --git a/allowlist/BUILD.bazel b/allowlist/BUILD.bazel index 9ada01c..380dc5b 100644 --- a/allowlist/BUILD.bazel +++ b/allowlist/BUILD.bazel @@ -11,4 +11,5 @@ go_test( name = "allowlist_test", srcs = ["allowlist_test.go"], embed = [":allowlist"], + deps = ["@com_github_google_go_cmp//cmp"], ) diff --git a/allowlist/allowlist.go b/allowlist/allowlist.go index aa35ce0..6bbcc9a 100644 --- a/allowlist/allowlist.go +++ b/allowlist/allowlist.go @@ -2,32 +2,143 @@ package allowlist import ( + "encoding/json" + "errors" "fmt" + "os" "strings" ) +type config struct { + Format string `json:"format"` + Allowlist []*AllowlistEntry `json:"allowlist"` +} + +// AllowlistEntry maps some entity (a domain or email) to a list of authorized sites. +type AllowlistEntry struct { + // Only one of Domain or Email may be set + Domain string `json:"domain"` + Email string `json:"email"` + + // If empty, all sites are allowed. + // This isn't a "fail closed" default, but I think that's fine at this stage. + Sites []string `json:"sites"` +} + +type Site string + +const ( + SiteOPGEE = Site("OPGEE") + SitePACTA = Site("PACTA") +) + +type Entity struct { + // If true, AllowedSites is ignored + AllowAllSites bool + AllowedSites []Site +} + type Checker struct { - allowedDomains map[string]bool + allowedDomains map[string]*Entity + allowedEmails map[string]*Entity } -func NewChecker(allowedDomains []string) *Checker { - m := make(map[string]bool) - for _, ad := range allowedDomains { - m[ad] = true +func NewCheckerFromConfigFile(fn string) (*Checker, error) { + f, err := os.Open(fn) + if err != nil { + return nil, fmt.Errorf("failed to open allowlist config file: %w", err) + } + defer f.Close() + + var cfg config + if err := json.NewDecoder(f).Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode allowlist config: %w", err) + } + + return newChecker(&cfg) +} + +func newChecker(cfg *config) (*Checker, error) { + switch cfg.Format { + case "v1": + // Valid, continue + case "": + return nil, errors.New("config file had no 'format' field, which is required") + default: + return nil, fmt.Errorf("unknown format %q", cfg.Format) + } + + allowedDomains := make(map[string]*Entity) + allowedEmails := make(map[string]*Entity) + for i, ae := range cfg.Allowlist { + if ae.Domain != "" && ae.Email != "" { + return nil, fmt.Errorf("allowlist entry specified both a domain (%q) and an email (%q), which isn't allowed", ae.Domain, ae.Email) + } + if ae.Domain == "" && ae.Email == "" { + return nil, fmt.Errorf("allowlist entry at index %d did not specify a domain or email", i) + } + entity, err := parseEntity(ae.Sites) + if err != nil { + return nil, fmt.Errorf("failed to parse sites for entry at index %d: %w", i, err) + } + if ae.Domain != "" { + allowedDomains[strings.ToLower(ae.Domain)] = entity + } + if ae.Email != "" { + allowedEmails[strings.ToLower(ae.Email)] = entity + } } return &Checker{ - allowedDomains: m, + allowedDomains: allowedDomains, + allowedEmails: allowedEmails, + }, nil +} + +func parseEntity(inp []string) (*Entity, error) { + if len(inp) == 0 { + return &Entity{AllowAllSites: true}, nil + } + var sites []Site + for _, s := range inp { + st, err := parseSite(s) + if err != nil { + return nil, fmt.Errorf("failed to parse entity %q: %w", s, err) + } + sites = append(sites, st) + } + return &Entity{AllowedSites: sites}, nil +} + +func parseSite(inp string) (Site, error) { + switch inp { + case "OPGEE": + return SiteOPGEE, nil + case "PACTA": + return SitePACTA, nil + default: + return "", errors.New("unknown site") } } // Check returns if the email is of an allowlisted domain, and errors if the // email is incorrectly formatted. Subdomains are not handled specially, only // exact matches are allowed. -func (c *Checker) Check(email string) (bool, error) { +func (c *Checker) Check(email string) (*Entity, error) { + email = strings.ToLower(email) + + // First, check the email + if tmp, ok := c.allowedEmails[email]; ok { + return tmp, nil + } + _, domain, ok := strings.Cut(email, "@") if !ok { - return false, fmt.Errorf("email %q was missing '@'", email) + return nil, fmt.Errorf("email %q was missing '@'", email) + } + + if tmp, ok := c.allowedDomains[domain]; ok { + return tmp, nil } - return c.allowedDomains[domain], nil + return nil, nil } diff --git a/allowlist/allowlist_test.go b/allowlist/allowlist_test.go index 9c8dd85..2c13163 100644 --- a/allowlist/allowlist_test.go +++ b/allowlist/allowlist_test.go @@ -1,33 +1,82 @@ package allowlist -import "testing" +import ( + "testing" -func TestCheck(t *testing.T) { - allowedDomain := "example.com" + "github.com/google/go-cmp/cmp" +) - c := NewChecker([]string{allowedDomain}) +var exampleConfig = &config{ + Format: "v1", + Allowlist: []*AllowlistEntry{ + &AllowlistEntry{Domain: "example.com"}, // Can access any site + &AllowlistEntry{Domain: "only-opgee.com", Sites: []string{"OPGEE"}}, // Can only access OPGEE + &AllowlistEntry{Email: "test@only-pacta.com", Sites: []string{"PACTA"}}, // Only test@ can access PACTA + }, +} - allowed, err := c.Check("allowed@example.com") +func TestCheck(t *testing.T) { + c, err := newChecker(exampleConfig) if err != nil { - t.Fatalf("Check: %v", err) + t.Fatalf("failed to init checker: %v", err) + } + + tests := []struct { + desc string + email string + want *Entity + }{ + { + desc: "allowed on any site", + email: "allowed@example.com", + want: &Entity{AllowAllSites: true}, + }, + { + desc: "domain not in the allowlist", + email: "denied@example.net", + want: nil, + }, + { + desc: "domain allowlisted for OPGEE", + email: "any-email@only-opgee.com", + want: &Entity{AllowedSites: []Site{SiteOPGEE}}, + }, + { + desc: "email allowlisted for PACTA", + email: "test@only-pacta.com", + want: &Entity{AllowedSites: []Site{SitePACTA}}, + }, + { + desc: "different email allowlisted for PACTA", + email: "not-allowed@only-pacta.com", + want: nil, + }, } - if !allowed { - t.Error("Check said email was not allowed, expected allowed") + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + got, err := c.Check(test.email) + if err != nil { + t.Fatalf("Check: %v", err) + } + if diff := cmp.Diff(test.want, got); diff != "" { + t.Errorf("unexpected Check() results (-want +got)\n%s", diff) + } + }) } +} - allowed, err = c.Check("denied@example.net") +func TestCheck_Error(t *testing.T) { + c, err := newChecker(exampleConfig) if err != nil { - t.Fatalf("Check: %v", err) - } - if allowed { - t.Error("Check said email was allowed, expected not allowed") + t.Fatalf("failed to init checker: %v", err) } - allowed, err = c.Check("malformed.biz") + entity, err := c.Check("malformed.biz") if err == nil { t.Fatal("Check returned no error for invalid email address") } - if allowed { - t.Error("Check said invalid email was allowed") + if entity != nil { + t.Errorf("Check said invalid email was allowed: %+v", entity) } } diff --git a/azure/azjwt/BUILD.bazel b/azure/azjwt/BUILD.bazel index 6b0b21b..06fc4f1 100644 --- a/azure/azjwt/BUILD.bazel +++ b/azure/azjwt/BUILD.bazel @@ -7,7 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//allowlist", - "//emailctx", + "//tokenctx", "@com_github_go_chi_jwtauth_v5//:jwtauth", "@com_github_lestrrat_go_jwx_v2//jwk", "@com_github_lestrrat_go_jwx_v2//jwt", diff --git a/azure/azjwt/azjwt.go b/azure/azjwt/azjwt.go index 18b3d49..9715ed9 100644 --- a/azure/azjwt/azjwt.go +++ b/azure/azjwt/azjwt.go @@ -10,7 +10,7 @@ import ( "net/http" "github.com/RMI/credential-service/allowlist" - "github.com/RMI/credential-service/emailctx" + "github.com/RMI/credential-service/tokenctx" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" @@ -138,7 +138,7 @@ func (a *Auth) Authenticator(next http.Handler) http.Handler { } // Now, check against the allowlist - allowedEmails, err := a.checkEmailAllowed(token) + allowedEmails, entity, err := a.checkEmailAllowed(token) if err != nil { a.logger.Warn("token failed allowlist check", zap.Error(err)) if errors.Is(err, errNotAllowlisted) { @@ -150,7 +150,8 @@ func (a *Auth) Authenticator(next http.Handler) http.Handler { } // Add the email to the context so that it can be used by the handler - ctx := emailctx.AddEmailsToContext(r.Context(), allowedEmails) + ctx := tokenctx.AddEmailsToContext(r.Context(), allowedEmails) + ctx = tokenctx.AddAllowlistEntityToContext(ctx, entity) // Token is authenticated, pass it through next.ServeHTTP(w, r.WithContext(ctx)) } @@ -212,47 +213,61 @@ func (a *Auth) parseAndVerify(ctx context.Context, r *http.Request) (jwt.Token, var errNotAllowlisted = errors.New("email isn't allowlisted") -func (a *Auth) checkEmailAllowed(tkn jwt.Token) ([]string, error) { +func (a *Auth) checkEmailAllowed(tkn jwt.Token) ([]string, *allowlist.Entity, error) { // See https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference emailsVal, ok := tkn.Get("emails") if !ok { - return nil, errors.New("token didn't contain an 'emails' claim") + return nil, nil, errors.New("token didn't contain an 'emails' claim") } emailsI, ok := emailsVal.([]any) if !ok { - return nil, fmt.Errorf("'emails' claim in token had unexpected type %T", emailsVal) + return nil, nil, fmt.Errorf("'emails' claim in token had unexpected type %T", emailsVal) } var emails []string for i, ei := range emailsI { email, ok := ei.(string) if !ok { - return nil, fmt.Errorf("email %d from 'emails' claim in token had unexpected type %T", i, ei) + return nil, nil, fmt.Errorf("email %d from 'emails' claim in token had unexpected type %T", i, ei) } emails = append(emails, email) } // If one of their emails is allowed, consider them allowed. - allowed := a.allowedEmails(emails) + allowed, entity := a.allowedEmails(emails) if len(allowed) == 0 { - return nil, errNotAllowlisted + return nil, nil, errNotAllowlisted } - return allowed, nil + return allowed, entity, nil } -func (a *Auth) allowedEmails(emails []string) []string { - var result []string +func (a *Auth) allowedEmails(emails []string) ([]string, *allowlist.Entity) { + var ( + outEmails []string + allowAllSites bool + sites []allowlist.Site + ) for _, email := range emails { - allowed, err := a.allowlist.Check(email) + entity, err := a.allowlist.Check(email) if err != nil { a.logger.Warn("failed to check allowlist", zap.String("email", email), zap.Error(err)) continue } - // We don't return early on success, just to parse and validate all emails in the token. - if allowed { - result = append(result, email) + if entity == nil { + continue + } + if entity.AllowAllSites { + allowAllSites = true } + sites = append(sites, entity.AllowedSites...) + outEmails = append(outEmails, email) + } + var entity *allowlist.Entity + if allowAllSites { + entity = &allowlist.Entity{AllowAllSites: true} + } else { + entity = &allowlist.Entity{AllowedSites: sites} } - return result + return outEmails, entity } diff --git a/cmd/server/configs/allowlists/example.json b/cmd/server/configs/allowlists/example.json new file mode 100644 index 0000000..b157084 --- /dev/null +++ b/cmd/server/configs/allowlists/example.json @@ -0,0 +1,9 @@ +{ + "format": "v1", + "allowlist": [ + {"domain": "siliconally.org"}, + {"domain": "rmi.org"}, + {"domain": "opgee-only.not-a-domain", "sites": ["OPGEE"]}, + {"email": "only-this-email@pacta-only.not-a-domain", "sites": ["PACTA"]} + ] +} diff --git a/cmd/server/configs/dev.conf b/cmd/server/configs/dev.conf index 416d03f..4fd6e85 100644 --- a/cmd/server/configs/dev.conf +++ b/cmd/server/configs/dev.conf @@ -1,10 +1,10 @@ env dev allowed_cors_origins https://*.dev.rmi.siliconally.dev +allowlist_file cmd/server/configs/dev.json port 80 use_local_jwts false enable_credential_test_api true -allowed_domains siliconally.org,rmi.org,plevin.com cookie_domain dev.rmi.siliconally.dev diff --git a/cmd/server/configs/local.conf b/cmd/server/configs/local.conf index 4abb66e..556e90a 100644 --- a/cmd/server/configs/local.conf +++ b/cmd/server/configs/local.conf @@ -1,11 +1,10 @@ env local allowed_cors_origins http://localhost:3000 +allowlist_file cmd/server/configs/local.json use_local_jwts true enable_credential_test_api true -allowed_domains siliconally.org,rmi.org - secret_auth_private_key_id 2023-08-11 secret_auth_private_key_data -----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEINj77iKqwAKJHb0I0XVr8OhvQMpO6SVkmCGlNb9epwUO\n-----END PRIVATE KEY----- diff --git a/cmd/server/main.go b/cmd/server/main.go index 515ec94..b478b41 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -77,7 +77,7 @@ func run(args []string) error { cookieDomain = fs.String("cookie_domain", "", "Domain to return in the cookie response") - allowedDomains flagext.StringList + allowlistFile = fs.String("allowlist_file", "", "JSON-formatted file containing the allowlist") allowedCORSOrigins flagext.StringList minLogLevel zapcore.Level = zapcore.WarnLevel @@ -90,7 +90,6 @@ func run(args []string) error { azureADClientID = fs.String("secret_azure_ad_client_id", "", "The client ID the users are authenticating against") azureADTenantID = fs.String("secret_azure_ad_tenant_id", "", "The ID of the tenant user tokens should come from") ) - fs.Var(&allowedDomains, "allowed_domains", "A comma-separated list of domains that are allowed to get valid credentials") fs.Var(&allowedCORSOrigins, "allowed_cors_origins", "A comma-separated list of CORS origins to allow traffic from") fs.Var(&minLogLevel, "min_log_level", "If set, retains logs at the given level and above. Options: 'debug', 'info', 'warn', 'error', 'dpanic', 'panic', 'fatal' - default warn.") @@ -226,10 +225,14 @@ func run(args []string) error { zap.String("user_flow", sec.AzureAD.UserFlow), zap.String("client_id", sec.AzureAD.ClientID), ) + checker, err := allowlist.NewCheckerFromConfigFile(*allowlistFile) + if err != nil { + return fmt.Errorf("failed to init allowlist checker: %w", err) + } // Accept Microsoft-issued JWTs azJWTAuth, err := azjwt.NewAuth(ctx, &azjwt.Config{ Logger: logger, - Allowlist: allowlist.NewChecker(allowedDomains), + Allowlist: checker, Tenant: sec.AzureAD.TenantName, TenantID: sec.AzureAD.TenantID, Policy: sec.AzureAD.UserFlow, diff --git a/cmd/server/usersrv/BUILD.bazel b/cmd/server/usersrv/BUILD.bazel index 3dc4307..626fb94 100644 --- a/cmd/server/usersrv/BUILD.bazel +++ b/cmd/server/usersrv/BUILD.bazel @@ -6,8 +6,9 @@ go_library( importpath = "github.com/RMI/credential-service/cmd/server/usersrv", visibility = ["//visibility:public"], deps = [ - "//emailctx", + "//allowlist", "//openapi:user_generated", + "//tokenctx", "@com_github_go_chi_jwtauth_v5//:jwtauth", "@com_github_google_uuid//:uuid", "@com_github_lestrrat_go_jwx_v2//jwa", @@ -22,9 +23,10 @@ go_test( srcs = ["usrsrv_test.go"], embed = [":usersrv"], deps = [ - "//emailctx", + "//allowlist", "//keyutil", "//openapi:user_generated", + "//tokenctx", "@com_github_go_chi_jwtauth_v5//:jwtauth", "@com_github_google_go_cmp//cmp", "@com_github_google_uuid//:uuid", diff --git a/cmd/server/usersrv/usersrv.go b/cmd/server/usersrv/usersrv.go index 3440d93..265a1dd 100644 --- a/cmd/server/usersrv/usersrv.go +++ b/cmd/server/usersrv/usersrv.go @@ -5,13 +5,15 @@ package usersrv import ( + "bytes" "context" "errors" "fmt" "net/http" "time" - "github.com/RMI/credential-service/emailctx" + "github.com/RMI/credential-service/allowlist" + "github.com/RMI/credential-service/tokenctx" "github.com/RMI/credential-service/openapi/user" "github.com/go-chi/jwtauth/v5" "github.com/google/uuid" @@ -26,7 +28,7 @@ type TokenIssuer struct { Now func() time.Time } -func (t *TokenIssuer) IssueToken(userID string, emails []string, exp time.Time) (string, string, error) { +func (t *TokenIssuer) IssueToken(userID string, emails []string, ae *allowlist.Entity, exp time.Time) (string, string, error) { now := t.Now() id := uuid.NewString() builder := jwt.NewBuilder(). @@ -39,6 +41,13 @@ func (t *TokenIssuer) IssueToken(userID string, emails []string, exp time.Time) if len(emails) > 0 { builder = builder.Claim("emails", emails) } + if ae != nil { + if ae.AllowAllSites { + builder = builder.Claim("sites", "all") + } else { + builder = builder.Claim("sites", formatSites(ae.AllowedSites)) + } + } tkn, err := builder.Build() if err != nil { return "", "", fmt.Errorf("failed to build token: %w", err) @@ -50,6 +59,17 @@ func (t *TokenIssuer) IssueToken(userID string, emails []string, exp time.Time) return string(dat), id, nil } +func formatSites(sites []allowlist.Site) string { + var buf bytes.Buffer + for i, s := range sites { + buf.WriteString(string(s)) + if i < len(sites)-1 { + buf.WriteRune(',') + } + } + return buf.String() +} + type Server struct { Issuer *TokenIssuer Logger *zap.Logger @@ -101,12 +121,19 @@ func (s *Server) exchangeToken(ctx context.Context, opts ...exchangeOption) (str return "", "", time.Time{}, fmt.Errorf("failed to get auth service JWT to exchange for service-issued JWT: %w", err) } - var emails []string + emails, err := tokenctx.EmailsFromContext(ctx) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to get emails from context: %w", err) + } + + var emailsClaim []string if eOpts.includeEmails { - emails, err = emailctx.EmailsFromContext(ctx) - if err != nil { - return "", "", time.Time{}, fmt.Errorf("failed to get email from context: %w", err) - } + emailsClaim = emails + } + + ae, err := tokenctx.AllowlistEntityFromContext(ctx) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to get allowlist entity from context: %w", err) } var exp time.Time @@ -134,7 +161,7 @@ func (s *Server) exchangeToken(ctx context.Context, opts ...exchangeOption) (str return "", "", time.Time{}, fmt.Errorf("'sub' claim in source JWT was of type %T, expected a string", sub) } - tkn, id, err := s.Issuer.IssueToken(subStr, emails, exp) + tkn, id, err := s.Issuer.IssueToken(subStr, emailsClaim, ae, exp) if err != nil { return "", "", time.Time{}, fmt.Errorf("failed to sign token: %w", err) } diff --git a/cmd/server/usersrv/usrsrv_test.go b/cmd/server/usersrv/usrsrv_test.go index 01dcbbf..9fd89a4 100644 --- a/cmd/server/usersrv/usrsrv_test.go +++ b/cmd/server/usersrv/usrsrv_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" - "github.com/RMI/credential-service/emailctx" + "github.com/RMI/credential-service/allowlist" "github.com/RMI/credential-service/keyutil" "github.com/RMI/credential-service/openapi/user" + "github.com/RMI/credential-service/tokenctx" "github.com/go-chi/jwtauth/v5" "github.com/google/go-cmp/cmp" "github.com/google/uuid" @@ -36,7 +37,8 @@ func TestLogin(t *testing.T) { emails := []string{"test@allowed.example.com"} tkn.Set("emails", emails) ctx = jwtauth.NewContext(ctx, tkn, nil) - ctx = emailctx.AddEmailsToContext(ctx, emails) + ctx = tokenctx.AddEmailsToContext(ctx, emails) + ctx = tokenctx.AddAllowlistEntityToContext(ctx, &allowlist.Entity{AllowAllSites: true}) got, err := srv.Login(ctx, user.LoginRequestObject{}) if err != nil { @@ -45,7 +47,7 @@ func TestLogin(t *testing.T) { want := user.Login200Response{ Headers: user.Login200ResponseHeaders{ - SetCookie: "jwt=eyJhbGciOiJFZERTQSIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOlsicm1pLm9yZyJdLCJlbWFpbHMiOlsidGVzdEBhbGxvd2VkLmV4YW1wbGUuY29tIl0sImV4cCI6MTIzNTQzMTg4LCJpYXQiOjEyMzQ1Njc4OSwianRpIjoiMDE5NGZkYzItZmEyZi00Y2MwLTgxZDMtZmYxMjA0NWI3M2M4IiwibmJmIjoxMjM0NTY3MjksInN1YiI6InVzZXIxMjMifQ.aJFKyWQ2035ziql5GxjtN6kn4bqc2w-q4_C_EH4cKAkFuybh3zDGf8TS-kC_w0NUL-y3U5xgJ_xdEJWqLEz0Ag; Path=/; Expires=Fri, 30 Nov 1973 21:33:08 GMT; HttpOnly; Secure; SameSite=Lax", + SetCookie: "jwt=eyJhbGciOiJFZERTQSIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOlsicm1pLm9yZyJdLCJlbWFpbHMiOlsidGVzdEBhbGxvd2VkLmV4YW1wbGUuY29tIl0sImV4cCI6MTIzNTQzMTg4LCJpYXQiOjEyMzQ1Njc4OSwianRpIjoiMDE5NGZkYzItZmEyZi00Y2MwLTgxZDMtZmYxMjA0NWI3M2M4IiwibmJmIjoxMjM0NTY3MjksInNpdGVzIjoiYWxsIiwic3ViIjoidXNlcjEyMyJ9.9H3r8uV66-ANKPwAOBcxy2s7EuDNDzF3e4i6AwPRAvhciQV58AZQuGapqEAI3dmV1_wwKerWJ22D4uQDsGLkCQ; Path=/; Expires=Fri, 30 Nov 1973 21:33:08 GMT; HttpOnly; Secure; SameSite=Lax", }, } @@ -62,9 +64,12 @@ func TestCreateAPIKey(t *testing.T) { exp := time.Date(9999, time.January, 1, 0, 0, 0, 0, time.UTC) tkn := jwt.New() tkn.Set("sub", "user123") - tkn.Set("emails", []any{"test@allowed.example.com"}) tkn.Set("exp", exp) + emails := []string{"test@allowed.example.com"} + tkn.Set("emails", emails) ctx = jwtauth.NewContext(ctx, tkn, nil) + ctx = tokenctx.AddEmailsToContext(ctx, emails) + ctx = tokenctx.AddAllowlistEntityToContext(ctx, &allowlist.Entity{AllowAllSites: true}) got, err := srv.CreateAPIKey(ctx, user.CreateAPIKeyRequestObject{}) if err != nil { @@ -74,7 +79,7 @@ func TestCreateAPIKey(t *testing.T) { want := user.CreateAPIKey200JSONResponse{ Id: "6e4ff95f-f662-45ee-a82a-bdf44a2d0b75", ExpiresAt: &exp, - Key: "eyJhbGciOiJFZERTQSIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOlsicm1pLm9yZyJdLCJleHAiOjI1MzM3MDc2NDgwMCwiaWF0IjoxMjM0NTY3ODksImp0aSI6IjZlNGZmOTVmLWY2NjItNDVlZS1hODJhLWJkZjQ0YTJkMGI3NSIsIm5iZiI6MTIzNDU2NzI5LCJzdWIiOiJ1c2VyMTIzIn0.sf7SbHOWGvW3mHadEsz64penWakt6KtlAs6z6EyYKcQRIiHeqMmoN6nycYnFjQ1RxD22IytFUDi_45Udi6UQCg", + Key: "eyJhbGciOiJFZERTQSIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOlsicm1pLm9yZyJdLCJleHAiOjI1MzM3MDc2NDgwMCwiaWF0IjoxMjM0NTY3ODksImp0aSI6IjZlNGZmOTVmLWY2NjItNDVlZS1hODJhLWJkZjQ0YTJkMGI3NSIsIm5iZiI6MTIzNDU2NzI5LCJzaXRlcyI6ImFsbCIsInN1YiI6InVzZXIxMjMifQ.Y3cAam6VOQ_5L7CxGeNx1r0oaNZylL1CTVP-rwp1NKQKmcpjh76ysipH6vd0o14mcVbQMAZ6YgXXQgMPnTTzDA", } if diff := cmp.Diff(want, got); diff != "" { diff --git a/emailctx/BUILD.bazel b/emailctx/BUILD.bazel deleted file mode 100644 index e367988..0000000 --- a/emailctx/BUILD.bazel +++ /dev/null @@ -1,8 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -go_library( - name = "emailctx", - srcs = ["emailctx.go"], - importpath = "github.com/RMI/credential-service/emailctx", - visibility = ["//visibility:public"], -) diff --git a/emailctx/emailctx.go b/emailctx/emailctx.go deleted file mode 100644 index a4321e2..0000000 --- a/emailctx/emailctx.go +++ /dev/null @@ -1,24 +0,0 @@ -package emailctx - -import ( - "context" - "fmt" -) - -type emailsContextKey struct{} - -func AddEmailsToContext(ctx context.Context, emails []string) context.Context { - return context.WithValue(ctx, emailsContextKey{}, emails) -} - -func EmailsFromContext(ctx context.Context) ([]string, error) { - v := ctx.Value(emailsContextKey{}) - if v == nil { - return nil, fmt.Errorf("no email found in context") - } - emails, ok := v.([]string) - if !ok { - return nil, fmt.Errorf("wrong type for email in context: %T", v) - } - return emails, nil -} diff --git a/tokenctx/BUILD.bazel b/tokenctx/BUILD.bazel new file mode 100644 index 0000000..d2d0a9f --- /dev/null +++ b/tokenctx/BUILD.bazel @@ -0,0 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "tokenctx", + srcs = ["tokenctx.go"], + importpath = "github.com/RMI/credential-service/tokenctx", + visibility = ["//visibility:public"], + deps = ["//allowlist"], +) diff --git a/tokenctx/tokenctx.go b/tokenctx/tokenctx.go new file mode 100644 index 0000000..dae93f1 --- /dev/null +++ b/tokenctx/tokenctx.go @@ -0,0 +1,44 @@ +package tokenctx + +import ( + "context" + "fmt" + + "github.com/RMI/credential-service/allowlist" +) + +type emailsContextKey struct{} + +func AddEmailsToContext(ctx context.Context, emails []string) context.Context { + return context.WithValue(ctx, emailsContextKey{}, emails) +} + +func EmailsFromContext(ctx context.Context) ([]string, error) { + v := ctx.Value(emailsContextKey{}) + if v == nil { + return nil, fmt.Errorf("no email found in context") + } + emails, ok := v.([]string) + if !ok { + return nil, fmt.Errorf("wrong type for email in context: %T", v) + } + return emails, nil +} + +type allowlistContextKey struct{} + +func AddAllowlistEntityToContext(ctx context.Context, ae *allowlist.Entity) context.Context { + return context.WithValue(ctx, allowlistContextKey{}, ae) +} + +func AllowlistEntityFromContext(ctx context.Context) (*allowlist.Entity, error) { + v := ctx.Value(allowlistContextKey{}) + if v == nil { + return nil, fmt.Errorf("no allowlist entity found in context") + } + entity, ok := v.(*allowlist.Entity) + if !ok { + return nil, fmt.Errorf("wrong type for allowlist entity in context: %T", v) + } + return entity, nil +}