/*
 * Decompiled with CFR 0.152.
 */
package fr.lri.tao.apro.ap;

import cern.colt.list.DoubleArrayList;
import cern.colt.list.IntArrayList;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.impl.SparseDoubleMatrix2D;
import fr.lri.tao.apro.ap.AbstractApro;
import fr.lri.tao.apro.ap.Group;
import fr.lri.tao.apro.ap.GroupWorker;
import fr.lri.tao.apro.data.DataProvider;
import fr.lri.tao.apro.util.Logger;
import fr.lri.tao.numa.NUMA;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Apro
extends AbstractApro {
    List<Group> groups = new ArrayList<Group>();
    private double damp = 0.5;
    private int n;
    private DataProvider provider;
    private DoubleMatrix2D s;
    double[][] r;
    double[] sumr;
    private final int groupCount;
    private Long runTime = null;
    private final boolean useNuma;
    private int numaStartNode = 0;
    private Integer numaCoresPerNode = null;
    private Integer numaNumNodes = null;
    private boolean debug = true;

    public Apro(DataProvider provider, int groupCount) {
        this(provider, groupCount, false, null, null, null);
    }

    public Apro(DataProvider provider, int groupCount, boolean useNuma) {
        this(provider, groupCount, useNuma, null, null, null);
    }

    public Apro(DataProvider provider, int groupCount, Integer numaNumNodes, Integer numaCoresPerNode, Integer numaStartNode) {
        this(provider, groupCount, true, numaNumNodes, numaCoresPerNode, numaStartNode);
    }

    public Apro(DataProvider provider, int groupCount, boolean useNuma, Integer numaNumNodes, Integer numaCoresPerNode, Integer numaStartNode) {
        if (useNuma && !NUMA.isAvailable()) {
            Logger.warn("[NUMA] NUMA library not available. Switching off.", new Object[0]);
            this.useNuma = false;
        } else {
            this.useNuma = useNuma;
        }
        if (this.useNuma) {
            this.setNumaCoresPerNode(numaCoresPerNode);
            this.setNumaNumNodes(numaNumNodes);
            this.setNumaStartNode(numaStartNode);
            Logger.info("[NUMA] Number of nodes: %d, cores per node: %d, start node: %d", this.numaNumNodes, this.numaCoresPerNode, this.numaStartNode);
        } else {
            Logger.warn("[NUMA] NUMA is off", new Object[0]);
        }
        this.provider = provider;
        this.n = provider.size();
        this.groupCount = groupCount;
        if (groupCount > this.n) {
            Logger.warn("Number of required threads is greater than the number of nodes. Setting groupCount to " + this.n, new Object[0]);
            groupCount = this.n;
        }
    }

    private void init() {
        long startInit = System.currentTimeMillis();
        this.s = this.provider.getMatrix();
        this.r = new double[this.n][this.n];
        this.sumr = new double[this.n];
        int minSize = Integer.MAX_VALUE;
        int maxSize = Integer.MIN_VALUE;
        for (int gid = 0; gid < this.groupCount; ++gid) {
            if (this.useNuma) {
                NUMA.allocOnNode(this.getNumaNode(gid));
            }
            Group group = new Group(this, gid);
            this.groups.add(group);
            minSize = Math.min(minSize, group.size);
            maxSize = Math.max(maxSize, group.size);
        }
        if (this.useNuma) {
            NUMA.localAlloc();
        }
        IntArrayList is = new IntArrayList();
        IntArrayList ks = new IntArrayList();
        DoubleArrayList vs = new DoubleArrayList();
        this.getS().getNonZeros(is, ks, vs);
        for (int j = 0; j < is.size(); ++j) {
            int i = is.get(j);
            Group group = this.getGroup(i);
            int n = i - group.startIndex;
            group.lengths[n] = group.lengths[n] + 1;
        }
        for (Group group : this.groups) {
            if (this.useNuma) {
                NUMA.allocOnNode(group.getNumaNode());
            }
            for (int i = 0; i < group.size; ++i) {
                int len = group.lengths[i];
                group.indices[i] = new int[len];
                group.s[i] = new double[len];
                group.a[i] = new double[len];
            }
        }
        if (this.useNuma) {
            NUMA.localAlloc();
        }
        for (int j = 0; j < is.size(); ++j) {
            int i = is.get(j);
            Group group = this.getGroup(i);
            int k = ks.get(j);
            double v = vs.get(j);
            int next = group.t[i - group.startIndex];
            group.indices[i - group.startIndex][next] = k;
            group.s[i - group.startIndex][next] = v;
            int n = i - group.startIndex;
            group.t[n] = group.t[n] + 1;
        }
        long initTime = System.currentTimeMillis() - startInit;
        Logger.info("[Init] %d groups created (%d - %d elements each). Init time %d ms", this.groupCount, minSize, maxSize, initTime);
    }

    public void setDebug(boolean debug) {
        this.debug = debug;
    }

    public void setDamping(double dampingFactor) {
        this.damp = dampingFactor;
    }

    public double getDamping() {
        return this.damp;
    }

    private void setNumaStartNode(Integer node) {
        this.numaStartNode = node == null ? NUMA.getNode() : node;
    }

    private void setNumaCoresPerNode(Integer cpn) {
        this.numaCoresPerNode = cpn == null ? Integer.valueOf(NUMA.getCoresPerNode()) : cpn;
    }

    private void setNumaNumNodes(Integer nn) {
        this.numaNumNodes = nn == null ? Integer.valueOf(NUMA.getNumNodes()) : nn;
    }

    Integer getNumaNode(int gid) {
        int node = gid / this.numaCoresPerNode;
        node = (node + this.numaStartNode) % this.numaNumNodes;
        return node;
    }

    int startIndex(int gid) {
        return this.n / this.groupCount * gid;
    }

    int endIndex(int gid) {
        int groupSize = this.n / this.groupCount;
        int start = groupSize * gid;
        if (this.groupCount == gid + 1) {
            return this.n;
        }
        if (start + groupSize > this.n) {
            return this.n;
        }
        return start + groupSize;
    }

    public final int getGroupId(int nodeId) {
        int size = this.n / this.groupCount;
        int gid = nodeId / size;
        if (gid >= this.groupCount) {
            gid = this.groupCount - 1;
        }
        return gid;
    }

    public Group getGroup(int nodeId) {
        Group group = this.groups.get(this.getGroupId(nodeId));
        return group;
    }

    synchronized void updateSums(Group group) {
        for (int k = 0; k < this.n; ++k) {
            int n = k;
            this.sumr[n] = this.sumr[n] + group.sums[k];
        }
    }

    public void run(int iters) {
        this.init();
        long startTime = System.currentTimeMillis();
        ArrayList<GroupWorker> workers = new ArrayList<GroupWorker>();
        for (Group group : this.groups) {
            GroupWorker worker = new GroupWorker(group, this.useNuma);
            worker.start();
            workers.add(worker);
        }
        Logger.info("[Apro] Working...", new Object[0]);
        for (int i = 0; i < iters; ++i) {
            if (this.debug && (i + 1) % (iters / 10) == 0) {
                System.out.print('.');
            }
            Arrays.fill(this.sumr, 0.0);
            for (GroupWorker worker : workers) {
                worker.responsibilities();
            }
            for (GroupWorker worker : workers) {
                worker.waitTask();
            }
            for (GroupWorker worker : workers) {
                worker.availabilities();
            }
            for (GroupWorker worker : workers) {
                worker.waitTask();
            }
        }
        if (this.debug) {
            System.out.println();
        }
        for (GroupWorker groupWorker : workers) {
            groupWorker.done();
            try {
                groupWorker.join();
            }
            catch (InterruptedException interruptedException) {}
        }
        long endTime = System.currentTimeMillis();
        this.runTime = endTime - startTime;
        Logger.info("[Result] Groups: %d; Time: %d ms", this.groupCount, this.runTime);
    }

    public Long getRunTime() {
        return this.runTime;
    }

    @Override
    public int getN() {
        return this.n;
    }

    @Override
    public DoubleMatrix2D getR() {
        return new DenseDoubleMatrix2D(this.r);
    }

    @Override
    public DoubleMatrix2D getS() {
        return this.s;
    }

    @Override
    public DoubleMatrix2D getA() {
        SparseDoubleMatrix2D a = new SparseDoubleMatrix2D(this.n, this.n);
        for (Group group : this.groups) {
            for (int i = 0; i < group.size; ++i) {
                int len = group.lengths[i];
                for (int j = 0; j < len; ++j) {
                    int k = group.indices[i][j];
                    a.setQuick(i + group.startIndex, k, group.a[i][j]);
                }
            }
        }
        return a;
    }

    private void identifyExemplars() {
        DoubleMatrix2D ar = this.getAR();
        int c = 0;
        for (int i = 0; i < this.n; ++i) {
            if (!(ar.get(i, i) > 0.0)) continue;
            ++c;
        }
        System.out.println("AR diagonal > 0: " + c);
    }
}

