Source code for sdss_brain.helpers.io

# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Filename: io.py
# Project: helpers
# Author: Brian Cherinka
# Created: Wednesday, 7th October 2020 10:54:12 am
# License: BSD 3-clause "New" or "Revised" License
# Copyright (c) 2020 Brian Cherinka
# Last Modified: Wednesday, 7th October 2020 10:54:12 am
# Modified By: Brian Cherinka


from __future__ import print_function, division, absolute_import
import httpx
import pathlib
import gzip
from io import BytesIO
from typing import Union
from astropy.io import fits
from sdss_brain import log
from sdss_brain.config import config
from sdss_brain.exceptions import BrainError

try:
    from tqdm import tqdm
except ImportError:
    tqdm = None


[docs]def get_mapped_version(name: str, release: str = None, key: str = None) -> Union[dict, str]: ''' Get a version id mapped to a release number For a given named category, looks up the "mapped_versions" attribute from the configuration yaml file and returns a version number that has been mapped to a specific release. For example, for manga, DR16 maps to drpver='v2_4_3' and dapver='2.2.1'. This can be useful when needing to specify certain versions when defining paths to files. Parameters ---------- name : str The name of the set of versions to access release : str The SDSS release. Default is config.release. key : str Optional name of dictionary key to access specific value Returns ------- version : dict|str A version number corresponding to a given release Example ------- >>> # access the MaNGA versions for release DR16 >>> get_mapped_version('manga', release='DR16') {'drpver': 'v2_4_3', 'dapver': '2.2.1'} >>> # access specific key >>> get_mapped_version('manga', release='DR16', key='drpver') 'v2_4_3' ''' # if release is a work release, return nothing if release.lower() == 'work': return None # get the mapped_versions attribute from the configuration mapped_versions = config._custom_config.get('mapped_versions', None) assert mapped_versions, 'mapped_versions must be defined' if type(mapped_versions) != dict: raise TypeError('mapped_versions must be a dictionary') # ensure that the name is a valid entry if name not in mapped_versions: raise ValueError(f'{name} not found in mapped_versions dictionary') versions = mapped_versions.get(name, None) if type(versions) != dict: raise TypeError(f'release versions for {name} must be a dictionary') # ensure that the release is a valid entry release = release or config.release version = versions.get(release, None) if not version: raise ValueError(f'no mapped_version found for release {release} in {name}. ' 'Check the sdss_brain.yml config file.') # check for a specific key in the version dictionary if key and type(version) == dict: version = version.get(key, None) return version
[docs]def load_fits_file(filename: str) -> fits.HDUList: ''' Load a FITS file Opens and loads a FITS file with astropy.io.fits. Parameters ---------- filename : str A FITS filen to open Returns ------- hdulist : `~astropy.io.fits.HDUList` an Astropy HDUList ''' path = pathlib.Path(filename) if not path.exists() and path.is_file(): raise FileNotFoundError('input filename must exist and be a file') assert '.fits' in path.suffixes, 'filename is not a valid FITS file' try: hdulist = fits.open(path) except (IOError, OSError) as err: log.error(f'Cannot open FITS file {filename}: {err}') raise BrainError(f'Failed to open FITS files {filename}: {err}') from err else: return hdulist
[docs]def load_from_url(url: str, no_progress: bool = None) -> fits.HDUList: ''' Load a file from a remote url using a get request Streams url content with httpx.stream and pipes the response contents into an Astropy FITS file. Parameters ---------- url : str A url path to a filename no_progress : bool If True, turns off the tqdm progress bar Returns ------- an Astropy `~astropy.io.fits.HDUList` ''' b = BytesIO() with httpx.stream("GET", url) as r: r.raise_for_status() total = int(r.headers["Content-Length"]) if not tqdm or no_progress: for data in r.iter_bytes(): b.write(data) else: with tqdm(total=total, unit_scale=True, unit_divisor=1024, unit="B") as progress: num_bytes_downloaded = r.num_bytes_downloaded for data in r.iter_bytes(): b.write(data) progress.update(r.num_bytes_downloaded - num_bytes_downloaded) num_bytes_downloaded = r.num_bytes_downloaded b.seek(0) if url.endswith('.gz'): return fits.open(gzip.open(b, 'rb')) else: return fits.open(b)