diff --git a/DESCRIPTION b/DESCRIPTION index d4014153..3ac57631 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: BayesMallows Type: Package Title: Bayesian Preference Learning with the Mallows Rank Model -Version: 2.1.1.9001 +Version: 2.1.1.9002 Authors@R: c(person("Oystein", "Sorensen", email = "oystein.sorensen.1985@gmail.com", role = c("aut", "cre"), diff --git a/NEWS.md b/NEWS.md index 84e82a7a..eb291e0f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # BayesMallows (development versions) +* Fixed a bug which caused inconsistent partial rank data to be retained from + previous timepoints when existing users update their preferences. * Arguments random and random_limit to setup_rank_data() have been removed. A new argument max_topological_sorts has been added instead, which captures all previous use cases, but also allows the user to specify the number of diff --git a/src/missing_data.h b/src/missing_data.h index 9098c4a5..1a37218b 100644 --- a/src/missing_data.h +++ b/src/missing_data.h @@ -23,3 +23,6 @@ std::pair make_new_augmentation( arma::mat initialize_missing_ranks( arma::mat rankings, const arma::umat& missing_indicator); + +arma::vec initialize_missing_ranks_vec( + arma::vec rankings, const arma::uvec& missing_indicator); diff --git a/src/missing_data_functions.cpp b/src/missing_data_functions.cpp index 87f6fa2b..9a207be6 100644 --- a/src/missing_data_functions.cpp +++ b/src/missing_data_functions.cpp @@ -6,20 +6,24 @@ using namespace arma; +vec initialize_missing_ranks_vec(vec rankings, const uvec& missing_indicator) { + vec present_ranks = rankings(find(missing_indicator == 0)); + uvec missing_inds = find(missing_indicator == 1); + vec a = setdiff( + regspace(1, rankings.n_elem), present_ranks); + + ivec inds = Rcpp::sample(a.size(), a.size()) - 1; + vec new_ranks = a.elem(conv_to::from(inds)); + rankings(missing_inds) = new_ranks; + return rankings; +} + mat initialize_missing_ranks(mat rankings, const umat& missing_indicator) { int n_assessors = rankings.n_cols; for(int i = 0; i < n_assessors; ++i){ - vec rank_vector = rankings.col(i); - vec present_ranks = rank_vector(find(missing_indicator.col(i) == 0)); - uvec missing_inds = find(missing_indicator.col(i) == 1); - vec a = setdiff( - regspace(1, rank_vector.n_elem), present_ranks); - - ivec inds = Rcpp::sample(a.size(), a.size()) - 1; - vec new_ranks = a.elem(conv_to::from(inds)); - rank_vector(missing_inds) = new_ranks; - rankings.col(i) = rank_vector; + rankings.col(i) = initialize_missing_ranks_vec( + rankings.col(i), missing_indicator.col(i)); } return rankings; } diff --git a/src/particles.cpp b/src/particles.cpp index a2db9e55..d3462a2b 100644 --- a/src/particles.cpp +++ b/src/particles.cpp @@ -81,9 +81,12 @@ std::vector augment_particles( vec to_compare = dat.rankings.col(index); uvec comparison_inds = find(to_compare > 0); vec augmented = pvec[i].augmented_data(span::all, span(index)); - - pvec[i].consistent(index) = - all(to_compare(comparison_inds) == augmented(comparison_inds)); + bool check = all(to_compare(comparison_inds) == augmented(comparison_inds)); + pvec[i].consistent(index) = check; + if(!check) { + pvec[i].augmented_data.col(index) = + initialize_missing_ranks_vec(to_compare, dat.missing_indicator.col(index)); + } } if(dat.num_new_obs > 0) { diff --git a/src/smc_data_class.cpp b/src/smc_data_class.cpp index f09eb557..c4ae325d 100644 --- a/src/smc_data_class.cpp +++ b/src/smc_data_class.cpp @@ -46,6 +46,7 @@ void SMCData::update( for(auto index : updated_indices) { rankings.col(updated_match[index]) = new_dat.rankings.col(updated_new[index]); + missing_indicator.col(updated_match[index]) = new_dat.missing_indicator.col(updated_new[index]); if(augpair) { items_above[updated_match[index]] = new_dat.items_above[updated_new[index]]; items_below[updated_match[index]] = new_dat.items_below[updated_new[index]]; diff --git a/tests/testthat/test-smc_update_correctness.R b/tests/testthat/test-smc_update_correctness.R index 80e875aa..176784ac 100644 --- a/tests/testthat/test-smc_update_correctness.R +++ b/tests/testthat/test-smc_update_correctness.R @@ -169,31 +169,39 @@ test_that("update_mallows is correct for updated partial rankings", { set.seed(1) user_ids <- 1:12 dat0 <- potato_visual - dat0[] <- ifelse(runif(length(dat0)) > .8, NA_real_, dat0) + dat0[] <- ifelse(runif(length(dat0)) > .5, NA_real_, dat0) mod0 <- compute_mallows( data = setup_rank_data(rankings = dat0), - compute_options = set_compute_options(burnin = 1000) + compute_options = set_compute_options(nmc = 5000, burnin = 2500) ) dat1 <- potato_visual - dat1 <- ifelse(is.na(dat0) & runif(length(dat1)) > .8, NA_real_, dat1) + dat1 <- ifelse(is.na(dat0) & runif(length(dat1)) > .5, NA_real_, dat1) mod1 <- update_mallows( model = mod0, new_data = setup_rank_data(rankings = dat1, user_ids = user_ids), - compute_options = set_compute_options(aug_method = "pseudo") + compute_options = set_compute_options(aug_method = "pseudo"), + smc_options = set_smc_options(n_particles = 3000, mcmc_steps = 10) ) + for (i in 1:12) { + expect_equal( + as.numeric(mod1$augmented_rankings[, i, 1][!is.na(dat1[i, ])]), + as.numeric(dat1[i, !is.na(dat1[i, ])]) + ) + } + mod_bmm1 <- compute_mallows( data = setup_rank_data(rankings = dat1), - compute_options = set_compute_options(burnin = 500) + compute_options = set_compute_options(nmc = 5000, burnin = 500) ) expect_equal( mean(mod1$alpha$value), - mean(mod_bmm1$alpha$value[mod_bmm1$alpha$iteration > 500]), - tolerance = .05 + mean(mod_bmm1$alpha$value[mod_bmm1$alpha$iteration > 2500]), + tolerance = .2 ) dat2 <- potato_visual @@ -201,19 +209,23 @@ test_that("update_mallows is correct for updated partial rankings", { mod2 <- update_mallows( model = mod1, - new_data = setup_rank_data(rankings = dat2, user_ids = user_ids), - compute_options = set_compute_options(aug_method = "pseudo"), - smc_options = set_smc_options(n_particles = 1000, mcmc_steps = 20) + new_data = setup_rank_data(rankings = dat2, user_ids = user_ids) ) + for (i in 1:12) { + expect_equal( + as.numeric(mod2$augmented_rankings[, i, 1][!is.na(dat2[i, ])]), + as.numeric(dat2[i, !is.na(dat2[i, ])]) + ) + } + mod_bmm <- compute_mallows( - data = setup_rank_data(rankings = dat2), - compute_options = set_compute_options(nmc = 50000) + data = setup_rank_data(rankings = dat2) ) expect_equal( mean(mod2$alpha$value), - mean(mod_bmm$alpha$value[mod_bmm$alpha$iteration > 5000]), + mean(mod_bmm$alpha$value[mod_bmm$alpha$iteration > 1000]), tolerance = .1 ) }) diff --git a/work-docs/bughunting.R b/work-docs/bughunting.R new file mode 100644 index 00000000..11d4b437 --- /dev/null +++ b/work-docs/bughunting.R @@ -0,0 +1,44 @@ +set.seed(1) +user_ids <- 1:12 +dat0 <- potato_visual +dat0[] <- ifelse(runif(length(dat0)) > .2, NA_real_, dat0) + +mod0 <- compute_mallows( + data = setup_rank_data(rankings = dat0), + compute_options = set_compute_options(burnin = 1000) +) + +dat1 <- potato_visual +dat1 <- ifelse(is.na(dat0) & runif(length(dat1)) > .5, NA_real_, dat1) + +mod1 <- update_mallows( + model = mod0, + new_data = setup_rank_data(rankings = dat1, user_ids = user_ids), + compute_options = set_compute_options(aug_method = "pseudo") +) +for(i in 1:12) { + expect_equal( + as.numeric(mod1$augmented_rankings[, i, 1][!is.na(dat1[i, ])]), + as.numeric(dat1[i, !is.na(dat1[i, ])]) + ) +} + +mod_bmm1 <- compute_mallows( + data = setup_rank_data(rankings = dat1), + compute_options = set_compute_options(burnin = 500) +) + +dat2 <- potato_visual +dat2 <- ifelse(is.na(dat1) & runif(length(dat2)) > .2, NA_real_, dat2) + +mod2 <- update_mallows( + model = mod1, + new_data = setup_rank_data(rankings = dat2, user_ids = user_ids) +) + +for(i in 1:12) { + expect_equal( + as.numeric(mod2$augmented_rankings[, i, 1][!is.na(dat2[i, ])]), + as.numeric(dat2[i, !is.na(dat2[i, ])]) + ) +}