2019年8月4日日曜日

SONY NNC:cli.py(バージョン1.50)

【バージョン1.50】
import argpars
import sys
import warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")


def _nnabla_version():
    import nnabla
    import nnabla.utils.callback as callback
    version_string = 'Version:{}, Build:{}'.format(nnabla.__version__,
                                                   nnabla.__build_number__)
    callback_version_string = callback.get_callback_version()
    if callback_version_string is not None:
        version_string += ', Callback:{}'.format(callback_version_string)
    return version_string


def version_command(args):
    print(_nnabla_version())


return_value = None


def main():
    global return_value
    import six.moves._thread as thread
    import threading
    thread.stack_size(128 * 1024 * 1024)
    sys.setrecursionlimit(0x3fffffff)
    main_thread = threading.Thread(target=cli_main)
    main_thread.start()
    main_thread.join()
    if not return_value:
        sys.exit(-1)


def cli_main():
    global return_value
    return_value = False

    parser = argparse.ArgumentParser(description='Command line interface ' +
                                     'for NNabla({})'.format(_nnabla_version()))
    parser.add_argument(
        '-m', '--mpi', help='exec with mpi.', action='store_true')

    subparsers = parser.add_subparsers()

    import nnabla
    from nnabla.utils.cli.train import add_train_command
    add_train_command(subparsers)

    from nnabla.utils.cli.forward import add_infer_command, add_forward_command
    add_infer_command(subparsers)
    add_forward_command(subparsers)

    from nnabla.utils.cli.encode_decode_param import add_decode_param_command, add_encode_param_command
    add_encode_param_command(subparsers)
    add_decode_param_command(subparsers)

    from nnabla.utils.cli.profile import add_profile_command
    add_profile_command(subparsers)

    from nnabla.utils.cli.conv_dataset import add_conv_dataset_command
    add_conv_dataset_command(subparsers)

    from nnabla.utils.cli.compare_with_cpu import add_compare_with_cpu_command
    add_compare_with_cpu_command(subparsers)

    from nnabla.utils.cli.create_image_classification_dataset import add_create_image_classification_dataset_command
    add_create_image_classification_dataset_command(subparsers)

    from nnabla.utils.cli.uploader import add_upload_command
    add_upload_command(subparsers)

    from nnabla.utils.cli.uploader import add_create_tar_command
    add_create_tar_command(subparsers)

    from nnabla.utils.cli.convert import add_convert_command
    add_convert_command(subparsers)

    from nnabla.utils.cli.func_info import add_function_info_command
    add_function_info_command(subparsers)

    from nnabla.utils.cli.plot import (
        add_plot_series_command, add_plot_timer_command)
    add_plot_series_command(subparsers)
    add_plot_timer_command(subparsers)

    from nnabla.utils.cli.draw_graph import add_draw_graph_command
    add_draw_graph_command(subparsers)

    # Version
    subparser = subparsers.add_parser(
        'version', help='Print version and build number.')
    subparser.set_defaults(func=version_command)

    print('NNabla command line interface ({})'.format(_nnabla_version()))

    args = parser.parse_args()

    if 'func' not in args:
        parser.print_help(sys.stderr)
        sys.exit(-1)

    if args.mpi:
        from nnabla.utils.communicator_util import create_communicator
        comm = create_communicator()
        try:
            return_value = args.func(args)
        except:
            import traceback
            print(traceback.format_exc())
            comm.abort()
    else:
        try:
            return_value = args.func(args)
        except:
            import traceback
            print(traceback.format_exc())
            return_value = False
            sys.exit(-1)


if __name__ == '__main__':
    main()


0 件のコメント:

コメントを投稿