Created
February 25, 2021 09:48
-
-
Save argman/f378b5020e2ec7e2b621dcd1bf2f61f7 to your computer and use it in GitHub Desktop.
onnx c++ code to extract a sub-graph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "onnx_util.h" | |
void Extractor::prepare() { | |
for(int i=0; i < mModel.graph().initializer_size(); ++i) { | |
auto initializerI = mModel.graph().initializer(i); | |
spdlog::info("Initializer {} : name={}", i, initializerI.name()); | |
wmap[initializerI.name()] = initializerI; | |
} | |
for(int i=0; i < mModel.graph().value_info_size(); ++i) { | |
auto valueInfoI = mModel.graph().value_info(i); | |
vimap[valueInfoI.name()] = valueInfoI; | |
} | |
} | |
OnnxValueInfoProtoMap Extractor::CollectNewIOCore(OnnxValueInfoProtoMap originalIO, StringVector ioNamesToExtract) { | |
std::set<std::string> originalIONames; | |
for(auto item : originalIO) { | |
originalIONames.insert(item.first); | |
} | |
StringSet sIONamesToExtract; | |
for(size_t i=0; i < ioNamesToExtract.size(); i++) { | |
sIONamesToExtract.insert(ioNamesToExtract[i]); | |
} | |
StringSet ioNamesToKeep; | |
StringSet newIONamesToAdd; | |
std::set_intersection(sIONamesToExtract.begin(), sIONamesToExtract.end(), | |
originalIONames.begin(), originalIONames.end(), | |
std::inserter(ioNamesToKeep, newIONamesToAdd.begin())); | |
// not in graph input or output, maybe activations | |
std::set_difference(sIONamesToExtract.begin(), sIONamesToExtract.end(), | |
originalIONames.begin(), originalIONames.end(), | |
std::inserter(newIONamesToAdd, newIONamesToAdd.begin())); | |
OnnxValueInfoProtoMap newIOValueInfo; | |
for(auto name : ioNamesToKeep) { | |
newIOValueInfo[name] = originalIO[name]; | |
} | |
for(auto name : newIONamesToAdd) { | |
if(vimap.count(name) == 0) { | |
spdlog::error("Cannot find io name {} in activations", name); | |
continue; | |
} else { | |
newIOValueInfo[name] = vimap[name]; | |
} | |
} | |
return newIOValueInfo; | |
} | |
void Extractor::DFSSearchReachableNodes(std::string nodeOutputName, StringVector graphInputNames, std::vector<onnx::NodeProto>& reachableNodes) { | |
if(CT::indexOfStringItems(nodeOutputName, graphInputNames) != -1) { | |
return; | |
} | |
for(int i=0; i < mModel.graph().node_size(); i++) { | |
auto nodeI = mModel.graph().node(i); | |
bool nodeInReachableNodes = false; | |
for(int j=0; j < reachableNodes.size(); ++j) { | |
if(reachableNodes[j].name() == nodeI.name()) { | |
nodeInReachableNodes = true; | |
break; | |
} | |
} | |
if(nodeInReachableNodes) { | |
continue; | |
} | |
bool findInNodeOutput = false; | |
for(int j=0; j < nodeI.output_size(); j++) { | |
if(nodeI.output(j) == nodeOutputName) { | |
findInNodeOutput = true; | |
break; | |
} | |
} | |
if(!findInNodeOutput) { | |
continue; | |
} | |
reachableNodes.push_back(nodeI); | |
for(int j=0; j < nodeI.input_size(); j++) { | |
DFSSearchReachableNodes(nodeI.input(j), graphInputNames, reachableNodes); | |
} | |
} | |
} | |
std::vector<onnx::NodeProto> Extractor::CollectReachableNodes(StringVector inputNames, StringVector outputNames) { | |
std::vector<onnx::NodeProto> reachableNodes; | |
for(auto name : outputNames) { | |
DFSSearchReachableNodes(name, inputNames, reachableNodes); | |
} | |
// needs to be topology sorted | |
std::vector<onnx::NodeProto> nodes; | |
for(int i=0; i < mModel.graph().node_size(); i++) { | |
for(auto node : reachableNodes) { | |
if(node.name() == mModel.graph().node(i).name()) { | |
nodes.push_back(mModel.graph().node(i)); | |
} | |
} | |
} | |
return nodes; | |
} | |
std::vector<onnx::TensorProto> Extractor::CollectReachableInitializer(std::vector<onnx::NodeProto>& nodes) { | |
StringVector allTensorNames; | |
for(auto node : nodes) { | |
for(int i=0; i < node.input_size(); i++) { | |
allTensorNames.push_back(node.input(i)); | |
} | |
for(int i=0; i < node.output_size(); i++) { | |
allTensorNames.push_back(node.output(i)); | |
} | |
} | |
std::vector<onnx::TensorProto> initializer; | |
for(auto w : wmap) { | |
for(auto name : allTensorNames) { | |
if(name == w.first) { | |
initializer.push_back(w.second); | |
} | |
} | |
} | |
return initializer; | |
} | |
std::vector<onnx::ValueInfoProto> Extractor::CollectReachableValueInfo(std::vector<onnx::NodeProto>& nodes) { | |
StringVector allTensorNames; | |
for(auto node : nodes) { | |
for(int i=0; i < node.input_size(); i++) { | |
allTensorNames.push_back(node.input(i)); | |
} | |
for(int i=0; i < node.output_size(); i++) { | |
allTensorNames.push_back(node.output(i)); | |
} | |
} | |
std::vector<onnx::ValueInfoProto> valueInfo; | |
for(auto v : vimap) { | |
for(auto name : allTensorNames) { | |
if(name == v.first) { | |
valueInfo.push_back(v.second); | |
} | |
} | |
} | |
return valueInfo; | |
} | |
onnx::ModelProto Extractor::ExtractModel(std::vector<std::string> inputNames, std::vector<std::string> outputNames) { | |
prepare(); | |
OnnxValueInfoProtoMap originalInput; | |
for(int i=0; i < mModel.graph().input_size(); i++) { | |
originalInput[mModel.graph().input(i).name()] = mModel.graph().input(i); | |
} | |
OnnxValueInfoProtoMap originalOutput; | |
for(int i=0; i < mModel.graph().output_size(); i++) { | |
originalOutput[mModel.graph().output(i).name()] = mModel.graph().output(i); | |
} | |
OnnxValueInfoProtoMap newInputs = CollectNewIOCore(originalInput, inputNames); | |
spdlog::info("Collect {} new input tensors", newInputs.size()); | |
OnnxValueInfoProtoMap newOutputs = CollectNewIOCore(originalOutput, outputNames); | |
spdlog::info("Collect {} new ouptut tensors", newOutputs.size()); | |
std::vector<onnx::NodeProto> newNodes = CollectReachableNodes(inputNames, outputNames); | |
spdlog::info("Collect {} reachable nodes in graph", newNodes.size()); | |
std::vector<onnx::TensorProto> newInitializer = CollectReachableInitializer(newNodes); | |
spdlog::info("Collect {} reachable initializers in graph", newInitializer.size()); | |
std::vector<onnx::ValueInfoProto> newValueInfo = CollectReachableValueInfo(newNodes); | |
spdlog::info("Collect {} reachable value info in graph", newValueInfo.size()); | |
auto newGraph = onnx::GraphProto(); | |
for(auto node : newNodes) { | |
auto addedNode = newGraph.add_node(); | |
addedNode->CopyFrom(node); | |
} | |
newGraph.set_name("ExtractedGraph"); | |
for(auto input : newInputs) { | |
auto addedInput = newGraph.add_input(); | |
addedInput->CopyFrom(input.second); | |
} | |
for(auto output : newOutputs) { | |
auto addedOutput = newGraph.add_output(); | |
addedOutput->CopyFrom(output.second); | |
} | |
for(auto initializer : newInitializer) { | |
auto addedInitializer = newGraph.add_initializer(); | |
addedInitializer->CopyFrom(initializer); | |
} | |
for(auto valueInfo : newValueInfo) { | |
auto addedValueInfo = newGraph.add_value_info(); | |
addedValueInfo->CopyFrom(valueInfo); | |
} | |
auto newModel = onnx::ModelProto(); | |
newModel.set_ir_version(mModel.ir_version()); | |
newModel.mutable_graph()->CopyFrom(newGraph); | |
for(int i=0; i < mModel.opset_import_size(); i++) { | |
auto addedOpset = newModel.add_opset_import(); | |
addedOpset->CopyFrom(mModel.opset_import(i)); | |
} | |
newModel.set_producer_version(mModel.producer_name() + "SubModel"); | |
return newModel; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef __onnx_util_H__ | |
#define __onnx_util_H__ | |
#include <onnx/onnx_pb.h> | |
#include <vector> | |
#include <map> | |
#include <set> | |
#include <algorithm> | |
#include "string_util.h" | |
#include "spdlog/spdlog.h" | |
typedef std::vector<std::string> StringVector; | |
typedef std::set<std::string> StringSet; | |
typedef std::map<std::string, onnx::ValueInfoProto> OnnxValueInfoProtoMap; | |
// Extract sub-graph from an onnx model | |
// https://github.com/onnx/onnx/blob/1cbd549193661c7e9c870c4e9595144d0c6971c5/onnx/utils.py | |
class Extractor { | |
public: | |
Extractor(onnx::ModelProto model): mModel(model) {} | |
void prepare(); | |
OnnxValueInfoProtoMap CollectNewIOCore(OnnxValueInfoProtoMap originalIO, StringVector ioNamesToExtract); | |
void DFSSearchReachableNodes(std::string nodeOutputName, StringVector graphInputNames, std::vector<onnx::NodeProto>& reachableNodes); | |
std::vector<onnx::NodeProto> CollectReachableNodes(StringVector inputNames, StringVector outputNames); | |
std::vector<onnx::TensorProto> CollectReachableInitializer(std::vector<onnx::NodeProto>& nodes); | |
std::vector<onnx::ValueInfoProto> CollectReachableValueInfo(std::vector<onnx::NodeProto>& nodes); | |
onnx::ModelProto ExtractModel(StringVector inputNames, StringVector outputNames); | |
private: | |
onnx::ModelProto mModel; | |
onnx::GraphProto* graph; | |
std::map<std::string, onnx::TensorProto> wmap; | |
OnnxValueInfoProtoMap vimap; | |
}; | |
#endif // !__onnx_util_H__ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment