package cn.edu.bjut.chapter6; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Random; public class Grouper { private List data; private int nfold; private int[] starts = null; public Grouper(List data, int nfold, long seed) { this.data = data; this.nfold = nfold; Collections.shuffle(data, new Random(seed)); this.starts = new int[nfold + 1]; for (int v = 0; v <= nfold; v++) { starts[v] = Math.round(data.size() * (v / (float) nfold)); } } public List getGroup(final int split) { if (split >= nfold || split < 0) { return null; } List group = new ArrayList(); for (int m = starts[split]; m < starts[split + 1]; m++) { group.add(data.get(m)); } return group; } }