Skip to content

Commit

Permalink
Merge pull request #349 from PzaThief/master
Browse files Browse the repository at this point in the history
Fix session context corrupted problem
  • Loading branch information
ybkuroki authored Nov 12, 2023
2 parents c9f2a30 + 9a1c9b5 commit b07a881
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 81 deletions.
12 changes: 6 additions & 6 deletions controller/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (controller *accountController) GetLoginAccount(c echo.Context) error {
if !controller.context.GetConfig().Extension.SecurityEnabled {
return c.JSON(http.StatusOK, controller.dummyAccount)
}
return c.JSON(http.StatusOK, controller.context.GetSession().GetAccount())
return c.JSON(http.StatusOK, controller.context.GetSession().GetAccount(c))
}

// Login is the method to login using username and password by http post.
Expand All @@ -79,14 +79,14 @@ func (controller *accountController) Login(c echo.Context) error {
}

sess := controller.context.GetSession()
if account := sess.GetAccount(); account != nil {
if account := sess.GetAccount(c); account != nil {
return c.JSON(http.StatusOK, account)
}

authenticate, a := controller.service.AuthenticateByUsernameAndPassword(dto.UserName, dto.Password)
if authenticate {
_ = sess.SetAccount(a)
_ = sess.Save()
_ = sess.SetAccount(c, a)
_ = sess.Save(c)
return c.JSON(http.StatusOK, a)
}
return c.NoContent(http.StatusUnauthorized)
Expand All @@ -102,7 +102,7 @@ func (controller *accountController) Login(c echo.Context) error {
// @Router /auth/logout [post]
func (controller *accountController) Logout(c echo.Context) error {
sess := controller.context.GetSession()
_ = sess.SetAccount(nil)
_ = sess.Delete()
_ = sess.SetAccount(c, nil)
_ = sess.Delete(c)
return c.NoContent(http.StatusOK)
}
55 changes: 55 additions & 0 deletions controller/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package controller

import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/ybkuroki/go-webapp-sample/config"
"github.com/ybkuroki/go-webapp-sample/container"
"github.com/ybkuroki/go-webapp-sample/test"
)

type sessionController struct {
container container.Container
}

func TestSessionRace_Success(t *testing.T) {
sessionKey := "Key"
router, container := test.PrepareForControllerTest(true)
session := sessionController{container: container}

router.GET(config.API+"1", func(c echo.Context) error {
_ = session.container.GetSession().SetValue(c, sessionKey, 1)
_ = session.container.GetSession().Save(c)
time.Sleep(3 * time.Second)
return c.String(http.StatusOK, session.container.GetSession().GetValue(c, sessionKey))
})
router.GET(config.API+"2", func(c echo.Context) error {
_ = session.container.GetSession().SetValue(c, sessionKey, 2)
_ = session.container.GetSession().Save(c)
return c.String(http.StatusOK, session.container.GetSession().GetValue(c, sessionKey))
})

req1 := httptest.NewRequest("GET", config.API+"1", nil)
req2 := httptest.NewRequest("GET", config.API+"2", nil)
rec1 := httptest.NewRecorder()
rec2 := httptest.NewRecorder()

go func() {
router.ServeHTTP(rec1, req1)
}()

go func() {
time.Sleep(1 * time.Second)
router.ServeHTTP(rec2, req2)
}()

time.Sleep(5 * time.Second)

assert.Equal(t, "1", rec1.Body.String())
assert.Equal(t, "2", rec2.Body.String())
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func main() {
logger.GetZapLogger().Infof("Loaded messages.properties")

rep := repository.NewBookRepository(logger, conf)
sess := session.NewSession()
sess := session.NewSession(logger, conf)
container := container.NewContainer(rep, sess, conf, messages, logger, env)

migration.CreateDatabase(container)
Expand Down
40 changes: 5 additions & 35 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@ package middleware

import (
"embed"
"fmt"
"io"
"net/http"
"regexp"
"strconv"

"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
echomd "github.com/labstack/echo/v4/middleware"
"github.com/valyala/fasttemplate"
"github.com/ybkuroki/go-webapp-sample/container"
"gopkg.in/boj/redistore.v1"
)

// InitLoggerMiddleware initialize a middleware for logger.
Expand All @@ -26,23 +23,9 @@ func InitLoggerMiddleware(e *echo.Echo, container container.Container) {
// InitSessionMiddleware initialize a middleware for session management.
func InitSessionMiddleware(e *echo.Echo, container container.Container) {
conf := container.GetConfig()
logger := container.GetLogger()

e.Use(SessionMiddleware(container))

e.Use(session.Middleware(container.GetSession().GetStore()))
if conf.Extension.SecurityEnabled {
if conf.Redis.Enabled {
logger.GetZapLogger().Infof("Try redis connection")
address := fmt.Sprintf("%s:%s", conf.Redis.Host, conf.Redis.Port)
store, err := redistore.NewRediStore(conf.Redis.ConnectionPoolSize, "tcp", address, "", []byte("secret"))
if err != nil {
logger.GetZapLogger().Errorf("Failure redis connection")
}
e.Use(session.Middleware(store))
logger.GetZapLogger().Infof(fmt.Sprintf("Success redis connection, %s", address))
} else {
e.Use(session.Middleware(sessions.NewCookieStore([]byte("secret"))))
}
e.Use(AuthenticationMiddleware(container))
}
}
Expand All @@ -63,7 +46,7 @@ func RequestLoggerMiddleware(container container.Container) echo.MiddlewareFunc
case "remote_ip":
return w.Write([]byte(c.RealIP()))
case "account_name":
if account := container.GetSession().GetAccount(); account != nil {
if account := container.GetSession().GetAccount(c); account != nil {
return w.Write([]byte(account.Name))
}
return w.Write([]byte("None"))
Expand Down Expand Up @@ -99,19 +82,6 @@ func ActionLoggerMiddleware(container container.Container) echo.MiddlewareFunc {
}
}

// SessionMiddleware is a middleware for setting a context to a session.
func SessionMiddleware(container container.Container) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
container.GetSession().SetContext(c)
if err := next(c); err != nil {
c.Error(err)
}
return nil
}
}
}

// StaticContentsMiddleware is the middleware for loading the static files.
func StaticContentsMiddleware(e *echo.Echo, container container.Container, staticFile embed.FS) {
conf := container.GetConfig()
Expand Down Expand Up @@ -155,16 +125,16 @@ func hasAuthorization(c echo.Context, container container.Container) bool {
if equalPath(currentPath, container.GetConfig().Security.ExculdePath) {
return true
}
account := container.GetSession().GetAccount()
account := container.GetSession().GetAccount(c)
if account == nil {
return false
}
if account.Authority.Name == "Admin" && equalPath(currentPath, container.GetConfig().Security.AdminPath) {
_ = container.GetSession().Save()
_ = container.GetSession().Save(c)
return true
}
if account.Authority.Name == "User" && equalPath(currentPath, container.GetConfig().Security.UserPath) {
_ = container.GetSession().Save()
_ = container.GetSession().Save(c)
return true
}
return false
Expand Down
89 changes: 51 additions & 38 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package session

import (
"encoding/json"
"net/http"
"fmt"

"github.com/gorilla/sessions"
echoSession "github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/ybkuroki/go-webapp-sample/config"
"github.com/ybkuroki/go-webapp-sample/logger"
"github.com/ybkuroki/go-webapp-sample/model"
"gopkg.in/boj/redistore.v1"
)

const (
Expand All @@ -18,79 +20,92 @@ const (
)

type session struct {
context echo.Context
store sessions.Store
}

// Session represents a interface for accessing the session on the application.
type Session interface {
SetContext(c echo.Context)
Get() *sessions.Session
Save() error
Delete() error
SetValue(key string, value interface{}) error
GetValue(key string) string
SetAccount(account *model.Account) error
GetAccount() *model.Account
GetStore() sessions.Store

Get(c echo.Context) *sessions.Session
Save(c echo.Context) error
Delete(c echo.Context) error
SetValue(c echo.Context, key string, value interface{}) error
GetValue(c echo.Context, key string) string
SetAccount(c echo.Context, account *model.Account) error
GetAccount(c echo.Context) *model.Account
}

// NewSession is constructor.
func NewSession() Session {
return &session{context: nil}
func NewSession(logger logger.Logger, conf *config.Config) Session {
if !conf.Redis.Enabled {
logger.GetZapLogger().Infof("use CookieStore for session")
return &session{sessions.NewCookieStore([]byte("secret"))}
}

logger.GetZapLogger().Infof("use redis for session")
logger.GetZapLogger().Infof("Try redis connection")
address := fmt.Sprintf("%s:%s", conf.Redis.Host, conf.Redis.Port)
store, err := redistore.NewRediStore(conf.Redis.ConnectionPoolSize, "tcp", address, "", []byte("secret"))
if err != nil {
logger.GetZapLogger().Panicf("Failure redis connection, %s", err.Error())
}
logger.GetZapLogger().Infof(fmt.Sprintf("Success redis connection, %s", address))
return &session{store: store}
}

// SetContext sets the context of echo framework to the session.
func (s *session) SetContext(c echo.Context) {
s.context = c
func (s *session) GetStore() sessions.Store {
return s.store
}

// Get returns a session for the current request.
func (s *session) Get() *sessions.Session {
sess, _ := echoSession.Get(sessionStr, s.context)
func (s *session) Get(c echo.Context) *sessions.Session {
sess, _ := s.store.Get(c.Request(), sessionStr)
return sess
}

// Save saves the current session.
func (s *session) Save() error {
sess := s.Get()
func (s *session) Save(c echo.Context) error {
sess := s.Get(c)
sess.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
}
return s.saveSession(sess)
return s.saveSession(c, sess)
}

// Delete the current session.
func (s *session) Delete() error {
sess := s.Get()
func (s *session) Delete(c echo.Context) error {
sess := s.Get(c)
sess.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
MaxAge: -1,
}
return s.saveSession(sess)
return s.saveSession(c, sess)
}

func (s *session) saveSession(sess *sessions.Session) error {
if err := sess.Save(s.context.Request(), s.context.Response()); err != nil {
return s.context.NoContent(http.StatusInternalServerError)
func (s *session) saveSession(c echo.Context, sess *sessions.Session) error {
if err := sess.Save(c.Request(), c.Response()); err != nil {
return fmt.Errorf("error occurred while save session")
}
return nil
}

// SetValue sets a key and a value.
func (s *session) SetValue(key string, value interface{}) error {
sess := s.Get()
func (s *session) SetValue(c echo.Context, key string, value interface{}) error {
sess := s.Get(c)
bytes, err := json.Marshal(value)
if err != nil {
return s.context.NoContent(http.StatusInternalServerError)
return fmt.Errorf("json marshal error while set value in session")
}
sess.Values[key] = string(bytes)
return nil
}

// GetValue returns value of session.
func (s *session) GetValue(key string) string {
sess := s.Get()
func (s *session) GetValue(c echo.Context, key string) string {
sess := s.Get(c)
if sess != nil {
if v, ok := sess.Values[key]; ok {
data, result := v.(string)
Expand All @@ -102,14 +117,12 @@ func (s *session) GetValue(key string) string {
return ""
}

// SetAccount sets account data in session.
func (s *session) SetAccount(account *model.Account) error {
return s.SetValue(Account, account)
func (s *session) SetAccount(c echo.Context, account *model.Account) error {
return s.SetValue(c, Account, account)
}

// GetAccount returns account object of session.
func (s *session) GetAccount() *model.Account {
if v := s.GetValue(Account); v != "" {
func (s *session) GetAccount(c echo.Context) *model.Account {
if v := s.GetValue(c, Account); v != "" {
a := &model.Account{}
_ = json.Unmarshal([]byte(v), a)
return a
Expand Down
2 changes: 1 addition & 1 deletion test/unittest_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func createConfig(isSecurity bool) *config.Config {

func initContainer(conf *config.Config, logger logger.Logger) container.Container {
rep := repository.NewBookRepository(logger, conf)
sess := session.NewSession()
sess := session.NewSession(logger, conf)
messages := map[string]string{
"ValidationErrMessageBookTitle": "Please enter the title with 3 to 50 characters.",
"ValidationErrMessageBookISBN": "Please enter the ISBN with 10 to 20 characters."}
Expand Down

0 comments on commit b07a881

Please sign in to comment.