#ifndef __onnx_util_H__
#define __onnx_util_H__

#include <onnx/onnx_pb.h>
#include <vector>
#include <map>
#include <set>
#include <algorithm>

#include "string_util.h"
#include "spdlog/spdlog.h"

typedef std::vector<std::string> StringVector;
typedef std::set<std::string> StringSet;
typedef std::map<std::string, onnx::ValueInfoProto> OnnxValueInfoProtoMap; 

// Extract sub-graph from an onnx model
// https://github.com/onnx/onnx/blob/1cbd549193661c7e9c870c4e9595144d0c6971c5/onnx/utils.py
class Extractor {

public:
    Extractor(onnx::ModelProto model): mModel(model) {}

    void prepare();

    OnnxValueInfoProtoMap CollectNewIOCore(OnnxValueInfoProtoMap originalIO, StringVector ioNamesToExtract);

    void DFSSearchReachableNodes(std::string nodeOutputName, StringVector graphInputNames, std::vector<onnx::NodeProto>& reachableNodes);

    std::vector<onnx::NodeProto> CollectReachableNodes(StringVector inputNames, StringVector outputNames);
    
    std::vector<onnx::TensorProto> CollectReachableInitializer(std::vector<onnx::NodeProto>& nodes);

    std::vector<onnx::ValueInfoProto> CollectReachableValueInfo(std::vector<onnx::NodeProto>& nodes);

    onnx::ModelProto ExtractModel(StringVector inputNames, StringVector outputNames);

private:
    onnx::ModelProto mModel;
    onnx::GraphProto* graph;
    std::map<std::string, onnx::TensorProto> wmap;
    OnnxValueInfoProtoMap vimap;
};



#endif // !__onnx_util_H__