Skip to content

Commit

Permalink
fixed bug in augmented rankings for existing users
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Mar 20, 2024
1 parent 5538d69 commit f09d9bf
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 27 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: BayesMallows
Type: Package
Title: Bayesian Preference Learning with the Mallows Rank Model
Version: 2.1.1.9001
Version: 2.1.1.9002
Authors@R: c(person("Oystein", "Sorensen",
email = "oystein.sorensen.1985@gmail.com",
role = c("aut", "cre"),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# BayesMallows (development versions)

* Fixed a bug which caused inconsistent partial rank data to be retained from
previous timepoints when existing users update their preferences.
* Arguments random and random_limit to setup_rank_data() have been removed. A
new argument max_topological_sorts has been added instead, which captures all
previous use cases, but also allows the user to specify the number of
Expand Down
3 changes: 3 additions & 0 deletions src/missing_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ std::pair<arma::vec, bool> make_new_augmentation(
arma::mat initialize_missing_ranks(
arma::mat rankings,
const arma::umat& missing_indicator);

arma::vec initialize_missing_ranks_vec(
arma::vec rankings, const arma::uvec& missing_indicator);
24 changes: 14 additions & 10 deletions src/missing_data_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@

using namespace arma;

vec initialize_missing_ranks_vec(vec rankings, const uvec& missing_indicator) {
vec present_ranks = rankings(find(missing_indicator == 0));
uvec missing_inds = find(missing_indicator == 1);
vec a = setdiff(
regspace<vec>(1, rankings.n_elem), present_ranks);

ivec inds = Rcpp::sample(a.size(), a.size()) - 1;
vec new_ranks = a.elem(conv_to<uvec>::from(inds));
rankings(missing_inds) = new_ranks;
return rankings;
}

mat initialize_missing_ranks(mat rankings, const umat& missing_indicator) {
int n_assessors = rankings.n_cols;

for(int i = 0; i < n_assessors; ++i){
vec rank_vector = rankings.col(i);
vec present_ranks = rank_vector(find(missing_indicator.col(i) == 0));
uvec missing_inds = find(missing_indicator.col(i) == 1);
vec a = setdiff(
regspace<vec>(1, rank_vector.n_elem), present_ranks);

ivec inds = Rcpp::sample(a.size(), a.size()) - 1;
vec new_ranks = a.elem(conv_to<uvec>::from(inds));
rank_vector(missing_inds) = new_ranks;
rankings.col(i) = rank_vector;
rankings.col(i) = initialize_missing_ranks_vec(
rankings.col(i), missing_indicator.col(i));
}
return rankings;
}
Expand Down
9 changes: 6 additions & 3 deletions src/particles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ std::vector<Particle> augment_particles(
vec to_compare = dat.rankings.col(index);
uvec comparison_inds = find(to_compare > 0);
vec augmented = pvec[i].augmented_data(span::all, span(index));

pvec[i].consistent(index) =
all(to_compare(comparison_inds) == augmented(comparison_inds));
bool check = all(to_compare(comparison_inds) == augmented(comparison_inds));
pvec[i].consistent(index) = check;
if(!check) {
pvec[i].augmented_data.col(index) =
initialize_missing_ranks_vec(to_compare, dat.missing_indicator.col(index));
}
}

if(dat.num_new_obs > 0) {
Expand Down
1 change: 1 addition & 0 deletions src/smc_data_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ void SMCData::update(

for(auto index : updated_indices) {
rankings.col(updated_match[index]) = new_dat.rankings.col(updated_new[index]);
missing_indicator.col(updated_match[index]) = new_dat.missing_indicator.col(updated_new[index]);
if(augpair) {
items_above[updated_match[index]] = new_dat.items_above[updated_new[index]];
items_below[updated_match[index]] = new_dat.items_below[updated_new[index]];
Expand Down
38 changes: 25 additions & 13 deletions tests/testthat/test-smc_update_correctness.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,51 +169,63 @@ test_that("update_mallows is correct for updated partial rankings", {
set.seed(1)
user_ids <- 1:12
dat0 <- potato_visual
dat0[] <- ifelse(runif(length(dat0)) > .8, NA_real_, dat0)
dat0[] <- ifelse(runif(length(dat0)) > .5, NA_real_, dat0)

mod0 <- compute_mallows(
data = setup_rank_data(rankings = dat0),
compute_options = set_compute_options(burnin = 1000)
compute_options = set_compute_options(nmc = 5000, burnin = 2500)
)

dat1 <- potato_visual
dat1 <- ifelse(is.na(dat0) & runif(length(dat1)) > .8, NA_real_, dat1)
dat1 <- ifelse(is.na(dat0) & runif(length(dat1)) > .5, NA_real_, dat1)

mod1 <- update_mallows(
model = mod0,
new_data = setup_rank_data(rankings = dat1, user_ids = user_ids),
compute_options = set_compute_options(aug_method = "pseudo")
compute_options = set_compute_options(aug_method = "pseudo"),
smc_options = set_smc_options(n_particles = 3000, mcmc_steps = 10)
)

for (i in 1:12) {
expect_equal(
as.numeric(mod1$augmented_rankings[, i, 1][!is.na(dat1[i, ])]),
as.numeric(dat1[i, !is.na(dat1[i, ])])
)
}

mod_bmm1 <- compute_mallows(
data = setup_rank_data(rankings = dat1),
compute_options = set_compute_options(burnin = 500)
compute_options = set_compute_options(nmc = 5000, burnin = 500)
)

expect_equal(
mean(mod1$alpha$value),
mean(mod_bmm1$alpha$value[mod_bmm1$alpha$iteration > 500]),
tolerance = .05
mean(mod_bmm1$alpha$value[mod_bmm1$alpha$iteration > 2500]),
tolerance = .2
)

dat2 <- potato_visual
dat2 <- ifelse(is.na(dat1) & runif(length(dat2)) > .5, NA_real_, dat2)

mod2 <- update_mallows(
model = mod1,
new_data = setup_rank_data(rankings = dat2, user_ids = user_ids),
compute_options = set_compute_options(aug_method = "pseudo"),
smc_options = set_smc_options(n_particles = 1000, mcmc_steps = 20)
new_data = setup_rank_data(rankings = dat2, user_ids = user_ids)
)

for (i in 1:12) {
expect_equal(
as.numeric(mod2$augmented_rankings[, i, 1][!is.na(dat2[i, ])]),
as.numeric(dat2[i, !is.na(dat2[i, ])])
)
}

mod_bmm <- compute_mallows(
data = setup_rank_data(rankings = dat2),
compute_options = set_compute_options(nmc = 50000)
data = setup_rank_data(rankings = dat2)
)

expect_equal(
mean(mod2$alpha$value),
mean(mod_bmm$alpha$value[mod_bmm$alpha$iteration > 5000]),
mean(mod_bmm$alpha$value[mod_bmm$alpha$iteration > 1000]),
tolerance = .1
)
})
Expand Down
44 changes: 44 additions & 0 deletions work-docs/bughunting.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
set.seed(1)
user_ids <- 1:12
dat0 <- potato_visual
dat0[] <- ifelse(runif(length(dat0)) > .2, NA_real_, dat0)

mod0 <- compute_mallows(
data = setup_rank_data(rankings = dat0),
compute_options = set_compute_options(burnin = 1000)
)

dat1 <- potato_visual
dat1 <- ifelse(is.na(dat0) & runif(length(dat1)) > .5, NA_real_, dat1)

mod1 <- update_mallows(
model = mod0,
new_data = setup_rank_data(rankings = dat1, user_ids = user_ids),
compute_options = set_compute_options(aug_method = "pseudo")
)
for(i in 1:12) {
expect_equal(
as.numeric(mod1$augmented_rankings[, i, 1][!is.na(dat1[i, ])]),
as.numeric(dat1[i, !is.na(dat1[i, ])])
)
}

mod_bmm1 <- compute_mallows(
data = setup_rank_data(rankings = dat1),
compute_options = set_compute_options(burnin = 500)
)

dat2 <- potato_visual
dat2 <- ifelse(is.na(dat1) & runif(length(dat2)) > .2, NA_real_, dat2)

mod2 <- update_mallows(
model = mod1,
new_data = setup_rank_data(rankings = dat2, user_ids = user_ids)
)

for(i in 1:12) {
expect_equal(
as.numeric(mod2$augmented_rankings[, i, 1][!is.na(dat2[i, ])]),
as.numeric(dat2[i, !is.na(dat2[i, ])])
)
}

0 comments on commit f09d9bf

Please sign in to comment.