/*
 * Decompiled with CFR 0.152.
 */
package com.hiservice.translate_v5;

import com.hiservice.translate_v5.SentencePieceProcessor;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class CustomSPMEncoder {
    private final SentencePieceProcessor spModel;
    private final Map<String, Integer> specialTokens;
    private final Map<Integer, String> idToSpecialToken;
    private final Map<String, Integer> languageTokens;
    private final Map<Integer, String> idToLanguageTokens;
    private final int fairseqOffset;

    public static CustomSPMEncoder create(File mode, List<String> languages) {
        SentencePieceProcessor spModel = new SentencePieceProcessor();
        spModel.load(mode.getAbsolutePath() + File.separator + "sentencepiece.model");
        return new CustomSPMEncoder(spModel, languages);
    }

    public CustomSPMEncoder(SentencePieceProcessor spModel, List<String> languages) {
        this.spModel = spModel;
        this.fairseqOffset = 1;
        this.specialTokens = new HashMap<String, Integer>();
        this.specialTokens.put("<s>", 0);
        this.specialTokens.put("<pad>", 1);
        this.specialTokens.put("</s>", 2);
        this.specialTokens.put("<unk>", 3);
        this.idToSpecialToken = new HashMap<Integer, String>();
        for (Map.Entry<String, Integer> entry : this.specialTokens.entrySet()) {
            this.idToSpecialToken.put(entry.getValue(), entry.getKey());
        }
        int vocabSize = spModel.getPieceSize();
        this.languageTokens = new HashMap<String, Integer>();
        int numId = vocabSize;
        for (String string : languages) {
            this.languageTokens.put(string, ++numId);
        }
        this.idToLanguageTokens = new HashMap<Integer, String>();
        for (Map.Entry entry : this.languageTokens.entrySet()) {
            this.idToLanguageTokens.put((Integer)entry.getValue(), (String)entry.getKey());
        }
    }

    public int get_bos_id() {
        return this.specialTokens.get("<s>");
    }

    public int get_eos_id() {
        return this.specialTokens.get("</s>");
    }

    public List<Integer> encode(String text, String srcLang) {
        int[] spmIds = this.spModel.encodeAsIds(text);
        int srcLangID = this.languageTokens.get(srcLang);
        ArrayList<Integer> adjustedIds = new ArrayList<Integer>();
        adjustedIds.add(this.get_eos_id());
        adjustedIds.add(srcLangID);
        for (int spmId : spmIds) {
            if (spmId == this.spModel.unkId()) {
                adjustedIds.add(this.specialTokens.get("<unk>"));
                continue;
            }
            if (spmId == this.spModel.bosId()) {
                adjustedIds.add(this.specialTokens.get("<s>"));
                continue;
            }
            if (spmId == this.spModel.eosId()) {
                adjustedIds.add(this.specialTokens.get("</s>"));
                continue;
            }
            if (spmId == this.spModel.padId()) {
                adjustedIds.add(this.specialTokens.get("<pad>"));
                continue;
            }
            adjustedIds.add(spmId + this.fairseqOffset);
        }
        adjustedIds.add(this.specialTokens.get("</s>"));
        return adjustedIds;
    }

    public String decode(List<Integer> ids) {
        StringBuilder result = new StringBuilder();
        ArrayList<Integer> segment = new ArrayList<Integer>();
        for (int id : ids) {
            if (this.idToSpecialToken.containsKey(id) || this.idToLanguageTokens.containsKey(id)) {
                if (segment.isEmpty()) continue;
                int[] array = new int[segment.size()];
                for (int i = 0; i < segment.size(); ++i) {
                    array[i] = (Integer)segment.get(i);
                }
                result.append(this.spModel.decodeIds(array));
                segment.clear();
                continue;
            }
            segment.add(id - this.fairseqOffset);
        }
        if (!segment.isEmpty()) {
            int[] array = new int[segment.size()];
            for (int i = 0; i < segment.size(); ++i) {
                array[i] = (Integer)segment.get(i);
            }
            result.append(this.spModel.decodeIds(array));
        }
        return result.toString();
    }

    public int getLangId(String lang) {
        return this.languageTokens.get(lang);
    }

    public void close() {
        this.spModel.close();
        this.specialTokens.clear();
        this.idToSpecialToken.clear();
        this.languageTokens.clear();
        this.idToLanguageTokens.clear();
    }
}

