/*
 * Decompiled with CFR 0.152.
 */
package ai.djl;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseModel
implements Model {
    private static final Logger logger = LoggerFactory.getLogger(BaseModel.class);
    private static final int MODEL_VERSION = 1;
    protected Path modelDir;
    protected Block block;
    protected String modelName;
    protected NDManager manager;
    protected DataType dataType;
    protected boolean wasLoaded;
    protected PairList<String, Shape> inputData;
    protected Map<String, Object> artifacts = new ConcurrentHashMap<String, Object>();
    protected Map<String, String> properties = new ConcurrentHashMap<String, String>();

    protected BaseModel(String modelName) {
        this.modelName = modelName;
    }

    @Override
    public Block getBlock() {
        return this.block;
    }

    @Override
    public void setBlock(Block block) {
        this.wasLoaded = false;
        this.block = block;
    }

    @Override
    public String getName() {
        return this.modelName;
    }

    @Override
    public NDManager getNDManager() {
        return this.manager;
    }

    @Override
    public Trainer newTrainer(TrainingConfig trainingConfig) {
        throw new UnsupportedOperationException("Not supported!");
    }

    @Override
    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator, Device device) {
        return new Predictor<I, O>(this, translator, device, false);
    }

    @Override
    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }

    @Override
    public DataType getDataType() {
        return this.dataType;
    }

    @Override
    public void load(InputStream is, Map<String, ?> options) throws IOException, MalformedModelException {
        throw new UnsupportedOperationException("Not supported!");
    }

    @Override
    public void close() {
        this.manager.close();
    }

    @Override
    public PairList<String, Shape> describeInput() {
        if (this.inputData == null) {
            this.inputData = this.block.describeInput();
        }
        return this.inputData;
    }

    @Override
    public PairList<String, Shape> describeOutput() {
        if (this.block instanceof SymbolBlock) {
            return ((SymbolBlock)this.block).describeOutput();
        }
        NDList input = new NDList();
        for (Pair<String, Shape> pair : this.describeInput()) {
            input.add(this.manager.ones(pair.getValue()));
        }
        ArrayList<String> outputNames = new ArrayList<String>();
        NDList nDList = this.block.forward(new ParameterStore(this.manager, true), input, false);
        Shape[] outputShapes = (Shape[])nDList.stream().map(NDArray::getShape).toArray(Shape[]::new);
        for (int i = 0; i < outputShapes.length; ++i) {
            outputNames.add("output" + i);
        }
        return new PairList<String, Shape>(outputNames, Arrays.asList(outputShapes));
    }

    @Override
    public String[] getArtifactNames() {
        throw new UnsupportedOperationException("Not supported!");
    }

    @Override
    public <T> T getArtifact(String name, Function<InputStream, T> function) throws IOException {
        try {
            Object artifact = this.artifacts.computeIfAbsent(name, v -> {
                Object r;
                block8: {
                    InputStream is = this.getArtifactAsStream(name);
                    try {
                        r = function.apply(is);
                        if (is == null) break block8;
                    }
                    catch (Throwable throwable) {
                        try {
                            if (is != null) {
                                try {
                                    is.close();
                                }
                                catch (Throwable throwable2) {
                                    throwable.addSuppressed(throwable2);
                                }
                            }
                            throw throwable;
                        }
                        catch (IOException e) {
                            throw new IllegalStateException(e);
                        }
                    }
                    is.close();
                }
                return r;
            });
            return (T)artifact;
        }
        catch (RuntimeException e) {
            Throwable t = e.getCause();
            if (t instanceof IOException) {
                throw (IOException)e.getCause();
            }
            throw e;
        }
    }

    @Override
    public URL getArtifact(String artifactName) throws IOException {
        if (artifactName == null) {
            throw new IllegalArgumentException("artifactName cannot be null");
        }
        Path file = this.modelDir.resolve(artifactName);
        if (Files.exists(file, new LinkOption[0]) && Files.isReadable(file)) {
            return file.toUri().toURL();
        }
        throw new FileNotFoundException("File not found: " + file);
    }

    @Override
    public InputStream getArtifactAsStream(String name) throws IOException {
        URL url = this.getArtifact(name);
        return new BufferedInputStream(url.openStream());
    }

    @Override
    public void setProperty(String key, String value) {
        this.properties.put(key, value);
    }

    @Override
    public String getProperty(String key) {
        return this.properties.get(key);
    }

    @Override
    public Map<String, String> getProperties() {
        return this.properties;
    }

    protected void setModelDir(Path modelDir) {
        this.modelDir = Utils.getNestedModelDir(modelDir);
    }

    protected void loadBlock(String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
        String paramOption;
        boolean hasParameter = true;
        if (options != null && (paramOption = (String)options.get("hasParameter")) != null) {
            hasParameter = Boolean.parseBoolean(paramOption);
        }
        if (hasParameter) {
            Path paramFile = this.paramPathResolver(prefix, options);
            if (paramFile == null) {
                throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name match your saved model file name.");
            }
            this.readParameters(paramFile, options);
        }
    }

    @Override
    public void save(Path modelPath, String newModelName) throws IOException {
        if (newModelName == null || newModelName.isEmpty()) {
            newModelName = this.modelName;
        }
        if (Files.notExists(modelPath, new LinkOption[0])) {
            Files.createDirectories(modelPath, new FileAttribute[0]);
        }
        if (this.block == null || !this.block.isInitialized()) {
            throw new IllegalStateException("Model has not be trained or loaded yet.");
        }
        String epochValue = this.getProperty("Epoch");
        int epoch = epochValue == null ? Utils.getCurrentEpoch(modelPath, newModelName) + 1 : Integer.parseInt(epochValue);
        String fileName = String.format(Locale.ROOT, "%s-%04d.params", newModelName, epoch);
        Path paramFile = modelPath.resolve(fileName);
        try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(paramFile, new OpenOption[0])));){
            dos.writeBytes("DJL@");
            dos.writeInt(1);
            dos.writeUTF(newModelName);
            dos.writeUTF(this.dataType.name());
            this.inputData = this.block.describeInput();
            dos.writeInt(this.inputData.size());
            for (Pair<String, Shape> pair : this.inputData) {
                String name = pair.getKey();
                if (name == null) {
                    dos.writeUTF("");
                } else {
                    dos.writeUTF(name);
                }
                dos.write(pair.getValue().getEncoded());
            }
            dos.writeInt(this.properties.size());
            for (Map.Entry entry : this.properties.entrySet()) {
                dos.writeUTF((String)entry.getKey());
                dos.writeUTF((String)entry.getValue());
            }
            this.block.saveParameters(dos);
        }
        this.modelDir = modelPath.toAbsolutePath();
    }

    @Override
    public Path getModelPath() {
        return this.modelDir;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Model (\n\tName: ").append(this.modelName);
        if (this.modelDir != null) {
            sb.append("\n\tModel location: ").append(this.modelDir.toAbsolutePath());
        }
        sb.append("\n\tData Type: ").append((Object)this.dataType);
        for (Map.Entry<String, String> entry : this.properties.entrySet()) {
            sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
        }
        sb.append("\n)");
        return sb.toString();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            logger.warn("Model: {} was not closed explicitly.", (Object)this.modelName);
            this.manager.close();
        }
        super.finalize();
    }

    protected Path paramPathResolver(String prefix, Map<String, ?> options) throws IOException {
        int epoch;
        Object epochOption = null;
        if (options != null) {
            epochOption = options.get("epoch");
        }
        if (epochOption == null) {
            epoch = Utils.getCurrentEpoch(this.modelDir, prefix);
            if (epoch == -1) {
                return null;
            }
        } else {
            epoch = Integer.parseInt(epochOption.toString());
        }
        return this.modelDir.resolve(String.format(Locale.ROOT, "%s-%04d.params", prefix, epoch));
    }

    protected boolean readParameters(Path paramFile, Map<String, ?> options) throws IOException, MalformedModelException {
        logger.debug("Try to load model from {}", (Object)paramFile);
        return this.readParameters(Files.newInputStream(paramFile, new OpenOption[0]), options);
    }

    protected boolean readParameters(InputStream paramStream, Map<String, ?> options) throws IOException, MalformedModelException {
        try (DataInputStream dis = new DataInputStream(new BufferedInputStream(paramStream));){
            byte[] buf = new byte[4];
            dis.readFully(buf);
            if (!"DJL@".equals(new String(buf, StandardCharsets.US_ASCII))) {
                boolean bl = false;
                return bl;
            }
            int version = dis.readInt();
            if (version != 1) {
                throw new IOException("Unsupported model version: " + version);
            }
            String savedModelName = dis.readUTF();
            logger.debug("Loading saved model: {} parameter", (Object)savedModelName);
            this.dataType = DataType.valueOf(dis.readUTF());
            int numberOfInputs = dis.readInt();
            this.inputData = new PairList();
            for (int i = 0; i < numberOfInputs; ++i) {
                String inputName = dis.readUTF();
                Shape shape = Shape.decode(dis);
                this.inputData.add(inputName, shape);
            }
            int numberOfProperties = dis.readInt();
            for (int i = 0; i < numberOfProperties; ++i) {
                String key = dis.readUTF();
                String value = dis.readUTF();
                this.properties.put(key, value);
            }
            this.block.loadParameters(this.manager, dis);
            logger.debug("DJL model loaded successfully");
        }
        return true;
    }
}

