/usr/src/tensorrt$ cat bin/download-digits-model.py

 

 

/usr/src/tensorrt$ cat bin/download-digits-model.py
#!/usr/bin/env python

import sys
import os.path
import argparse
import requests
import json

class ModelDownloader(object):
    """
    Downloads a DIGITS model
    """

    def __init__(self, hostname, port):
        self.hostname = hostname
        self.port = port

    def get_url(self, url):
        try:
            r = requests.get(url)
            r.raise_for_status()
            return r.content
        except requests.exceptions.ConnectionError as e:
            print 'Failed to connect to server at %s:%s' % (self.hostname, self.port)
        except requests.exceptions.RequestException as e:
            print 'Error loading "%s"' % url
            print '\t', e.message
        sys.exit(1)

    def get_job_id(self):
        """
        Present the user with a list of models on the server and return the id of the selection
        """
        url = 'http://%s:%s/index.json' % (self.hostname, self.port)
        models = json.loads(self.get_url(url))['models']
        if not len(models):
            raise Exception('No models exist on this server!')
        fmt = '[%3s] %-20s %-10s %-20s'
        print fmt % ('Num', 'Job ID', 'Status', 'Name')
        print '-' * len(fmt % ('a', 'a', 'a', 'a'))
        for i, model in enumerate(models):
            print fmt % (i+1, model['id'], model['status'], model['name'])

        selected = None
        while selected is None:
            print 'Select a job'
            x = raw_input('>>> ')
            try:
                x = int(x)-1
                if 0 <= x < len(models):
                    selected = x
                else:
                    print 'Out of range'
            except ValueError as e:
                print e
        print

        return models[selected]['id']

    def get_snapshot_epoch(self, job_id):
        """
        Present the user with a list of snapshots and return the epoch of the selection
        """
        url = 'http://%s:%s/models/%s.json' % (self.hostname, self.port, job_id)
        snapshots = json.loads(self.get_url(url))['snapshots']
        if not len(snapshots):
            raise Exception('No snapshots exist for this job!')
        fmt = '[%3s] %-10s'
        print fmt % ('Num', 'Epoch')
        print '-' * len(fmt % ('a', 'a'))
        for i, epoch in enumerate(snapshots):
            print fmt % (i+1, epoch)

        selected = None
        default = len(snapshots)
        while selected is None:
            print 'Select a snapshot (leave blank for default=%s)' % default
            x = raw_input('>>> ')
            if not x.strip():
                selected = default - 1
            else:
                try:
                    x = int(x)-1
                    if 0 <= x < len(snapshots):
                        selected = x
                    else:
                        print 'Out of range'
                except ValueError as e:
                    print e
        print

        epoch = snapshots[selected]
        return epoch

    def download_model(self, output_file, job_id, snapshot_epoch):
        """
        Download a tarfile of a specific snapshot
        """
        extension = '.'.join([''] + os.path.basename(output_file).split('.')[1:])

        url = 'http://%s:%s/models/%s/download%s?epoch=%s' % (self.hostname, self.port, job_id, extension, snapshot_epoch)
        content = self.get_url(url)

        print 'Saving to %s' % output_file
        with open(output_file, 'wb') as outfile:
            outfile.write(content)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Download a DIGITS model tarball')

    ### Positional arguments

    parser.add_argument('output_file',
            help='output file (should end with .zip, .tar, .tar.gz or .tar.bz2)')

    ### Optional arguments

    parser.add_argument('-n', '--hostname',
            default='127.0.0.1',
            help='hostname for the DIGITS server [default=127.0.0.1]')

    parser.add_argument('-p', '--port',
            type=int,
            default=80,
            help='port for the DIGITS server [default=80]')

    args = vars(parser.parse_args())

    downloader = ModelDownloader(args['hostname'], args['port'])
    job_id = downloader.get_job_id()
    snapshot_epoch = downloader.get_snapshot_epoch(job_id)
    downloader.download_model(args['output_file'], job_id, snapshot_epoch)

    print 'Done.'

 

posted on 2018-03-13 10:46  cdekelon  阅读(99)  评论(0)    收藏  举报

导航