Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model stable test #782

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/main/java/ml/shifu/shifu/ShifuCLI.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
import ml.shifu.shifu.util.Environment;
import ml.shifu.shifu.util.IndependentTreeModelUtils;

import static ml.shifu.shifu.util.Constants.CHAOS_COLUMNS;
import static ml.shifu.shifu.util.Constants.CHAOS_TYPE;

/**
* ShifuCLI class is the MAIN class for whole project It will read and analysis
* the parameters from command line and execute corresponding functions
Expand Down Expand Up @@ -126,6 +129,8 @@ public class ShifuCLI {
private static final String SCORE = "score";
private static final String CONFMAT = "confmat";
private static final String PERF = "perf";
private static final String STABILITY = "stab";

private static final String NORM = "norm";
private static final String NOSORT = "nosort";
private static final String REF = "ref";
Expand Down Expand Up @@ -374,7 +379,17 @@ public static void main(String[] args) {
// run perfermance
runEvalPerf(cmd.getOptionValue(PERF));
log.info("Finish run performance maxtrix with eval set {}.", cmd.getOptionValue(PERF));
} else if (cmd.hasOption(LIST)) {
} else if(cmd.hasOption(STABILITY)) {
// run perfermance
if(cmd.hasOption(CHAOS_TYPE)) {
params.put(CHAOS_TYPE, cmd.getOptionValue(CHAOS_TYPE));
}
if(cmd.hasOption(CHAOS_COLUMNS)) {
params.put(CHAOS_COLUMNS, cmd.getOptionValue(CHAOS_COLUMNS));
}
runEvalStability(cmd.getOptionValue(STABILITY), params);
log.info("Finish run performance maxtrix with eval set {}.", cmd.getOptionValue(PERF));
} else if(cmd.hasOption(LIST)) {
// list all evaluation sets
listEvalSet();
} else if (cmd.hasOption(DELETE)) {
Expand Down Expand Up @@ -625,6 +640,11 @@ private static int runEvalPerf(String evalSetNames) throws Exception {
return p.run();
}

private static int runEvalStability(String evalSetNames, Map<String, Object> params) throws Exception {
EvalModelProcessor p = new EvalModelProcessor(EvalStep.STAB, evalSetNames, params);
return p.run();
}

private static int runEvalNorm(String evalSetNames, Map<String, Object> params) throws Exception {
EvalModelProcessor p = new EvalModelProcessor(EvalStep.NORM, evalSetNames, params);
return p.run();
Expand Down Expand Up @@ -792,6 +812,9 @@ private static Options buildModelSetOptions() {
Option opt_score = OptionBuilder.hasOptionalArg().create(SCORE);
Option opt_confmat = OptionBuilder.hasArg().create(CONFMAT);
Option opt_perf = OptionBuilder.hasArg().create(PERF);
Option opt_stab = OptionBuilder.hasArg(false).create(STABILITY);
Option opt_chaos_type = OptionBuilder.hasArg().create(CHAOS_TYPE);
Option opt_chaos_cols = OptionBuilder.hasArg().create(CHAOS_COLUMNS);
Option opt_norm = OptionBuilder.hasArg().create(NORM);
Option opt_eval = OptionBuilder.hasArg(false).create(EVAL_CMD);
Option opt_init = OptionBuilder.hasArg(false).create(INIT_CMD);
Expand Down Expand Up @@ -831,6 +854,9 @@ private static Options buildModelSetOptions() {
opts.addOption(opt_type);
opts.addOption(opt_run);
opts.addOption(opt_perf);
opts.addOption(opt_stab);
opts.addOption(opt_chaos_type);
opts.addOption(opt_chaos_cols);
opts.addOption(opt_norm);
opts.addOption(opt_model);
opts.addOption(opt_concise);
Expand Down Expand Up @@ -931,6 +957,8 @@ private static void printUsage() {
System.out.println("\teval -confmat <EvalSetName> Compute the TP/FP/TN/FN based on scoring.");
System.out
.println("\teval -perf <EvalSetName> Calculate the model performance based on confmat.");
System.out.println("\teval -stab <EvalSetName> -type <chaosType> -cols <column names list join with ','> " +
"Score evaluation dataset with injection on specific columns.");
System.out.println("\teval -audit [-n <#numofrecords>] Score eval data and generate audit dataset.");
System.out.println("\texport [-t pmml|columnstats|woemapping|bagging|baggingpmml|corr|woe|ume|baggingume|normume]");
System.out.println("\t [-c] [-vars var1,var1] [-ivr <ratio>] [-bic <bic>] [-name <modelName>] [-postfix <postfix>] [-strategy <max|min|mean>] [-mapping <variable_mapping.conf>]");
Expand Down
145 changes: 134 additions & 11 deletions src/main/java/ml/shifu/shifu/core/processor/EvalModelProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,11 @@
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;

import ml.shifu.shifu.core.stability.ChaosType;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.io.FileUtils;
Expand Down Expand Up @@ -70,6 +63,9 @@
import ml.shifu.shifu.util.HdfsPartFile;
import ml.shifu.shifu.util.ModelSpecLoaderUtils;

import static ml.shifu.shifu.util.Constants.CHAOS_COLUMNS;
import static ml.shifu.shifu.util.Constants.CHAOS_TYPE;

/**
* EvalModelProcessor class
*/
Expand All @@ -84,7 +80,7 @@ public class EvalModelProcessor extends BasicModelProcessor implements Processor
* Step for evaluation
*/
public enum EvalStep {
LIST, NEW, DELETE, RUN, PERF, SCORE, AUDIT, CONFMAT, NORM, GAINCHART;
LIST, NEW, DELETE, RUN, PERF, STAB, SCORE, AUDIT, CONFMAT, NORM, GAINCHART;
}

public static final String NOSORT = "NOSORT";
Expand Down Expand Up @@ -184,6 +180,9 @@ public int run() throws Exception {
case SCORE:
runScore(getEvalConfigListFromInput());
break;
case STAB:
runStability(getEvalConfigListFromInput());
break;
case AUDIT:
runGenAudit(getEvalConfigListFromInput());
break;
Expand Down Expand Up @@ -463,6 +462,13 @@ private ScoreStatus runDistScore(EvalConfig evalConfig, int index) throws IOExce
|| (isNoSort() && (EvalStep.SCORE.equals(this.evalStep) || EvalStep.AUDIT.equals(this.evalStep)))) {
pigScript = "scripts/EvalScore.pig";
}
if(EvalStep.STAB.equals(this.evalStep) && this.params.containsKey(CHAOS_TYPE) && this.params.containsKey(CHAOS_COLUMNS)) {
paramsMap.put(CHAOS_TYPE, this.params.get(CHAOS_TYPE).toString());
paramsMap.put(CHAOS_COLUMNS, this.params.get(CHAOS_COLUMNS).toString());
pigScript = "scripts/EvalChaosScore.pig";
}
LOG.info("run dist score with pigScript {}, parameters: {}", pigScript, paramsMap.toString());

try {
PigExecutor.getExecutor().submitJob(modelConfig, pathFinder.getScriptPath(pigScript), paramsMap,
evalConfig.getDataSet().getSource(), confMap, super.pathFinder);
Expand Down Expand Up @@ -766,6 +772,123 @@ public void run() {
}
}

private void runStability(List<EvalConfig> evalSetList) throws IOException {
// validate the stability config
validateStabilityConfig(this.params);

// do it only once
syncDataToHdfs(evalSetList);

// validation for score column
for(EvalConfig evalConfig: evalSetList) {
List<String> scoreMetaColumns = evalConfig.getScoreMetaColumns(modelConfig);
if(scoreMetaColumns.size() > 5) {
LOG.error(
"Starting from 0.10.x, 'scoreMetaColumns' is used for benchmark score columns and limited to at most 5.");
LOG.error(
"If meta columns are set in file of 'scoreMetaColumns', please move meta column config to 'eval#dataSet#metaColumnNameFile' part.");
LOG.error(
"If 'eval#dataSet#metaColumnNameFile' is duplicated with training 'metaColumnNameFile', you can rename it to another file with different name.");
return;
}
}

if(Environment.getBoolean(Constants.SHIFU_EVAL_PARALLEL, true) && modelConfig.isMapReduceRunMode()
&& evalSetList.size() > 1) {
// run in parallel
int parallelNum = Environment.getInt(Constants.SHIFU_EVAL_PARALLEL_NUM, 5);
if(parallelNum <= 0 || parallelNum > 100) {
throw new IllegalArgumentException(Constants.SHIFU_EVAL_PARALLEL_NUM
+ " in shifuconfig should be in (0, 100], by default it is 5.");
}

int evalSize = evalSetList.size();
int mod = evalSize % parallelNum;
int batch = evalSize / parallelNum;
batch = (mod == 0 ? batch : (batch + 1));

for(int i = 0; i < batch; i++) {
int batchSize = (mod != 0 && i == (batch - 1)) ? mod : parallelNum;
// lunch current batch size
LOG.info("Starting to run eval score in {}/{} round", (i + 1), batch);
final CountDownLatch cdl = new CountDownLatch(batchSize);
for(int j = 0; j < batchSize; j++) {
int currentIndex = i * parallelNum + j;
final EvalConfig config = evalSetList.get(currentIndex);
// save tmp models
Thread evalRunThread = new Thread(new Runnable() {
@Override
public void run() {
try {
runEval(config);
} catch (IOException e) {
LOG.error("Exception in eval:", e);
} catch (Exception e) {
LOG.error("Exception in eval:", e);
}
cdl.countDown();
}
}, config.getName());
// print eval name to log4j console to make each one is easy to be get from logs
evalRunThread.start();

// each one sleep 3s to avoid conflict in initialization
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}

LOG.info("Starting to wait eval in {}/{} round", (i + 1), batch);
// await all threads done
try {
cdl.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
LOG.info("Finish eval in {}/{} round", (i + 1), batch);
}
LOG.info("Finish all eval parallel running with eval size {}.", evalSize);
} else {
// for old sequential runs
for(EvalConfig evalConfig: evalSetList) {
runEval(evalConfig);
}
}
}

private void validateStabilityConfig(Map<String, Object> params) {
if(!params.containsKey(CHAOS_TYPE) || !params.containsKey(CHAOS_COLUMNS)) {
LOG.error("chaos tpe (-type) and chaos columns(-col) are required");
throw new IllegalArgumentException("chaos tpe (-type) and chaos columns(-col) are required");
}
if(Objects.isNull(ChaosType.fromName(params.get(CHAOS_TYPE).toString()))) {
LOG.error("Chaos type {} not supported yet", params.get(CHAOS_TYPE));
throw new IllegalArgumentException("Chaos type " + params.get(CHAOS_TYPE) + " does not support yet");
}
String validColumnNames = getValidColumnNames(params.get(CHAOS_COLUMNS).toString());
if(validColumnNames.isEmpty()) {
LOG.error("Chaos column {} do not has a valid column name", params.get(CHAOS_COLUMNS));
throw new IllegalArgumentException("Chaos column " + params.get(CHAOS_COLUMNS) + " do not has a valid column name");
} else {
// update the params with the valid columnNames in lower case
this.params.put(CHAOS_COLUMNS, validColumnNames);
}
}

private String getValidColumnNames(String columnNamesSeparateByComma) {
return Arrays.stream(columnNamesSeparateByComma.split(",")).map(String::trim).filter(name -> {
for(ColumnConfig columnConfig: this.columnConfigList) {
if(columnConfig.getColumnName().equalsIgnoreCase(name)) {
return true;
}
}
LOG.warn("Chaos column value {} is not a valid column name, will be ignored", name);
return false;
}).map(String::toLowerCase).collect(Collectors.joining(","));
}

@SuppressWarnings("deprecation")
private void validateEvalColumnConfig(EvalConfig evalConfig) throws IOException {
if(this.columnConfigList == null) {
Expand Down
76 changes: 76 additions & 0 deletions src/main/java/ml/shifu/shifu/core/stability/ChaosFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright [2013-2021] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.stability;

import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.core.stability.algorithm.BaseChaosAlgorithm;
import ml.shifu.shifu.util.Constants;
import ml.shifu.shifu.util.Environment;
import org.reflections.Reflections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

/**
* ChaosAlgorithmFactory to get chaos factory class by it's name.
*
* @author Wu Devin (haifwu@paypal.com)
*/
public class ChaosFactory {
private static final Logger LOG = LoggerFactory.getLogger(ChaosFactory.class);
private ChaosType chaosType = ChaosType.fromName(Environment.getProperty(Constants.CHAOS_TYPE));
private Set<String> chaosColumnSet = new HashSet<>(Arrays.asList(Environment.getProperty(Constants.CHAOS_COLUMNS).split(",")));

private ChaosFactory() {
LOG.info("ChaosFactory init with chaos type {}: {}", this.chaosType.getName(), this.chaosType.getDescription());
LOG.info("Inject chaos type on columns: {}", Environment.getProperty(Constants.CHAOS_COLUMNS));
}

private static class SingletonHolder {
private static ChaosFactory instance = new ChaosFactory();
}

/**
* Public method to get instance
*
* @return
* The singleton instance.
*/
public static ChaosFactory getInstance() {
return SingletonHolder.instance;
}

private static Map<String, Class<? extends BaseChaosAlgorithm>> algorithmMap = new HashMap<String, Class<? extends BaseChaosAlgorithm>>() {
private static final long serialVersionUID = -1080829888400897248L;
{
Reflections reflections = new Reflections("ml.shifu.shifu.core.stability.algorithm");
Set<Class<? extends BaseChaosAlgorithm>> classes = reflections.getSubTypesOf(BaseChaosAlgorithm.class);
for(Class<? extends BaseChaosAlgorithm> algorithm: classes) {
put(algorithm.getName().toLowerCase(), algorithm);
}
}
};

public ChaosType getChaosType() {
return this.chaosType;
}

public boolean needInjectChaos(ColumnConfig config) {
LOG.info("Inject chaos for {} use type {}", config.getColumnName(), this.chaosType.getName());
return this.chaosColumnSet.contains(config.getColumnName().toLowerCase());
}
}
Loading