/*
 * Decompiled with CFR 0.152.
 */
package com.hiservice.translate_v5;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import android.util.Log;
import android.util.Pair;
import androidx.collection.ArraySet;
import com.hiservice.tools.TensorUtils;
import com.hiservice.tools.Utils;
import com.hiservice.translate_v5.CustomSPMEncoder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public abstract class RunOnnxReRAMWithBeam {
    private final int EMPTY_BATCH_SIZE = 1;
    private OrtEnvironment onnxEnv;
    private OrtSession encoderSession;
    private OrtSession decoderSession;
    private OrtSession cacheInitSession;
    private OrtSession embedAndLmheadSession;
    private CustomSPMEncoder tokenizer;
    private long currentResultID = 0L;
    private boolean isLoad = false;
    int d_model = 512;
    int attention_heads = 8;
    int decoder_layers = 12;
    int encoder_layers = 6;
    int head_dim = this.d_model / this.attention_heads;
    int max_length = 200;

    public void load() {
        List<Object> list = this.create();
        this.onnxEnv = (OrtEnvironment)list.get(0);
        this.encoderSession = (OrtSession)list.get(1);
        this.decoderSession = (OrtSession)list.get(2);
        this.cacheInitSession = (OrtSession)list.get(3);
        this.embedAndLmheadSession = (OrtSession)list.get(4);
        this.tokenizer = (CustomSPMEncoder)list.get(5);
    }

    public abstract List<Object> create();

    public void closeModel() throws OrtException {
        this.embedAndLmheadSession.close();
        this.cacheInitSession.close();
        this.decoderSession.close();
        this.encoderSession.close();
        this.onnxEnv.close();
        this.tokenizer.close();
    }

    public boolean getLoadStatus() {
        return this.isLoad;
    }

    private String translateSingleText(String inputText, String srcLang, String tgtLang, int numBeam) {
        try {
            List<Integer> inputsIds = this.tokenizer.encode(inputText, srcLang);
            int[][] inputsIdsArray = new int[1][inputsIds.size()];
            for (int i = 0; i < inputsIds.size(); ++i) {
                inputsIdsArray[0][i] = inputsIds.get(i);
            }
            int[][] attentionMask = new int[1][inputsIds.size()];
            Arrays.fill(attentionMask[0], 1);
            Log.i((String)"LH", (String)("inputsIdsArray " + Arrays.deepToString((Object[])inputsIdsArray)));
            Log.i((String)"LH", (String)("attentionMask " + Arrays.deepToString((Object[])attentionMask)));
            OnnxTensor inputIDsTensor = Utils.createTensor(this.onnxEnv, inputsIdsArray);
            OnnxTensor attMaskTensor = Utils.createTensor(this.onnxEnv, attentionMask);
            OnnxTensor emptyPreLogits = TensorUtils.createFloatTensorWithSingleValue(this.onnxEnv, 0.0f, new long[]{1L, 1L, this.d_model});
            HashMap<String, OnnxTensor> embedInput = new HashMap<String, OnnxTensor>();
            embedInput.put("input_ids", inputIDsTensor);
            embedInput.put("pre_logits", emptyPreLogits);
            embedInput.put("use_lm_head", TensorUtils.convertBooleanToTensor(this.onnxEnv, false));
            ArraySet embRequestedOutputs = new ArraySet();
            embRequestedOutputs.add((Object)"embed_matrix");
            OrtSession.Result embedResult = this.embedAndLmheadSession.run(embedInput, (Set)embRequestedOutputs);
            OnnxTensor embed_matrix = (OnnxTensor)embedResult.get(0);
            Log.i((String)"LH", (String)("embed_matrix " + embed_matrix.getValue().toString()));
            HashMap<String, OnnxTensor> encoderInput = new HashMap<String, OnnxTensor>();
            encoderInput.put("input_ids", inputIDsTensor);
            encoderInput.put("attention_mask", attMaskTensor);
            encoderInput.put("embed_matrix", embed_matrix);
            OrtSession.Result encoderResult = this.encoderSession.run(encoderInput);
            OnnxTensor last_hidden_state = (OnnxTensor)encoderResult.get(0);
            Log.i((String)"LH", (String)("last_hidden_state " + last_hidden_state.getValue().toString()));
            HashMap<String, OnnxTensor> cacheInitInput = new HashMap<String, OnnxTensor>();
            cacheInitInput.put("encoder_hidden_states", last_hidden_state);
            OrtSession.Result cacheInitResult = this.cacheInitSession.run(cacheInitInput);
            OnnxTensor emptyInputIds = TensorUtils.createInt64TensorWithSingleValue(this.onnxEnv, 0L, new long[]{1L, 1L});
            HashMap<String, OnnxTensor> lmHeadInput = new HashMap<String, OnnxTensor>();
            lmHeadInput.put("input_ids", emptyInputIds);
            lmHeadInput.put("use_lm_head", TensorUtils.convertBooleanToTensor(this.onnxEnv, true));
            ArraySet lmHeadRequestedOutputs = new ArraySet();
            lmHeadRequestedOutputs.add((Object)"logits");
            int startTokenId = 2;
            int[][] decoderInputsIds = new int[][]{{startTokenId}};
            int[][] decoderAttentionMask = new int[][]{{1}};
            attMaskTensor = Utils.createTensor(this.onnxEnv, decoderAttentionMask);
            HashMap<String, Object> decoderInput = new HashMap<String, Object>();
            HashMap<String, OnnxTensor> init_past_key_values = new HashMap<String, OnnxTensor>();
            for (int layer = 0; layer < this.decoder_layers; ++layer) {
                init_past_key_values.put("past_key_values." + layer + ".decoder.key", Utils.createTensor(this.onnxEnv, new long[]{1L, this.attention_heads, 0L, this.head_dim}, new float[0]));
                init_past_key_values.put("past_key_values." + layer + ".decoder.value", Utils.createTensor(this.onnxEnv, new long[]{1L, this.attention_heads, 0L, this.head_dim}, new float[0]));
                init_past_key_values.put("past_key_values." + layer + ".encoder.key", (OnnxTensor)cacheInitResult.get(layer * 2));
                init_past_key_values.put("past_key_values." + layer + ".encoder.value", (OnnxTensor)cacheInitResult.get(layer * 2 + 1));
            }
            inputIDsTensor = Utils.createTensor(this.onnxEnv, decoderInputsIds);
            embedInput.put("input_ids", inputIDsTensor);
            embedResult.close();
            embedResult = this.embedAndLmheadSession.run(embedInput, (Set)embRequestedOutputs);
            embed_matrix = (OnnxTensor)embedResult.get(0);
            decoderInput.put("input_ids", inputIDsTensor);
            decoderInput.put("encoder_attention_mask", attMaskTensor);
            decoderInput.put("embed_matrix", embed_matrix);
            decoderInput.putAll(init_past_key_values);
            OrtSession.Result initDecoderResult = this.decoderSession.run(decoderInput);
            HashMap<String, OnnxTensor> startPastKeyValues = new HashMap<String, OnnxTensor>();
            for (int layer = 0; layer < this.decoder_layers; ++layer) {
                startPastKeyValues.put("past_key_values." + layer + ".decoder.key", (OnnxTensor)initDecoderResult.get(layer * 2 + 1));
                startPastKeyValues.put("past_key_values." + layer + ".decoder.value", (OnnxTensor)initDecoderResult.get(layer * 2 + 2));
            }
            ArrayList<Integer> startSeq = new ArrayList<Integer>();
            startSeq.add(this.tokenizer.getLangId(tgtLang));
            ArrayList<BeamSearchCandidate> candidates = new ArrayList<BeamSearchCandidate>();
            candidates.add(new BeamSearchCandidate(startSeq, 0.0, startPastKeyValues));
            embedResult.close();
            encoderResult.close();
            int eosId = this.tokenizer.get_eos_id();
            ArrayList[] preDecoderResultsCache = new ArrayList[]{new ArrayList(), new ArrayList()};
            preDecoderResultsCache[1].add(initDecoderResult);
            PriorityQueue<BeamSearchCandidate> newCandidate = new PriorityQueue<BeamSearchCandidate>(10, new Comparator<BeamSearchCandidate>(){

                @Override
                public int compare(BeamSearchCandidate o1, BeamSearchCandidate o2) {
                    return Double.compare(o1.getScore(), o2.getScore());
                }
            });
            for (int step = 1; step <= this.max_length; ++step) {
                Log.i((String)"LH", (String)("step: " + step));
                for (BeamSearchCandidate candidate : candidates) {
                    int tokenId = candidate.getSeqLastTokenId();
                    inputIDsTensor = Utils.createTensor(this.onnxEnv, new int[][]{{tokenId}});
                    Map<String, OnnxTensor> pastKeyValues = candidate.getPastKeyValues();
                    for (int layer = 0; layer < this.decoder_layers; ++layer) {
                        pastKeyValues.put("past_key_values." + layer + ".encoder.key", (OnnxTensor)cacheInitResult.get(layer * 2));
                        pastKeyValues.put("past_key_values." + layer + ".encoder.value", (OnnxTensor)cacheInitResult.get(layer * 2 + 1));
                    }
                    embedInput.put("input_ids", inputIDsTensor);
                    embedResult = this.embedAndLmheadSession.run(embedInput, (Set)embRequestedOutputs);
                    embed_matrix = (OnnxTensor)embedResult.get(0);
                    decoderInput.put("input_ids", inputIDsTensor);
                    decoderInput.put("encoder_attention_mask", attMaskTensor);
                    decoderInput.put("embed_matrix", embed_matrix);
                    decoderInput.putAll(pastKeyValues);
                    OrtSession.Result decoderResult = this.decoderSession.run(decoderInput);
                    preDecoderResultsCache[step % 2].add(decoderResult);
                    OnnxTensor decode_emb = (OnnxTensor)decoderResult.get(0);
                    lmHeadInput.put("pre_logits", decode_emb);
                    OrtSession.Result lmHeadResult = this.embedAndLmheadSession.run(lmHeadInput, (Set)lmHeadRequestedOutputs);
                    OnnxTensor logits = (OnnxTensor)lmHeadResult.get(0);
                    float[][][] logitsValue = (float[][][])logits.getValue();
                    embedResult.close();
                    lmHeadResult.close();
                    List<Pair<Integer, Float>> topKResult = Utils.argTopK(logitsValue[0][0], numBeam);
                    Log.i((String)"LH", (String)("topKResult: " + topKResult.size()));
                    for (Pair<Integer, Float> ind_score : topKResult) {
                        Log.i((String)"LH", (String)("nextTokenId: " + ind_score.first + ", score: " + ind_score.second));
                        int nextTokenId = (Integer)ind_score.first;
                        double nextScore = candidate.getScore() + (double)((Float)ind_score.second).floatValue();
                        ArrayList<Integer> nextSeq = new ArrayList<Integer>(candidate.getSequence());
                        nextSeq.add(nextTokenId);
                        HashMap<String, OnnxTensor> nextPastKeyValues = new HashMap<String, OnnxTensor>();
                        for (int layer = 0; layer < this.decoder_layers; ++layer) {
                            nextPastKeyValues.put("past_key_values." + layer + ".decoder.key", (OnnxTensor)decoderResult.get(layer * 2 + 1));
                            nextPastKeyValues.put("past_key_values." + layer + ".decoder.value", (OnnxTensor)decoderResult.get(layer * 2 + 2));
                        }
                        BeamSearchCandidate nextCand = new BeamSearchCandidate(nextSeq, nextScore, nextPastKeyValues);
                        if (newCandidate.size() < numBeam) {
                            newCandidate.offer(nextCand);
                            continue;
                        }
                        if (!(nextCand.getScore() > newCandidate.peek().getScore())) continue;
                        BeamSearchCandidate removeCand = newCandidate.poll();
                        removeCand.close();
                        newCandidate.offer(nextCand);
                    }
                    candidate.close();
                }
                for (OrtSession.Result preDecodeRes : preDecoderResultsCache[(step + 1) % 2]) {
                    preDecodeRes.close();
                }
                preDecoderResultsCache[(step + 1) % 2].clear();
                candidates.clear();
                candidates.addAll(newCandidate);
                newCandidate.clear();
                for (BeamSearchCandidate candidate : candidates) {
                    Log.i((String)"LH", (String)(candidate.sequence.toString() + " " + candidate.score));
                }
                if (this.containsEosId(candidates, eosId)) break;
            }
            cacheInitResult.close();
            ArrayList<BeamSearchCandidate> validEosCandidates = new ArrayList<BeamSearchCandidate>();
            for (BeamSearchCandidate candidate : candidates) {
                if (candidate.getSeqLastTokenId() != eosId) continue;
                validEosCandidates.add(candidate);
            }
            BeamSearchCandidate bestCandidate = !validEosCandidates.isEmpty() ? this.findMaxScoreCandidate(validEosCandidates) : this.findMaxScoreCandidate(candidates);
            Log.i((String)"LH", (String)("decode: " + bestCandidate.sequence.toString() + "  " + bestCandidate.getScore()));
            String outputText = this.tokenizer.decode(bestCandidate.sequence);
            for (BeamSearchCandidate candidate : candidates) {
                candidate.close();
            }
            return outputText;
        }
        catch (OrtException e) {
            throw new RuntimeException(e);
        }
    }

    private BeamSearchCandidate findMaxScoreCandidate(List<BeamSearchCandidate> candidates) {
        if (candidates == null || candidates.isEmpty()) {
            return null;
        }
        BeamSearchCandidate maxCandidate = candidates.get(0);
        for (BeamSearchCandidate candidate : candidates) {
            if (!(candidate.getScore() > maxCandidate.getScore())) continue;
            maxCandidate = candidate;
        }
        return maxCandidate;
    }

    private boolean containsEosId(List<BeamSearchCandidate> beamCandidates, int eosId) {
        for (BeamSearchCandidate candidate : beamCandidates) {
            if (candidate.getSeqLastTokenId() != eosId) continue;
            return true;
        }
        return false;
    }

    protected String translate(String inputText, String srcLang, String tgtLang, int numBeam) throws OrtException {
        try {
            String pattern = "([.!?\uff01\uff1f\u3002]+)";
            Pattern r = Pattern.compile(pattern);
            Matcher m = r.matcher(inputText);
            ArrayList<String> segments = new ArrayList<String>();
            int lastIndex = 0;
            while (m.find()) {
                int currentIndex = m.start();
                segments.add(inputText.substring(lastIndex, currentIndex) + m.group());
                lastIndex = m.end();
            }
            if (lastIndex < inputText.length()) {
                segments.add(inputText.substring(lastIndex));
            }
            Log.i((String)"LH", (String)("Segments: " + segments));
            ArrayList<String> res = new ArrayList<String>();
            for (String segment : segments) {
                res.add(this.translateSingleText(segment, srcLang, tgtLang, numBeam));
            }
            String string = String.join((CharSequence)" ", res);
            return string;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static class BeamSearchCandidate {
        private List<Integer> sequence;
        private double score;
        private Map<String, OnnxTensor> pastKeyValues;

        public BeamSearchCandidate(List<Integer> sequence, double score, Map<String, OnnxTensor> nextPastKeyValues) {
            this.sequence = sequence;
            this.score = score;
            this.pastKeyValues = nextPastKeyValues;
        }

        public int getSeqLastTokenId() {
            return this.sequence.get(this.sequence.size() - 1);
        }

        public Map<String, OnnxTensor> getPastKeyValues() {
            return this.pastKeyValues;
        }

        public double getScore() {
            return this.score;
        }

        public List<Integer> getSequence() {
            return this.sequence;
        }

        public void close() {
            this.sequence.clear();
            this.pastKeyValues.clear();
        }
    }
}

