#ifndef __onnx_util_H__ #define __onnx_util_H__ #include <onnx/onnx_pb.h> #include <vector> #include <map> #include <set> #include <algorithm> #include "string_util.h" #include "spdlog/spdlog.h" typedef std::vector<std::string> StringVector; typedef std::set<std::string> StringSet; typedef std::map<std::string, onnx::ValueInfoProto> OnnxValueInfoProtoMap; // Extract sub-graph from an onnx model // https://github.com/onnx/onnx/blob/1cbd549193661c7e9c870c4e9595144d0c6971c5/onnx/utils.py class Extractor { public: Extractor(onnx::ModelProto model): mModel(model) {} void prepare(); OnnxValueInfoProtoMap CollectNewIOCore(OnnxValueInfoProtoMap originalIO, StringVector ioNamesToExtract); void DFSSearchReachableNodes(std::string nodeOutputName, StringVector graphInputNames, std::vector<onnx::NodeProto>& reachableNodes); std::vector<onnx::NodeProto> CollectReachableNodes(StringVector inputNames, StringVector outputNames); std::vector<onnx::TensorProto> CollectReachableInitializer(std::vector<onnx::NodeProto>& nodes); std::vector<onnx::ValueInfoProto> CollectReachableValueInfo(std::vector<onnx::NodeProto>& nodes); onnx::ModelProto ExtractModel(StringVector inputNames, StringVector outputNames); private: onnx::ModelProto mModel; onnx::GraphProto* graph; std::map<std::string, onnx::TensorProto> wmap; OnnxValueInfoProtoMap vimap; }; #endif // !__onnx_util_H__