/*
 * Decompiled with CFR 0.152.
 */
package org.renjin.primitives.matrix;

import com.github.fommil.netlib.BLAS;
import org.renjin.eval.EvalException;
import org.renjin.sexp.AtomicVector;
import org.renjin.sexp.AttributeMap;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.DoubleVector;
import org.renjin.sexp.ListVector;
import org.renjin.sexp.Null;
import org.renjin.sexp.SEXP;
import org.renjin.sexp.Symbols;
import org.renjin.sexp.Vector;

class MatrixProduct {
    public static final int PROD = 0;
    public static final int CROSSPROD = 1;
    public static final int TCROSSPROD = 2;
    private AtomicVector x;
    private AtomicVector y;
    private int nrx = 0;
    private int ncx = 0;
    private int nry = 0;
    private int ncy = 0;
    private int primop;
    private AtomicVector xdims;
    private AtomicVector ydims;
    private int ldx;
    private int ldy;
    private boolean sym;
    private ListVector.Builder dimnames = new ListVector.Builder(2);

    public MatrixProduct(int primop, AtomicVector x, AtomicVector y) {
        this.x = x;
        this.y = y;
        boolean bl = this.sym = y == Null.INSTANCE;
        if (this.sym && primop > 0) {
            this.y = x;
        }
        this.primop = primop;
        this.computeMatrixDims();
    }

    private void computeMatrixDims() {
        this.xdims = (AtomicVector)this.x.getAttribute(Symbols.DIM);
        this.ydims = (AtomicVector)this.y.getAttribute(Symbols.DIM);
        this.ldx = this.xdims.length();
        this.ldy = this.ydims.length();
        if (this.ldx != 2 && this.ldy != 2) {
            if (this.primop == 0) {
                this.nrx = 1;
                this.ncx = this.x.length();
            } else {
                this.nrx = this.x.length();
                this.ncx = 1;
            }
            this.nry = this.y.length();
            this.ncy = 1;
        } else if (this.ldx != 2) {
            this.nry = this.ydims.getElementAsInt(0);
            this.ncy = this.ydims.getElementAsInt(1);
            this.nrx = 0;
            this.ncx = 0;
            if (this.primop == 0) {
                if (this.x.length() == this.nry) {
                    this.nrx = 1;
                    this.ncx = this.nry;
                } else if (this.nry == 1) {
                    this.nrx = this.x.length();
                    this.ncx = 1;
                }
            } else if (this.primop == 1) {
                if (this.x.length() == this.nry) {
                    this.nrx = this.nry;
                    this.ncx = 1;
                }
            } else if (this.x.length() == this.ncy) {
                this.nrx = 1;
                this.ncx = this.ncy;
            } else if (this.ncy == 1) {
                this.nrx = this.x.length();
                this.ncx = 1;
            }
        } else if (this.ldy != 2) {
            this.nrx = this.xdims.getElementAsInt(0);
            this.ncx = this.xdims.getElementAsInt(1);
            this.nry = 0;
            this.ncy = 0;
            if (this.primop == 0) {
                if (this.y.length() == this.ncx) {
                    this.nry = this.ncx;
                    this.ncy = 1;
                } else if (this.ncx == 1) {
                    this.nry = 1;
                    this.ncy = this.y.length();
                }
            } else if (this.primop == 1) {
                if (this.y.length() == this.nrx) {
                    this.nry = this.nrx;
                    this.ncy = 1;
                }
            } else {
                this.nry = this.y.length();
                this.ncy = 1;
            }
        } else {
            this.nrx = this.xdims.getElementAsInt(0);
            this.ncx = this.xdims.getElementAsInt(1);
            this.nry = this.ydims.getElementAsInt(0);
            this.ncy = this.ydims.getElementAsInt(1);
        }
        if (this.primop == 0 && this.ncx != this.nry || this.primop == 1 && this.nrx != this.nry || this.primop == 2 && this.ncx != this.ncy) {
            throw new EvalException("non-conformable arguments", new Object[0]);
        }
    }

    public Vector matprod() {
        double[] ans = new double[this.nrx * this.ncy];
        this.matprod(this.getXArray(), this.nrx, this.ncx, this.getYArray(), this.nry, this.ncy, ans);
        Vector xdimnames = (Vector)this.x.getAttribute(Symbols.DIMNAMES);
        if (xdimnames != Null.INSTANCE && (this.ldx == 2 || this.ncx == 1)) {
            this.dimnames.set(0, (SEXP)xdimnames.getElementAsSEXP(0));
        }
        this.ydimsEtcetera();
        return this.makeMatrix(ans, this.nrx, this.ncy);
    }

    private DoubleVector makeMatrix(double[] values, int nr, int nc) {
        AttributeMap.Builder attributes2 = AttributeMap.builder();
        attributes2.setDim(nr, nc);
        attributes2.set(Symbols.DIMNAMES, (SEXP)this.buildDimnames());
        return new DoubleArrayVector(values, attributes2.build());
    }

    private Vector buildDimnames() {
        ListVector vector2 = this.dimnames.build();
        if (vector2.getElementAsSEXP(0) != Null.INSTANCE || vector2.getElementAsSEXP(1) != Null.INSTANCE) {
            return vector2;
        }
        return Null.INSTANCE;
    }

    public DoubleVector crossprod() {
        double[] ans = new double[this.ncx * this.ncy];
        if (this.sym) {
            this.symcrossprod(this.getXArray(), this.nrx, this.ncx, ans);
        } else {
            this.crossprod(this.getXArray(), this.nrx, this.ncx, this.getYArray(), this.nry, this.ncy, ans);
        }
        return this.makeMatrix(ans, this.ncx, this.ncy);
    }

    private void ydimsEtcetera() {
        Vector ydimnames = (Vector)this.y.getAttribute(Symbols.DIMNAMES);
        if (ydimnames != Null.INSTANCE) {
            if (this.ldy == 2) {
                this.dimnames.set(1, (SEXP)ydimnames.getElementAsSEXP(1));
            } else if (this.nry == 1) {
                this.dimnames.set(1, (SEXP)ydimnames.getElementAsSEXP(0));
            }
        }
    }

    public DoubleVector tcrossprod() {
        double[] ans = new double[this.nrx * this.nry];
        if (this.sym) {
            this.symtcrossprod(this.getXArray(), this.nrx, this.ncx, ans);
        } else {
            this.tcrossprod(this.getXArray(), this.nrx, this.ncx, this.getYArray(), this.nry, this.ncy, ans);
        }
        return this.makeMatrix(ans, this.nrx, this.nry);
    }

    private void symcrossprod(double[] x, int nr, int nc, double[] z) {
        String trans = "T";
        String uplo = "U";
        double one = 1.0;
        double zero = 0.0;
        if (nr > 0 && nc > 0) {
            BLAS.getInstance().dsyrk(uplo, trans, nc, nr, one, x, nr, zero, z, nc);
            for (int i = 1; i < nc; ++i) {
                for (int j = 0; j < i; ++j) {
                    z[i + nc * j] = z[j + nc * i];
                }
            }
        } else {
            for (int i = 0; i < nc * nc; ++i) {
                z[i] = 0.0;
            }
        }
    }

    private double[] getXArray() {
        return this.x.toDoubleArray();
    }

    private double[] getYArray() {
        return this.y.toDoubleArray();
    }

    private void matprod(double[] x, int nrx, int ncx, double[] y, int nry, int ncy, double[] z) {
        String transa = "N";
        String transb = "N";
        double one = 1.0;
        double zero = 0.0;
        boolean have_na = false;
        if (nrx > 0 && ncx > 0 && nry > 0 && ncy > 0) {
            int i;
            for (i = 0; i < nrx * ncx; ++i) {
                if (!Double.isNaN(x[i])) continue;
                have_na = true;
                break;
            }
            if (!have_na) {
                for (i = 0; i < nry * ncy; ++i) {
                    if (!Double.isNaN(y[i])) continue;
                    have_na = true;
                    break;
                }
            }
            if (have_na) {
                for (i = 0; i < nrx; ++i) {
                    for (int k = 0; k < ncy; ++k) {
                        double sum2 = 0.0;
                        for (int j = 0; j < ncx; ++j) {
                            sum2 += x[i + j * nrx] * y[j + k * nry];
                        }
                        z[i + k * nrx] = sum2;
                    }
                }
            } else {
                BLAS.getInstance().dgemm(transa, transb, nrx, ncy, ncx, one, x, nrx, y, nry, zero, z, nrx);
            }
        } else {
            for (int i = 0; i < nrx * ncy; ++i) {
                z[i] = 0.0;
            }
        }
    }

    private void symtcrossprod(double[] x, int nr, int nc, double[] z) {
        String trans = "N";
        String uplo = "U";
        double one = 1.0;
        double zero = 0.0;
        if (nr > 0 && nc > 0) {
            BLAS.getInstance().dsyrk(uplo, trans, nr, nc, one, x, nr, zero, z, nr);
            for (int i = 1; i < nr; ++i) {
                for (int j = 0; j < i; ++j) {
                    z[i + nr * j] = z[j + nr * i];
                }
            }
        } else {
            for (int i = 0; i < nr * nr; ++i) {
                z[i] = 0.0;
            }
        }
    }

    private void tcrossprod(double[] x, int nrx, int ncx, double[] y, int nry, int ncy, double[] z) {
        String transa = "N";
        String transb = "T";
        double one = 1.0;
        double zero = 0.0;
        if (nrx > 0 && ncx > 0 && nry > 0 && ncy > 0) {
            BLAS.getInstance().dgemm(transa, transb, nrx, nry, ncx, one, x, nrx, y, nry, zero, z, nrx);
        } else {
            for (int i = 0; i < nrx * nry; ++i) {
                z[i] = 0.0;
            }
        }
    }

    private void crossprod(double[] x, int nrx, int ncx, double[] y, int nry, int ncy, double[] z) {
        String transa = "T";
        String transb = "N";
        double one = 1.0;
        double zero = 0.0;
        if (nrx > 0 && ncx > 0 && nry > 0 && ncy > 0) {
            BLAS.getInstance().dgemm(transa, transb, ncx, ncy, nrx, one, x, nrx, y, nry, zero, z, ncx);
        } else {
            for (int i = 0; i < ncx * ncy; ++i) {
                z[i] = 0.0;
            }
        }
    }
}

