Source code for target_selection.cartons.tools

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2021-04-29
# @Filename: tools.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

import os
import tempfile

import numpy
import peewee
from astropy.io import fits
from astropy.table import Table

from sdssdb.utils.ingest import create_model_from_table

from target_selection.cartons import BaseCarton
from target_selection.exceptions import TargetSelectionError
from target_selection.utils import vacuum_table


[docs] def get_file_carton(filename): """Returns a carton class that creates a carton based on a FITS file. The FITS file is located in the ``open_fiber_path`` which is specified in ``python/config/target_selection.yml``. The list of FITS files to be loaded is specified in the file ``open_fiber_file_list.txt`` which is in the directory ``open_fiber_path``. """ # Import this here to prevent this module not being importable if the database # connection is not ready. from sdssdb.peewee.sdss5db.catalogdb import ( Catalog, CatalogToGaia_DR2, CatalogToGaia_DR3, CatalogToLegacy_Survey_DR8, CatalogToLegacy_Survey_DR10, CatalogToPanstarrs1, CatalogToTwoMassPSC, Gaia_DR2, Gaia_DR3, Legacy_Survey_DR8, Legacy_Survey_DR10, Panstarrs1, TwoMassPSC, ) class FileCarton(BaseCarton): can_offset = None # Will be set in query. def __init__(self, targeting_plan, config_file=None, schema=None, table_name=None): self._file_path = filename hdu_list = fits.open(self._file_path, memmap=True) col_list = str(hdu_list[1].columns) if "null =" in col_list: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " has null specified in the fits file columns." ) hdu_list.close() self._table = Table.read(self._file_path) self._run_sanity_checks() super().__init__( targeting_plan, config_file=config_file, schema=schema, table_name=table_name, ) self._disable_query_log = True def _run_sanity_checks(self): """Runs a series of sanity checks on the FITS table.""" if self._table.has_masked_values: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " has null specified in the fits file columns." + " Hence the table has masked values." ) unique_cartonname = numpy.unique(self._table["cartonname"]) if len(unique_cartonname) == 1: self.name = unique_cartonname[0].lower() else: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains more than one cartonname" ) unique_can_offset = numpy.unique(self._table["can_offset"]) if len(unique_can_offset) > 1: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains more than one" + " value of can_offset:" + " can_offset values must be " + " all 0 or all 1" ) if (unique_can_offset[0] != 1) and (unique_can_offset[0] != 0): raise TargetSelectionError( "Error in get_file_carton(): " + filename + " can_offset can only be 0 or 1." + " can_offset is " + str(unique_can_offset[0]) ) unique_inertial = numpy.unique(self._table["inertial"]) if len(unique_inertial) > 2: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains more than two" + " values of inertial:" + " inertial values must be " + " 0 or 1" ) if (unique_inertial[0] != 1) and (unique_inertial[0] != 0): raise TargetSelectionError( "Error in get_file_carton(): " + filename + " inertial can only be 0 or 1." + " inertial is " + str(unique_inertial[0]) ) # If there is only one inertial value then the above statement # is enough. Otherwise, we need to run the below check. if len(unique_inertial) == 2: if (unique_inertial[1] != 1) and (unique_inertial[1] != 0): raise TargetSelectionError( "Error in get_file_carton(): " + filename + " inertial can only be 0 or 1." + " inertial is " + str(unique_inertial[1]) ) # The valid_program list is from the output of the below command. # select distinct(program) from targetdb.carton order by program; # # mwm_bin is for future mwm binary star cartons valid_program = [ "bhm_aqmes", "bhm_csc", "bhm_filler", "bhm_rm", "bhm_spiders", "commissioning", "mwm_bin", "mwm_cb", "mwm_dust", "mwm_erosita", "mwm_filler", "mwm_galactic", "mwm_gg", "mwm_halo", "mwm_legacy", "mwm_magcloud", "mwm_ob", "mwm_planet", "mwm_rv", "mwm_snc", "mwm_tessob", "mwm_tessrgb", "mwm_validation", "mwm_wd", "mwm_yso", "open_fiber", "ops", "ops_sky", "ops_std", "SKY", ] # The valid_category list is from CartonImportTable.pdf valid_category = [ "science", "standard_apogee", "standard_boss", "guide", "sky_boss", "sky_apogee", "standard", "sky", "veto location boss", "veto_location_apogee", ] # The valid_category list is from CartonImportTable.pdf valid_mapper = ["", "MWM", "BHM"] unique_category = numpy.unique(self._table["category"]) if len(unique_category) == 1: self.category = unique_category[0].lower() if self.category not in valid_category: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains invalid category = " + self.category ) else: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains more than one category" ) unique_program = numpy.unique(self._table["program"]) if len(unique_program) == 1: self.program = unique_program[0].lower() if self.program not in valid_program: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains invalid program = " + self.program ) else: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains more than one program" ) unique_mapper = numpy.unique(self._table["mapper"]) if len(unique_mapper) == 1: # We do not use lower() for mapper because # allowed values for mapper are '' or 'MWM' or 'BHM'. self.mapper = unique_mapper[0] if self.mapper not in valid_mapper: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains invalid mapper = " + self.mapper ) if self.mapper == "": self.mapper = None else: raise TargetSelectionError( "Error in get_file_carton(): " + filename + " contains more than one mapper" ) basename_fits = os.path.basename(filename) basename_parts = os.path.splitext(basename_fits) basename = basename_parts[0] carton_name_from_filename = basename.lower() if self.name != carton_name_from_filename: raise TargetSelectionError( "filename parameter of get_file_carton() and " + "cartonname in FITS file do not match." + "\n" + "carton_name_from_filename = " + carton_name_from_filename + " cartonname = " + self.name ) def copy_data(self, temp_table: str): """Copy the input file data to a temporary table. The schema of the file carton table is such that we can dump to a CSV file without issues. """ temp_csv = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) self._table.write(temp_csv.name, format="csv", overwrite=True) cursor = self.database.cursor() temp_csv.seek(0) cursor.copy_expert(f"COPY {temp_table} FROM STDOUT WITH CSV HEADER", temp_csv) self.database.commit() def build_query(self, version_id, query_region=None): self.log.debug(f"Processing file {self._file_path}.") # We need to copy the data to a temporary table so that we can # join on it. We could use a Peewee ValueList but for large tables # that will hit the limit of 1GB in PSQL. # Create model for temporary table from FITS table columns. # This works fine because we know there are no arrays. temp_table = f"inputs_{self.name.lower()}_temp" self.database.execute_sql(f"DROP TABLE IF EXISTS {temp_table};") temp = create_model_from_table(temp_table, self._table) temp._meta.database = self.database temp.create_table(temporary=True) # Copy the data self.copy_data(temp_table) self.database.execute_sql( f'UPDATE "{temp_table}"' + ' SET "Gaia_DR3_Source_ID" = NULL' + ' WHERE "Gaia_DR3_Source_ID"=0' ) self.database.execute_sql( f'UPDATE "{temp_table}"' + ' SET "Gaia_DR2_Source_ID" = NULL' + ' WHERE "Gaia_DR2_Source_ID"=0' ) self.database.execute_sql( f'UPDATE "{temp_table}"' + ' SET "LegacySurvey_DR10_ID" = NULL' + ' WHERE "LegacySurvey_DR10_ID"=0' ) self.database.execute_sql( f'UPDATE "{temp_table}"' + ' SET "LegacySurvey_DR8_ID"= NULL' + ' WHERE "LegacySurvey_DR8_ID"=0' ) self.database.execute_sql( f'UPDATE "{temp_table}"' + ' SET "PanSTARRS_DR2_ID" = NULL' + ' WHERE "PanSTARRS_DR2_ID"=0' ) self.database.execute_sql( f'UPDATE "{temp_table}"' + ' SET "TwoMASS_ID" = NULL' + " WHERE \"TwoMASS_ID\"='NA' " ) self.database.execute_sql( f'CREATE UNIQUE INDEX ON "{temp_table}" ("Gaia_DR3_Source_ID")' ) self.database.execute_sql( f'CREATE UNIQUE INDEX ON "{temp_table}" ("Gaia_DR2_Source_ID")' ) self.database.execute_sql( f'CREATE UNIQUE INDEX ON "{temp_table}" ("LegacySurvey_DR10_ID")' ) self.database.execute_sql( f'CREATE UNIQUE INDEX ON "{temp_table}" ("LegacySurvey_DR8_ID")' ) self.database.execute_sql( f'CREATE UNIQUE INDEX ON "{temp_table}" ("PanSTARRS_DR2_ID")' ) self.database.execute_sql(f'CREATE UNIQUE INDEX ON "{temp_table}" ("TwoMASS_ID")') vacuum_table(self.database, temp_table, vacuum=False, analyze=True) inertial_case = peewee.Case( None, ((temp.inertial.cast("boolean").is_null(), False),), temp.inertial.cast("boolean"), ) # List of columns and aliases for the final query table. query_columns = [ Catalog.catalogid, temp.Gaia_DR3_Source_ID.alias("gaia_dr3_source_id"), temp.Gaia_DR2_Source_ID.alias("gaia_source_id"), temp.LegacySurvey_DR10_ID.alias("ls_id10"), temp.LegacySurvey_DR8_ID.alias("ls_id8"), temp.PanSTARRS_DR2_ID.alias("catid_objid"), temp.TwoMASS_ID.alias("designation"), Catalog.ra, Catalog.dec, temp.delta_ra.cast("double precision"), temp.delta_dec.cast("double precision"), inertial_case.alias("inertial"), temp.cadence, temp.priority, temp.instrument, temp.can_offset.cast("boolean").alias("can_offset"), peewee.Value(0).alias("value"), ] # Calculate number of rows in the table for each parent catalogue identifier and # run some sanity checks. len_table = len(self._table) len_gaia_dr3 = len(self._table[self._table["Gaia_DR3_Source_ID"] > 0]) len_gaia_dr2 = len(self._table[self._table["Gaia_DR2_Source_ID"] > 0]) len_legacysurvey_dr10 = len(self._table[self._table["LegacySurvey_DR10_ID"] > 0]) len_legacysurvey_dr8 = len(self._table[self._table["LegacySurvey_DR8_ID"] > 0]) len_panstarrs_dr2 = len(self._table[self._table["PanSTARRS_DR2_ID"] > 0]) # TwoMass_ID corresponds to the designation column of # the table catalogdb.twomass_psc. # Since the designation column is a text column, below # we are comparing it to the string 'NA' and not the integer 0. len_twomass_psc = len(self._table[self._table["TwoMASS_ID"] != "NA"]) # Make sure this is not an empty table. if len_table == 0: raise TargetSelectionError( f"Error in get_file_carton(): {self._file_path} is an empty table" ) # There must be exactly one non-zero id per row else raise an exception. if ( len_gaia_dr3 + len_gaia_dr2 + len_legacysurvey_dr10 + len_legacysurvey_dr8 + len_panstarrs_dr2 + len_twomass_psc ) != len_table: raise TargetSelectionError( "Error in get_file_carton(): " + "(len_gaia_dr3 + len_gaia_dr2 + " + "len_legacysurvey_dr10 + len_legacysurvey_dr8 +" + "len_panstarrs_dr2 + len_twomass_psc) != " + "len_table" ) # For each identifier that has non-zero targets, appends some information that # we need to create the subquery for that table. This includes the CatalogToX model # to which we need to join, the field on which to join in the parent catalogue, # and the column in the temporary table. In all cases except 2MASS, we join on the # parent catalogue primary key. model_data = [] if len_gaia_dr3 > 0: model_data.append( { "catalog_to": CatalogToGaia_DR3, "parent_field": Gaia_DR3.source_id, "temp_column": "Gaia_DR3_Source_ID", } ) if len_gaia_dr2 > 0: model_data.append( { "catalog_to": CatalogToGaia_DR2, "parent_field": Gaia_DR2.source_id, "temp_column": "Gaia_DR2_Source_ID", } ) if len_legacysurvey_dr10 > 0: model_data.append( { "catalog_to": CatalogToLegacy_Survey_DR10, "parent_field": Legacy_Survey_DR10.ls_id, "temp_column": "LegacySurvey_DR10_ID", } ) if len_legacysurvey_dr8 > 0: model_data.append( { "catalog_to": CatalogToLegacy_Survey_DR8, "parent_field": Legacy_Survey_DR8.ls_id, "temp_column": "LegacySurvey_DR8_ID", } ) if len_panstarrs_dr2 > 0: model_data.append( { "catalog_to": CatalogToPanstarrs1, "parent_field": Panstarrs1.catid_objid, "temp_column": "PanSTARRS_DR2_ID", } ) if len_twomass_psc > 0: model_data.append( { "catalog_to": CatalogToTwoMassPSC, "parent_field": TwoMassPSC.designation, "temp_column": "TwoMASS_ID", } ) if len(model_data) == 0: raise TargetSelectionError( "Error in get_file_carton(): no join model found for the file carton." ) # Create a query for each join model. The final query will be the union of all. # For each join model, we need to account for cases in which an identifier is # associated with more than one catalogid via a window function (we select either # the phase 1 match, or the one with the smallest distance). queries = [] for data in model_data: catalog_to_model = data["catalog_to"] parent_field = data["parent_field"] temp_column = data["temp_column"] # Get the model class field for the column in the temporary table with the # identifier for this case. temp_field = getattr(temp, temp_column) # Get the psrent model. The only reason why we need to join all the way to the # parent catalogue model is 2MASS for which the column TwoMASS_ID in the file # carton corresponds to the designation column in the TwoMassPSC model, which # is not the primary key. parent_model = parent_field.model # Create a subquery that ranks the rows by distance to the target. Since # temp is also used in the main query, we need to alias it. # Note that we are using ROW_NUMBER() and not RANK() because the latter would # assign the same rank to multiple rows with the same distance. This can happen # if two catalogids are associated with a target in phase 1 (an example of this # is a Gaia DR2 target that has been deblended into two Gaia DR3 targets). # We also order by catalogid to ensure that the query is deterministic but # ultimately we are randomly selecting one of the targets here. temp_alias = temp.alias("temp_alias") temp_alias_field = getattr(temp_alias, temp_column) distance_rank_partition = peewee.fn.row_number().over( partition_by=[catalog_to_model.target_id], order_by=[ peewee.fn.coalesce(catalog_to_model.distance, 0.0).asc(), catalog_to_model.catalogid.asc(), ], ) sub_query = ( temp_alias.select( catalog_to_model.catalogid, parent_field.alias("target_id"), distance_rank_partition.alias("distance_rank"), ) .join(parent_model, on=(temp_alias_field == parent_field)) .join(catalog_to_model) .where( catalog_to_model.best >> True, catalog_to_model.version_id == version_id, ) ).alias("distance_rank_subquery") # Now add the main query. We join the subquery to the temporary table to # get all the relevant columns, but keep only the entries with distance_rank=1. queries.append( Catalog.select(*query_columns, sub_query.c.distance_rank) .join( sub_query, on=(Catalog.catalogid == sub_query.c.catalogid), ) .join( temp, on=(temp_field == sub_query.c.target_id), ) .where(sub_query.c.distance_rank == 1) .distinct(Catalog.catalogid) ) # Union all queries. query_union = queries[0] for query in queries[1:]: query_union = query_union | query # It seems to work better if we disable sequential scans and force the use of the # indices. self.database.execute_sql("SET LOCAL enable_seqscan = off;") # Now just distinct on catalogid for all the unions. Although we already have a # distinct in each query, they can yield the same catalogid from different queries. query_union = query_union.cte("query_union") return ( query_union.select(query_union.__star__) .distinct(query_union.c.catalogid) .with_cte(query_union) ) def post_process(self, model, **kwargs): """Runs sanity checks on the output of the query.""" n_file_table = len(self._table) n_query = model.select().count() if n_file_table != n_query: self.log.warning( f"The number of rows in the file table ({n_file_table}) does not " f"match the number of rows returned by the query ({n_query})." f"This may be due to duplicate or invalid external catalog IDs " f"in the manual carton input fits file" ) return FileCarton
def create_table_as( query, table_name, schema=None, temporary=False, database=None, execute=True, overwrite=False, indices=[], analyze=True, ): """Creates a table from a query. Parameters ---------- query A Peewee ``ModelSelect`` or a string with the query to create a table from. table_name The name of the table to create. schema The schema in which to create the table. If ``table_name`` is in the form ``schema.table``, the schema parameter is overridden by ``table_name``. temporary Whether to create a temporary table instead of a persistent one. database The database connection to use to execute the query. If not passed and the query is a ``ModelSelect``, the database will be inherited from the query model. execute Whether to actually execute the query. Requires ``database`` to be passed. overwrite If `True`, the table will be create even if a table with the same name already exists. Requires ``database`` or will be ignored. analyze Whether to ``VACUUM ANALIZE`` the new table. Only relevant if ``execute=True``. indices List of columns to create indices on. Only relevant if ``execute=True``. Returns ------- create_query A tuple in which the first element is a Peewee ``Table`` for the created table (the table is bound to ``database`` if passed), and the ``CREATE TABLE AS`` query as a string. """ if "." in table_name: schema, table_name = table_name.split(".") if schema is None and temporary is False: schema = "public" elif temporary is True: schema = None path = f"{schema}.{table_name}" if schema else table_name create_sql = f"CREATE {'TEMPORARY ' if temporary else ''}TABLE {path} AS " if database is None and isinstance(query, peewee.ModelSelect): database = query.model._meta.database if overwrite and database: database.execute_sql(f"DROP TABLE IF EXISTS {path};") query_sql, params = database.get_sql_context().sql(query).query() cursor = database.cursor() query_str = cursor.mogrify(query_sql, params).decode() if execute: if database is None: raise RuntimeError("Cannot execute query without a database.") with database.atomic(): database.execute_sql(create_sql + query_sql, params) for index in indices: if isinstance(index, (list, tuple)): index = ",".join(index) database.execute_sql(f"CREATE INDEX ON {path} ({index})") if analyze: database.execute_sql(f"VACUUM ANALYZE {path}") table = peewee.Table(table_name, schema=schema).bind(database) create_str = create_sql + query_str return table, create_str