Created
April 9, 2024 00:40
-
-
Save hotchpotch/64fa52d32886fe61cc1d110066afef38 to your computer and use it in GitHub Desktop.
ONNX model to float16 precision
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script converts an ONNX model to float16 precision using the onnxruntime transformers package. | |
It takes an input ONNX model file as a mandatory argument. The output file name is optional; if not provided, | |
the script generates the output file name by appending "_fp16" to the base name of the input file. | |
""" | |
import argparse | |
import onnx | |
from onnxruntime.transformers.float16 import convert_float_to_float16 | |
import os | |
def main(input_file, output_file=None): | |
# Check if the input file exists | |
if not os.path.exists(input_file): | |
print(f"Error: The input file '{input_file}' does not exist.") | |
return | |
# Generate the output file name from the input file name if not specified | |
if output_file is None: | |
base_name = os.path.splitext(input_file)[0] # Get the file name without the extension | |
output_file = f"{base_name}_fp16.onnx" | |
print(f"Loading model from {input_file}...") | |
onnx_model = onnx.load(input_file) | |
print("Converting model to float16...") | |
model_fp16 = convert_float_to_float16(onnx_model, disable_shape_infer=True) | |
print(f"Saving converted model to {output_file}...") | |
onnx.save(model_fp16, output_file) | |
print("Conversion completed successfully.") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Convert an ONNX model to float16.") | |
parser.add_argument("-i", "--input", required=True, help="Input ONNX model file.") | |
parser.add_argument("-o", "--output", required=False, help="Optional output file for the converted model. If not specified, derives the output file name from the input file name.") | |
args = parser.parse_args() | |
main(args.input, args.output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment