package us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.groundContactForce;

import java.util.List;
import java.util.Map;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import us.ihmc.commonWalkingControlModules.controllerCore.command.inverseDynamics.MomentumRateCommand;
import us.ihmc.commonWalkingControlModules.controllerCore.command.inverseDynamics.PlaneContactStateCommand;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.ControllerCoreOptimizationSettings;
import us.ihmc.commonWalkingControlModules.visualizer.BasisVectorVisualizer;
import us.ihmc.commonWalkingControlModules.wrenchDistribution.WrenchMatrixCalculator;
import us.ihmc.commons.PrintTools;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.graphicsDescription.yoGraphics.YoGraphicsListRegistry;
import us.ihmc.humanoidRobotics.bipedSupportPolygons.ContactablePlaneBody;
import us.ihmc.robotics.math.frames.YoFrameVector;
import us.ihmc.robotics.math.frames.YoMatrix;
import us.ihmc.robotics.screwTheory.RigidBody;
import us.ihmc.robotics.screwTheory.Wrench;
import us.ihmc.tools.exceptions.NoConvergenceException;
import us.ihmc.yoVariables.registry.YoVariableRegistry;
import us.ihmc.yoVariables.variable.YoBoolean;
import us.ihmc.yoVariables.variable.YoDouble;
import us.ihmc.yoVariables.variable.YoInteger;

/* loaded from: input_file:us/ihmc/commonWalkingControlModules/momentumBasedController/optimization/groundContactForce/GroundContactForceOptimizationControlModule.class */
public class GroundContactForceOptimizationControlModule {
    private static final boolean DEBUG = true;
    private static final boolean VISUALIZE_RHO_BASIS_VECTORS = false;
    private static final boolean SETUP_RHO_TASKS = true;
    private final WrenchMatrixCalculator wrenchMatrixCalculator;
    private final List<? extends ContactablePlaneBody> contactablePlaneBodies;
    private final BasisVectorVisualizer basisVectorVisualizer;
    private final GroundContactForceQPSolver qpSolver;
    private final YoFrameVector desiredLinearMomentumRate;
    private final YoFrameVector desiredAngularMomentumRate;
    private Map<RigidBody, Wrench> solutionWrenches;
    private final YoVariableRegistry registry = new YoVariableRegistry(getClass().getSimpleName());
    private final YoDouble rhoMin = new YoDouble("rhoMinGCFOptimization", this.registry);
    private final YoBoolean hasNotConvergedInPast = new YoBoolean("hasNotConvergedInPast", this.registry);
    private final YoInteger hasNotConvergedCounts = new YoInteger("hasNotConvergedCounts", this.registry);
    private final MomentumRateCommand momentumRateCommand = new MomentumRateCommand();
    private final YoMatrix yoMomentumSelectionMatrix = new YoMatrix("VMCMomentumSelectionMatrix", 6, 6, this.registry);
    private final YoMatrix yoMomentumObjective = new YoMatrix("VMCMomentumObjectiveMatrix", 6, 1, this.registry);
    private final YoMatrix yoMomentumWeight = new YoMatrix("VMCMomentumWeightMatrix", 6, 6, this.registry);
    private final DenseMatrix64F momentumSelectionMatrix = new DenseMatrix64F(6, 1);
    private final DenseMatrix64F momentumObjective = new DenseMatrix64F(6, 1);
    private final DenseMatrix64F momentumJacobian = new DenseMatrix64F(6, 1);
    private final DenseMatrix64F momentumWeight = new DenseMatrix64F(6, 1);
    private final DenseMatrix64F tempTaskWeight = new DenseMatrix64F(6, 6);
    private final DenseMatrix64F tempTaskWeightSubspace = new DenseMatrix64F(6, 6);
    private final DenseMatrix64F fullMomentumObjective = new DenseMatrix64F(6, 1);

    public GroundContactForceOptimizationControlModule(WrenchMatrixCalculator wrenchMatrixCalculator, List<? extends ContactablePlaneBody> list, ControllerCoreOptimizationSettings controllerCoreOptimizationSettings, YoVariableRegistry yoVariableRegistry, YoGraphicsListRegistry yoGraphicsListRegistry) {
        this.wrenchMatrixCalculator = wrenchMatrixCalculator;
        this.contactablePlaneBodies = list;
        int rhoSize = controllerCoreOptimizationSettings.getRhoSize();
        this.basisVectorVisualizer = null;
        this.desiredLinearMomentumRate = new YoFrameVector("desiredLinearMomentumRateToQP", (ReferenceFrame) null, this.registry);
        this.desiredAngularMomentumRate = new YoFrameVector("desiredAngularMomentumRateToQP", (ReferenceFrame) null, this.registry);
        this.rhoMin.set(controllerCoreOptimizationSettings.getRhoMin());
        this.qpSolver = new GroundContactForceQPSolver(controllerCoreOptimizationSettings.getActiveSetQPSolver(), rhoSize, this.registry);
        this.qpSolver.setMinRho(controllerCoreOptimizationSettings.getRhoMin());
        yoVariableRegistry.addChild(this.registry);
    }

    public void initialize() {
        this.qpSolver.reset();
    }

    public void compute(Map<RigidBody, Wrench> map) throws NoConvergenceException {
        this.qpSolver.setRhoRegularizationWeight(this.wrenchMatrixCalculator.getRhoWeightMatrix());
        this.qpSolver.addRegularization();
        setupRhoTasks();
        this.qpSolver.setMinRho(this.rhoMin.getDoubleValue());
        this.qpSolver.addMomentumTask(this.momentumJacobian, this.momentumObjective, this.momentumWeight);
        Throwable th = VISUALIZE_RHO_BASIS_VECTORS;
        try {
            this.qpSolver.solve();
        } catch (NoConvergenceException e) {
            if (!this.hasNotConvergedInPast.getBooleanValue()) {
                e.printStackTrace();
                PrintTools.warn(this, "Only showing the stack trace of the first " + e.getClass().getSimpleName() + ". This may be happening more than once.");
            }
            this.hasNotConvergedInPast.set(true);
            this.hasNotConvergedCounts.increment();
            th = e;
        }
        DenseMatrix64F rhos = this.qpSolver.getRhos();
        if (th != null) {
            throw th;
        }
        this.solutionWrenches = this.wrenchMatrixCalculator.computeWrenchesFromRho(rhos);
        for (int i = VISUALIZE_RHO_BASIS_VECTORS; i < this.contactablePlaneBodies.size(); i++) {
            RigidBody rigidBody = this.contactablePlaneBodies.get(i).getRigidBody();
            Wrench wrench = this.solutionWrenches.get(rigidBody);
            if (map.containsKey(rigidBody)) {
                map.get(rigidBody).set(wrench);
            } else {
                map.put(rigidBody, wrench);
            }
        }
    }

    private void setupRhoTasks() {
        this.qpSolver.addRhoTask(this.wrenchMatrixCalculator.getRhoPreviousMatrix(), this.wrenchMatrixCalculator.getRhoRateWeightMatrix());
        DenseMatrix64F copJacobianMatrix = this.wrenchMatrixCalculator.getCopJacobianMatrix();
        this.qpSolver.addRhoTask(copJacobianMatrix, this.wrenchMatrixCalculator.getPreviousCoPMatrix(), this.wrenchMatrixCalculator.getCopRateWeightMatrix());
        this.qpSolver.addRhoTask(copJacobianMatrix, this.wrenchMatrixCalculator.getDesiredCoPMatrix(), this.wrenchMatrixCalculator.getDesiredCoPWeightMatrix());
    }

    public void submitMomentumRateCommand(MomentumRateCommand momentumRateCommand) {
        this.momentumRateCommand.set(momentumRateCommand);
        this.momentumRateCommand.setWeights(momentumRateCommand.getWeightVector());
    }

    public void submitPlaneContactStateCommand(PlaneContactStateCommand planeContactStateCommand) {
        this.wrenchMatrixCalculator.submitPlaneContactStateCommand(planeContactStateCommand);
    }

    public void submitMomentumSelectionMatrix(DenseMatrix64F denseMatrix64F) {
        this.momentumSelectionMatrix.set(denseMatrix64F);
        this.yoMomentumSelectionMatrix.set(denseMatrix64F);
    }

    public void processMomentumRateCommand(DenseMatrix64F denseMatrix64F) {
        this.wrenchMatrixCalculator.computeMatrices();
        int numRows = this.momentumSelectionMatrix.getNumRows();
        this.momentumObjective.reshape(numRows, 1);
        this.momentumWeight.reshape(numRows, numRows);
        if (numRows == 0) {
            return;
        }
        this.tempTaskWeight.reshape(6, 6);
        this.tempTaskWeightSubspace.reshape(numRows, 6);
        this.momentumRateCommand.getWeightMatrix(this.tempTaskWeight);
        CommonOps.mult(this.momentumSelectionMatrix, this.tempTaskWeight, this.tempTaskWeightSubspace);
        CommonOps.multTransB(this.tempTaskWeightSubspace, this.momentumSelectionMatrix, this.momentumWeight);
        this.yoMomentumWeight.set(this.momentumWeight);
        DenseMatrix64F rhoJacobianMatrix = this.wrenchMatrixCalculator.getRhoJacobianMatrix();
        this.momentumJacobian.reshape(numRows, rhoJacobianMatrix.numCols);
        CommonOps.mult(this.momentumSelectionMatrix, rhoJacobianMatrix, this.momentumJacobian);
        CommonOps.subtract(this.momentumRateCommand.getMomentumRate(), denseMatrix64F, this.fullMomentumObjective);
        CommonOps.mult(this.momentumSelectionMatrix, this.fullMomentumObjective, this.momentumObjective);
        this.yoMomentumObjective.set(this.momentumObjective);
        CommonOps.multTransA(this.momentumSelectionMatrix, this.momentumObjective, this.fullMomentumObjective);
        this.desiredLinearMomentumRate.set(this.fullMomentumObjective.get(3), this.fullMomentumObjective.get(4), this.fullMomentumObjective.get(5));
        this.desiredAngularMomentumRate.set(this.fullMomentumObjective.get(VISUALIZE_RHO_BASIS_VECTORS), this.fullMomentumObjective.get(1), this.fullMomentumObjective.get(2));
    }

    public DenseMatrix64F getMomentumObjective() {
        return this.fullMomentumObjective;
    }
}
