Skip to content

Commit

Permalink
fix: update panic if model is not ptr (#6037)
Browse files Browse the repository at this point in the history
* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: update panic if model is not ptr

* fix: raise an error if the value is not addressable

* fix: return
  • Loading branch information
black-06 authored Feb 18, 2023
1 parent 42fc75c commit e66a059
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 4 deletions.
13 changes: 11 additions & 2 deletions callbacks/callmethod.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++
}
case reflect.Struct:
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
}
}
}
4 changes: 3 additions & 1 deletion callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
}
}
case reflect.Struct:
Expand Down
2 changes: 1 addition & 1 deletion schema/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i)
elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr {
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
elemKey = elem.Addr().Interface()
}

Expand Down
52 changes: 52 additions & 0 deletions tests/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,55 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes)
}
}

type Product5 struct {
gorm.Model
Name string
}

var beforeUpdateCall int

func (p *Product5) BeforeUpdate(*gorm.DB) error {
beforeUpdateCall = beforeUpdateCall + 1
return nil
}

func TestUpdateCallbacks(t *testing.T) {
DB.Migrator().DropTable(&Product5{})
DB.AutoMigrate(&Product5{})

p := Product5{Name: "unique_code"}
DB.Model(&Product5{}).Create(&p)

err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error
if err != nil {
t.Fatalf("should update success, but got err %v", err)
}
if beforeUpdateCall != 1 {
t.Fatalf("before update should be called")
}

err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error
if !errors.Is(err, gorm.ErrInvalidValue) {
t.Fatalf("should got RecordNotFound, but got %v", err)
}
if beforeUpdateCall != 1 {
t.Fatalf("before update should not be called")
}

err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error
if err != nil {
t.Fatalf("should update success, but got err %v", err)
}
if beforeUpdateCall != 2 {
t.Fatalf("before update should be called")
}

err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error
if !errors.Is(err, gorm.ErrInvalidValue) {
t.Fatalf("should got RecordNotFound, but got %v", err)
}
if beforeUpdateCall != 2 {
t.Fatalf("before update should not be called")
}
}

0 comments on commit e66a059

Please sign in to comment.