package com.robrua.nlp.bert;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import com.google.common.io.Resources;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URL;
import java.nio.IntBuffer;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

/* loaded from: input_file:com/robrua/nlp/bert/Bert.class */
public class Bert implements AutoCloseable {
    private static final int FILE_COPY_BUFFER_BYTES = 1048576;
    private static final String MODEL_DETAILS = "model.json";
    private static final String SEPARATOR_TOKEN = "[SEP]";
    private static final String START_TOKEN = "[CLS]";
    private static final String VOCAB_FILE = "vocab.txt";
    private final SavedModelBundle bundle;
    private final ModelDetails model;
    private final int separatorTokenId;
    private final int startTokenId;
    private final FullTokenizer tokenizer;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/robrua/nlp/bert/Bert$Inputs.class */
    public class Inputs implements AutoCloseable {
        private final Tensor<Integer> inputIds;
        private final Tensor<Integer> inputMask;
        private final Tensor<Integer> segmentIds;

        public Inputs(IntBuffer intBuffer, IntBuffer intBuffer2, IntBuffer intBuffer3, int i) {
            this.inputIds = Tensor.create(new long[]{i, Bert.this.model.maxSequenceLength}, intBuffer);
            this.inputMask = Tensor.create(new long[]{i, Bert.this.model.maxSequenceLength}, intBuffer2);
            this.segmentIds = Tensor.create(new long[]{i, Bert.this.model.maxSequenceLength}, intBuffer3);
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            this.inputIds.close();
            this.inputMask.close();
            this.segmentIds.close();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/robrua/nlp/bert/Bert$ModelDetails.class */
    public static class ModelDetails {
        public boolean doLowerCase;
        public String inputIds;
        public String inputMask;
        public String segmentIds;
        public String pooledOutput;
        public String sequenceOutput;
        public int maxSequenceLength;

        private ModelDetails() {
        }
    }

    public static Bert load(File file) {
        return load(Paths.get(file.toURI()));
    }

    public static Bert load(Path path) {
        Path absolutePath = path.toAbsolutePath();
        try {
            return new Bert(SavedModelBundle.load(absolutePath.toString(), new String[]{"serve"}), (ModelDetails) new ObjectMapper().readValue(absolutePath.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class), absolutePath.resolve("assets").resolve(VOCAB_FILE));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX WARN: Finally extract failed */
    public static Bert load(String str) {
        try {
            try {
                URL resource = Resources.getResource(str);
                Path createTempDirectory = Files.createTempDirectory("easy-bert-", new FileAttribute[0]);
                ZipInputStream zipInputStream = new ZipInputStream(Resources.asByteSource(resource).openBufferedStream());
                Throwable th = null;
                while (true) {
                    try {
                        ZipEntry nextEntry = zipInputStream.getNextEntry();
                        if (nextEntry == null) {
                            break;
                        }
                        Path resolve = createTempDirectory.resolve(nextEntry.getName());
                        if (nextEntry.getName().endsWith("/")) {
                            Files.createDirectories(resolve, new FileAttribute[0]);
                        } else {
                            Files.createFile(resolve, new FileAttribute[0]);
                            OutputStream newOutputStream = Files.newOutputStream(resolve, new OpenOption[0]);
                            Throwable th2 = null;
                            try {
                                try {
                                    byte[] bArr = new byte[FILE_COPY_BUFFER_BYTES];
                                    while (true) {
                                        int read = zipInputStream.read(bArr);
                                        if (read <= 0) {
                                            break;
                                        }
                                        newOutputStream.write(bArr, 0, read);
                                    }
                                    if (newOutputStream != null) {
                                        if (0 != 0) {
                                            try {
                                                newOutputStream.close();
                                            } catch (Throwable th3) {
                                                th2.addSuppressed(th3);
                                            }
                                        } else {
                                            newOutputStream.close();
                                        }
                                    }
                                } finally {
                                }
                            } catch (Throwable th4) {
                                if (newOutputStream != null) {
                                    if (th2 != null) {
                                        try {
                                            newOutputStream.close();
                                        } catch (Throwable th5) {
                                            th2.addSuppressed(th5);
                                        }
                                    } else {
                                        newOutputStream.close();
                                    }
                                }
                                throw th4;
                            }
                        }
                        zipInputStream.closeEntry();
                    } catch (Throwable th6) {
                        if (zipInputStream != null) {
                            if (0 != 0) {
                                try {
                                    zipInputStream.close();
                                } catch (Throwable th7) {
                                    th.addSuppressed(th7);
                                }
                            } else {
                                zipInputStream.close();
                            }
                        }
                        throw th6;
                    }
                }
                if (zipInputStream != null) {
                    if (0 != 0) {
                        try {
                            zipInputStream.close();
                        } catch (Throwable th8) {
                            th.addSuppressed(th8);
                        }
                    } else {
                        zipInputStream.close();
                    }
                }
                Bert load = load(createTempDirectory);
                if (createTempDirectory != null && Files.exists(createTempDirectory, new LinkOption[0])) {
                    try {
                        Files.walk(createTempDirectory, new FileVisitOption[0]).sorted(Comparator.reverseOrder()).forEach(path -> {
                            try {
                                Files.delete(path);
                            } catch (IOException e) {
                                throw new RuntimeException(e);
                            }
                        });
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }
                return load;
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
        } catch (Throwable th9) {
            if (0 != 0 && Files.exists(null, new LinkOption[0])) {
                try {
                    Files.walk(null, new FileVisitOption[0]).sorted(Comparator.reverseOrder()).forEach(path2 -> {
                        try {
                            Files.delete(path2);
                        } catch (IOException e3) {
                            throw new RuntimeException(e3);
                        }
                    });
                } catch (IOException e3) {
                    throw new RuntimeException(e3);
                }
            }
            throw th9;
        }
    }

    private Bert(SavedModelBundle savedModelBundle, ModelDetails modelDetails, Path path) {
        this.tokenizer = new FullTokenizer(path, modelDetails.doLowerCase);
        this.bundle = savedModelBundle;
        this.model = modelDetails;
        int[] convert = this.tokenizer.convert(new String[]{START_TOKEN, SEPARATOR_TOKEN});
        this.startTokenId = convert[0];
        this.separatorTokenId = convert[1];
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.bundle.close();
    }

    public float[] embedSequence(String str) {
        Inputs inputs = getInputs(str);
        Throwable th = null;
        try {
            Tensor tensor = (Tensor) this.bundle.session().runner().feed(this.model.inputIds, inputs.inputIds).feed(this.model.inputMask, inputs.inputMask).feed(this.model.segmentIds, inputs.segmentIds).fetch(this.model.pooledOutput).run().get(0);
            Throwable th2 = null;
            try {
                try {
                    float[][] fArr = new float[1][(int) tensor.shape()[1]];
                    tensor.copyTo(fArr);
                    float[] fArr2 = fArr[0];
                    if (tensor != null) {
                        if (0 != 0) {
                            try {
                                tensor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            tensor.close();
                        }
                    }
                    return fArr2;
                } finally {
                }
            } catch (Throwable th4) {
                if (tensor != null) {
                    if (th2 != null) {
                        try {
                            tensor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        tensor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (inputs != null) {
                if (0 != 0) {
                    try {
                        inputs.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    inputs.close();
                }
            }
        }
    }

    public float[][] embedSequences(Iterable<String> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        return embedSequences((String[]) newArrayList.toArray(new String[newArrayList.size()]));
    }

    public float[][] embedSequences(Iterator<String> it) {
        ArrayList newArrayList = Lists.newArrayList(it);
        return embedSequences((String[]) newArrayList.toArray(new String[newArrayList.size()]));
    }

    public float[][] embedSequences(String... strArr) {
        Inputs inputs = getInputs(strArr);
        Throwable th = null;
        try {
            Tensor tensor = (Tensor) this.bundle.session().runner().feed(this.model.inputIds, inputs.inputIds).feed(this.model.inputMask, inputs.inputMask).feed(this.model.segmentIds, inputs.segmentIds).fetch(this.model.pooledOutput).run().get(0);
            Throwable th2 = null;
            try {
                try {
                    float[][] fArr = new float[strArr.length][(int) tensor.shape()[1]];
                    tensor.copyTo(fArr);
                    if (tensor != null) {
                        if (0 != 0) {
                            try {
                                tensor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            tensor.close();
                        }
                    }
                    return fArr;
                } finally {
                }
            } catch (Throwable th4) {
                if (tensor != null) {
                    if (th2 != null) {
                        try {
                            tensor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        tensor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (inputs != null) {
                if (0 != 0) {
                    try {
                        inputs.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    inputs.close();
                }
            }
        }
    }

    public float[][][] embedTokens(Iterable<String> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        return embedTokens((String[]) newArrayList.toArray(new String[newArrayList.size()]));
    }

    public float[][][] embedTokens(Iterator<String> it) {
        ArrayList newArrayList = Lists.newArrayList(it);
        return embedTokens((String[]) newArrayList.toArray(new String[newArrayList.size()]));
    }

    public float[][] embedTokens(String str) {
        Inputs inputs = getInputs(str);
        Throwable th = null;
        try {
            Tensor tensor = (Tensor) this.bundle.session().runner().feed(this.model.inputIds, inputs.inputIds).feed(this.model.inputMask, inputs.inputMask).feed(this.model.segmentIds, inputs.segmentIds).fetch(this.model.sequenceOutput).run().get(0);
            Throwable th2 = null;
            try {
                try {
                    float[][][] fArr = new float[1][(int) tensor.shape()[1]][(int) tensor.shape()[2]];
                    tensor.copyTo(fArr);
                    float[][] fArr2 = fArr[0];
                    if (tensor != null) {
                        if (0 != 0) {
                            try {
                                tensor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            tensor.close();
                        }
                    }
                    return fArr2;
                } finally {
                }
            } catch (Throwable th4) {
                if (tensor != null) {
                    if (th2 != null) {
                        try {
                            tensor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        tensor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (inputs != null) {
                if (0 != 0) {
                    try {
                        inputs.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    inputs.close();
                }
            }
        }
    }

    public float[][][] embedTokens(String... strArr) {
        Inputs inputs = getInputs(strArr);
        Throwable th = null;
        try {
            Tensor tensor = (Tensor) this.bundle.session().runner().feed(this.model.inputIds, inputs.inputIds).feed(this.model.inputMask, inputs.inputMask).feed(this.model.segmentIds, inputs.segmentIds).fetch(this.model.sequenceOutput).run().get(0);
            Throwable th2 = null;
            try {
                try {
                    float[][][] fArr = new float[strArr.length][(int) tensor.shape()[1]][(int) tensor.shape()[2]];
                    tensor.copyTo(fArr);
                    if (tensor != null) {
                        if (0 != 0) {
                            try {
                                tensor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            tensor.close();
                        }
                    }
                    return fArr;
                } finally {
                }
            } catch (Throwable th4) {
                if (tensor != null) {
                    if (th2 != null) {
                        try {
                            tensor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        tensor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (inputs != null) {
                if (0 != 0) {
                    try {
                        inputs.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    inputs.close();
                }
            }
        }
    }

    private Inputs getInputs(String str) {
        String[] strArr = this.tokenizer.tokenize(str);
        IntBuffer allocate = IntBuffer.allocate(this.model.maxSequenceLength);
        IntBuffer allocate2 = IntBuffer.allocate(this.model.maxSequenceLength);
        IntBuffer allocate3 = IntBuffer.allocate(this.model.maxSequenceLength);
        int[] convert = this.tokenizer.convert(strArr);
        allocate.put(this.startTokenId);
        allocate2.put(1);
        allocate3.put(0);
        for (int i = 0; i < convert.length && i < this.model.maxSequenceLength - 2; i++) {
            allocate.put(convert[i]);
            allocate2.put(1);
            allocate3.put(0);
        }
        allocate.put(this.separatorTokenId);
        allocate2.put(1);
        allocate3.put(0);
        while (allocate.position() < this.model.maxSequenceLength) {
            allocate.put(0);
            allocate2.put(0);
            allocate3.put(0);
        }
        allocate.rewind();
        allocate2.rewind();
        allocate3.rewind();
        return new Inputs(allocate, allocate2, allocate3, 1);
    }

    private Inputs getInputs(String[] strArr) {
        String[][] strArr2 = this.tokenizer.tokenize(strArr);
        IntBuffer allocate = IntBuffer.allocate(strArr.length * this.model.maxSequenceLength);
        IntBuffer allocate2 = IntBuffer.allocate(strArr.length * this.model.maxSequenceLength);
        IntBuffer allocate3 = IntBuffer.allocate(strArr.length * this.model.maxSequenceLength);
        int i = 1;
        for (String[] strArr3 : strArr2) {
            int[] convert = this.tokenizer.convert(strArr3);
            allocate.put(this.startTokenId);
            allocate2.put(1);
            allocate3.put(0);
            for (int i2 = 0; i2 < convert.length && i2 < this.model.maxSequenceLength - 2; i2++) {
                allocate.put(convert[i2]);
                allocate2.put(1);
                allocate3.put(0);
            }
            allocate.put(this.separatorTokenId);
            allocate2.put(1);
            allocate3.put(0);
            while (allocate.position() < this.model.maxSequenceLength * i) {
                allocate.put(0);
                allocate2.put(0);
                allocate3.put(0);
            }
            i++;
        }
        allocate.rewind();
        allocate2.rewind();
        allocate3.rewind();
        return new Inputs(allocate, allocate2, allocate3, strArr.length);
    }
}
