Source code for niftynet.utilities.download

#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import math
import os
import tarfile
import tempfile
from distutils.version import LooseVersion
from os.path import basename
from shutil import copyfile
from shutil import move

try:
    import configparser
except ImportError:
    import ConfigParser as configparser
from six.moves.urllib.parse import urlparse
from six.moves.urllib.request import urlopen
from six.moves.urllib.request import urlretrieve

# Used with the min_download_api settings option to determine
# if the downloaded configuration file is compatible with
# this version of NiftyNet downloader code
from niftynet.utilities.niftynet_global_config import NiftyNetGlobalConfig
from niftynet.utilities.util_common import print_progress_bar
from niftynet.utilities.versioning import get_niftynet_version, \
    get_niftynet_version_string

DOWNLOAD_API_VERSION = "1.0"
CONFIG_FILE_EXT = ".ini"


[docs]def download(example_ids, download_if_already_existing=False, verbose=True): """ Downloads standard NiftyNet examples such as data, samples :param example_ids: A list of identifiers for the samples to download :param download_if_already_existing: If true, data will always be downloaded """ global_config = NiftyNetGlobalConfig() config_store = ConfigStore(global_config) # If a single id is specified, convert to a list example_ids = [example_ids] \ if not isinstance(example_ids, (tuple, list)) else example_ids if not example_ids: return False # Check if the server is running by looking for a known file remote_base_url_test = gitlab_raw_file_url( global_config.get_download_server_url(), 'README.md') server_ok = url_exists(remote_base_url_test) if verbose: print("Accessing: {}".format(global_config.get_download_server_url())) any_error = False for example_id in example_ids: if not example_id: any_error = True continue if config_store.exists(example_id): update_ok = config_store.update_if_required( example_id, download_if_already_existing) any_error = (not update_ok) or any_error else: any_error = True if server_ok: print(example_id + ': FAIL. ') print('No NiftyNet example was found for ' + example_id + ".") # If errors occurred and the server is down report a message if any_error and not server_ok: print("The NiftyNetExamples server is not running") return not any_error
[docs]def download_file(url, download_path): """ Download a file from a resource URL to the given location :param url: URL of the file to download :param download_path: location where the file should be saved """ # Extract the filename from the URL parsed = urlparse(url) filename = basename(parsed.path) # Ensure the output directory exists if not os.path.exists(download_path): os.makedirs(download_path) # Get a temporary file path for the compressed file download downloaded_file = os.path.join(tempfile.gettempdir(), filename) # Download the file urlretrieve(url, downloaded_file, reporthook=progress_bar_wrapper) # Move the file to the destination folder destination_path = os.path.join(download_path, filename) move(downloaded_file, destination_path)
[docs]def download_and_decompress(url, download_path): """ Download an archive from a resource URL and decompresses/unarchives to the given location :param url: URL of the compressed file to download :param download_path: location where the file should be extracted """ # Extract the filename from the URL parsed = urlparse(url) filename = basename(parsed.path) # Ensure the output directory exists if not os.path.exists(download_path): os.makedirs(download_path) # Get a temporary file path for the compressed file download downloaded_file = os.path.join(tempfile.gettempdir(), filename) # Download the file urlretrieve(url, downloaded_file, reporthook=progress_bar_wrapper) # Decompress and extract all files to the specified local path tar = tarfile.open(downloaded_file, "r") tar.extractall(download_path) tar.close() # Remove the downloaded file os.remove(downloaded_file)
[docs]class ConfigStore: """ Manages a configuration file store based on a remote repository with local caching """ def __init__(self, global_config): self._download_folder = global_config.get_niftynet_home_folder() self._config_folder = global_config.get_niftynet_config_folder() self._local = ConfigStoreCache( os.path.join(self._config_folder, '.downloads_local_config_cache')) self._remote = RemoteProxy(self._config_folder, global_config.get_download_server_url())
[docs] def exists(self, example_id): """ Returns True if a record exists for this example_id, either locally or remotely """ return self._local.exists(example_id) or self._remote.exists(example_id)
[docs] def update_if_required(self, example_id, download_if_already_existing=False): """ Downloads data using the configuration file if it is not already up to date. Returns True if no update was required and no errors occurred """ try: self._remote.update(example_id) remote_update_failed = False except Exception as e: print("Warning: updating the examples file " "from the server caused an error: {}".format(e)) remote_update_failed = True current_config, current_entries = \ self._local.get_download_params(example_id) remote_config, remote_entries = \ self._remote.get_download_params(example_id) if not remote_entries: if remote_update_failed: print(example_id + ": FAIL.") print("Cannot download the examples configuration file. " "Is the server down?") else: print(example_id + ": FAIL. Nothing to download") return False else: # Always download if the local file is empty, or force by arguments force_download = download_if_already_existing or \ (not current_config and not current_entries) data_missing = self._are_data_missing(remote_entries, example_id) if force_download or data_missing or self._is_update_required( current_config, remote_config): self._check_minimum_versions(remote_config) self._download(remote_entries, example_id) self._replace_local_with_remote_config(example_id) else: print(example_id + ": OK. ") print("Already downloaded. " "Use the -r option to download again.") return True
@staticmethod def _check_minimum_versions(remote_config): # Checks whether a minimum download API is specified if 'min_download_api' in remote_config: min_download_api = remote_config['min_download_api'] current_download_api_version = DOWNLOAD_API_VERSION if LooseVersion(min_download_api) > LooseVersion( current_download_api_version): raise ValueError( "This example requires a newer version of NiftyNet.") # Checks whether a minimum NiftyNet version is specified if 'min_niftynet' in remote_config: min_niftynet = remote_config['min_niftynet'] current_version = get_niftynet_version() if LooseVersion(min_niftynet) > LooseVersion(current_version): raise ValueError("This example requires NiftyNet " "version %s or later.", min_niftynet) @staticmethod def _is_update_required(current_config, remote_config): """ If no version information locally, then update only if version information is specified remotely We are assuming that this is overridden by the case of no local information at all """ if 'version' not in current_config: return 'version' in remote_config else: return LooseVersion(current_config['version']) < \ LooseVersion(remote_config['version']) def _download(self, remote_config_sections, example_id): for section_name, config_params in remote_config_sections.items(): if 'action' in config_params: action = config_params.get('action').lower() if action == 'expand': if 'url' not in config_params: raise ValueError('No URL was found in the download ' 'configuration file') local_download_path = self._get_local_download_path( config_params, example_id) download_and_decompress(url=config_params['url'], download_path=local_download_path) print('{} -- {}: OK.'.format(example_id, section_name)) print("Downloaded data to " + local_download_path) else: print(example_id + ": FAIL.") print("I do not know the action " + action + ". Perhaps you need to update NiftyNet?") def _get_local_download_path(self, remote_config, example_id): destination = remote_config.get('destination', 'examples') local_id = remote_config.get('local_id', example_id) return os.path.join(self._download_folder, destination, local_id) def _replace_local_with_remote_config(self, example_id): local_filename = self._local.get_local_path(example_id) remote_filename = self._remote.get_local_path(example_id) copyfile(remote_filename, local_filename) def _are_data_missing(self, remote_config_sections, example_id): for section_name, config_params in remote_config_sections.items(): if 'action' in config_params: action = config_params.get('action').lower() if action == 'expand': local_download_path = self._get_local_download_path( config_params, example_id) if not os.path.isdir(local_download_path): return True non_system_files = [f for f in os.listdir(local_download_path) if not f.startswith('.')] if not non_system_files: return True return False
[docs]class ConfigStoreCache: """ A local cache for configuration files """ def __init__(self, cache_folder): self._cache_folder = cache_folder if not os.path.exists(self._cache_folder): os.makedirs(self._cache_folder)
[docs] def exists(self, example_id): """ Returns True if a record exists for this example_id, either locally or remotely """ return os.path.isfile(self.get_local_path(example_id))
[docs] def get_local_path(self, example_id): """ Returns the full path to the locally cached configuration file """ return os.path.join(self._cache_folder, example_id + CONFIG_FILE_EXT)
[docs] def get_local_cache_folder(self): """ Returns the folder in which the cached files are stored """ return self._cache_folder
[docs] def get_download_params(self, example_id): """ Returns the local configuration file for this example_id """ config_filename = self.get_local_path(example_id) parser = configparser.ConfigParser() parser.read(config_filename) if parser.has_section('config'): config_section = dict(parser.items('config')) else: config_section = {} other_sections = {} for section in parser.sections(): if section != 'config' and section != 'DEFAULT': other_sections[section] = dict(parser.items(section)) return config_section, other_sections
[docs]class RemoteProxy: """ A remote configuration file store with a local cache """ def __init__(self, parent_store_folder, base_url): self._cache = ConfigStoreCache( os.path.join(parent_store_folder, '.downloads_remote_config_cache')) self._remote = RemoteConfigStore(base_url)
[docs] def exists(self, example_id): """ Returns True if a record exists locally or remotely """ return self._cache.exists(example_id) or self._remote.exists(example_id)
[docs] def update(self, example_id): """ Retrieves the latest record from the remote store and puts locally into the remote cache """ download_file(self._remote.get_url(example_id), self._cache.get_local_cache_folder())
[docs] def get_download_params(self, example_id): """ Returns the local configuration file for this example_id """ return self._cache.get_download_params(example_id)
[docs] def get_local_path(self, example_id): """ Returns the full path to the locally cached configuration file """ return self._cache.get_local_path(example_id)
[docs]class RemoteConfigStore: """ A remote configuration file store """ def __init__(self, base_url): self._base_url = base_url
[docs] def exists(self, example_id): """ Returns true if the record exists on the remote server """ return url_exists(self.get_url(example_id))
[docs] def get_url(self, example_id): """ Gets the URL for the record for this example_id """ return gitlab_raw_file_url(self._base_url, example_id + CONFIG_FILE_EXT)
[docs]def gitlab_raw_file_url(base_url, file_name): """ Returns the url for the raw file on a GitLab server """ return base_url + '/raw/master/' + file_name
#return base_url + '/raw/revising-config/' + file_name
[docs]def url_exists(url): """ Returns true if the specified url exists, without any redirects """ try: connection = urlopen(url) return connection.getcode() < 400 except Exception: return False
[docs]def progress_bar_wrapper(count, block_size, total_size): """ Uses the common progress bar in the urlretrieve hook format """ if block_size*5 >= total_size: # no progress bar for tiny files return print_progress_bar( iteration=count, total=math.ceil(float(total_size) / float(block_size)), prefix="Downloading (total: %.2f M): " % (total_size * 1.0 / 1e6))
[docs]def main(): arg_parser = argparse.ArgumentParser( description="Download NiftyNet sample data") arg_parser.add_argument( "-r", "--retry", help="Force data to be downloaded again", required=False, action='store_true') arg_parser.add_argument( 'sample_id', nargs='+', help="Identifier string(s) for the example(s) to download") version_string = get_niftynet_version_string() arg_parser.add_argument( "-v", "--version", action='version', version=version_string) args = arg_parser.parse_args() if not download(args.sample_id, args.retry): return -1 return 0
if __name__ == "__main__": main()