package edu.cmu.ml.rtw.pra.features;

import com.google.common.annotations.VisibleForTesting;
import edu.cmu.ml.rtw.users.matt.util.Dictionary;
import edu.cmu.ml.rtw.users.matt.util.Vector;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/features/VectorPathTypeFactory.class */
public class VectorPathTypeFactory extends BaseEdgeSequencePathTypeFactory {
    private final VectorPathType emptyPathType = new VectorPathType();
    private Vector[] embeddings;
    private final double resetWeight;
    private final double spikiness;
    private final Dictionary edgeDict;
    private final Map<Integer, Vector> embeddingsMap;

    @VisibleForTesting
    /* loaded from: input_file:edu/cmu/ml/rtw/pra/features/VectorPathTypeFactory$VectorPathType.class */
    protected class VectorPathType extends BaseEdgeSequencePathType {
        private double _spikiness;
        private double _resetWeight;
        private Vector[] _embeddings;

        public VectorPathType(int[] iArr, boolean[] zArr) {
            super(iArr, zArr);
            this._spikiness = VectorPathTypeFactory.this.spikiness;
            this._resetWeight = VectorPathTypeFactory.this.resetWeight;
            this._embeddings = VectorPathTypeFactory.this.embeddings;
        }

        private VectorPathType() {
            super(new int[0], new boolean[0]);
        }

        @Override // edu.cmu.ml.rtw.pra.features.PathType
        public PathTypeVertexCache cacheVertexInformation(Vertex vertex, int i) {
            if (i >= this.numHops) {
                return null;
            }
            Vector vector = this._embeddings[this.edgeTypes[i]];
            if (vector == null) {
                return new VectorPathTypeVertexCache(this.edgeTypes[i]);
            }
            int[] inEdgeTypes = this.reverse[i] ? vertex.getInEdgeTypes() : vertex.getOutEdgeTypes();
            if (inEdgeTypes.length == 1 && inEdgeTypes[0] == this.edgeTypes[i]) {
                return new VectorPathTypeVertexCache(this.edgeTypes[i]);
            }
            double[] dArr = new double[inEdgeTypes.length];
            double d = 0.0d;
            for (int i2 = 0; i2 < inEdgeTypes.length; i2++) {
                Vector vector2 = this._embeddings[inEdgeTypes[i2]];
                if (vector2 == null) {
                    dArr[i2] = 0.0d;
                } else {
                    double exp = Math.exp(this._spikiness * vector.dotProduct(vector2));
                    dArr[i2] = exp;
                    d += exp;
                }
            }
            return new VectorPathTypeVertexCache(inEdgeTypes, dArr, d + this._resetWeight);
        }

        @Override // edu.cmu.ml.rtw.pra.features.BaseEdgeSequencePathType
        protected int getNextEdgeType(int i, Vertex vertex, Random random, PathTypeVertexCache pathTypeVertexCache) {
            VectorPathTypeVertexCache vectorPathTypeVertexCache = (VectorPathTypeVertexCache) pathTypeVertexCache;
            if (vectorPathTypeVertexCache.isDeltaDistribution()) {
                return vectorPathTypeVertexCache.deltaEdgeType;
            }
            double nextDouble = random.nextDouble() * vectorPathTypeVertexCache.totalWeight;
            for (int i2 = 0; i2 < vectorPathTypeVertexCache.weights.length; i2++) {
                if (nextDouble < vectorPathTypeVertexCache.weights[i2]) {
                    return vectorPathTypeVertexCache.edgeTypes[i2];
                }
                nextDouble -= vectorPathTypeVertexCache.weights[i2];
            }
            return -1;
        }
    }

    public VectorPathTypeFactory(Dictionary dictionary, Map<Integer, Vector> map, double d, double d2) {
        this.spikiness = d;
        this.resetWeight = Math.exp(d * d2);
        this.edgeDict = dictionary;
        this.embeddingsMap = map;
        initializeEmbeddings();
    }

    public Dictionary getEdgeDict() {
        return this.edgeDict;
    }

    public Map<Integer, Vector> getEmbeddingsMap() {
        return this.embeddingsMap;
    }

    public void initializeEmbeddings() {
        this.embeddings = new Vector[this.edgeDict.getNextIndex() + 1];
        for (Map.Entry<Integer, Vector> entry : this.embeddingsMap.entrySet()) {
            this.embeddings[entry.getKey().intValue()] = entry.getValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.cmu.ml.rtw.pra.features.BaseEdgeSequencePathTypeFactory
    public BaseEdgeSequencePathType newInstance(int[] iArr, boolean[] zArr) {
        return new VectorPathType(iArr, zArr);
    }

    @Override // edu.cmu.ml.rtw.pra.features.PathTypeFactory
    public PathType emptyPathType() {
        return this.emptyPathType;
    }

    public Vector getEmbedding(int i) {
        return this.embeddings[i];
    }
}
