Source code for target_selection.xmatch

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2020-04-06
# @Filename: xmatch.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

import copy
import hashlib
import inspect
import os
import re
import types
import uuid
import warnings

import networkx
import numpy
import peewee
import rich.markup
import yaml
from networkx.algorithms import shortest_path
from peewee import SQL, Case, Model, fn

from sdssdb.connection import PeeweeDatabaseConnection
from sdssdb.utils.internals import get_row_count
from sdsstools import merge_config
from sdsstools._vendor.color_print import color_text

import target_selection
from target_selection.exceptions import (
    TargetSelectionNotImplemented,
    TargetSelectionUserWarning,
    XMatchError,
)
from target_selection.utils import (
    Timer,
    get_configuration_values,
    get_epoch,
    is_view,
    sql_apply_pm,
    vacuum_outputs,
    vacuum_table,
)


EPOCH = 2016.0
QUERY_RADIUS = 1.0

#: Reserve last 11 bits for the run id.
RUN_ID_BIT_SHIFT = 64 - 11

TEMP_SCHEMA = "sandbox"


class Version(peewee.Model):
    """Model for the version table."""

    id = peewee.AutoField()
    plan = peewee.TextField()
    tag = peewee.TextField()

    class Meta:
        table_name = "version"


class Catalog(peewee.Model):
    """Model for the output table."""

    catalogid = peewee.BigIntegerField(index=True, null=False)
    iauname = peewee.TextField(null=True)
    ra = peewee.DoubleField(null=False)
    dec = peewee.DoubleField(null=False)
    pmra = peewee.FloatField(null=True)
    pmdec = peewee.FloatField(null=True)
    parallax = peewee.FloatField(null=True)
    lead = peewee.TextField(null=False)
    version_id = peewee.IntegerField(null=False, index=True)


class TempCatalog(Catalog):
    """Temporary output table."""

    catalogid = peewee.BigIntegerField(index=True, null=False)
    version_id = peewee.IntegerField(index=False)

    class Meta:
        primary_key = False


[docs] def XMatchModel( Model, resolution=None, ra_column=None, dec_column=None, pmra_column=None, pmdec_column=None, is_pmra_cos=True, parallax_column=None, epoch_column=None, epoch=None, epoch_format="jyear", relational_table=None, has_duplicates=False, has_missing_coordinates=False, skip=False, skip_phases=None, query_radius=None, join_weight=1, allow_multiple_bests=False, database_options=None, ): """Expands the model `peewee:Metadata` with cross-matching parameters. The parameters defined can be accessed with the same name as ``Model._meta.xmatch.<parameter>`` (e.g., ``Model._meta.xmatch.has_duplicates``). Parameters ---------- resolution : float The spatial resolution of the catalogue, in arcsec. ra_column : str The name of the right ascension column. If not provided, an attempt will be made to recover it from the Q3C index, if it exists. Assumed to be in degrees. dec_column : str As ``ra``, for the declination column. pmra_column : str The RA proper motion column, assumed to be in milliarcseconds per year. pmdec_column : str As ``pmra_column`` for the declination proper motion. is_pmra_cos : bool Whether ``pmra_column`` provides the RA proper motion corrected from declination(``pmra * cos(dec)``) or not. parallax_column : str The column containing the parallax, assumed to be in arcsec. epoch_column : str The column containing the epoch of the target coordinates. epoch : float The epoch of the targets that applies to all the records in the table. ``epoch`` and ``epoch_column`` are mutually exclusive. If neither ``epoch_column`` nor ``epoch`` are defined, assumes that the epoch is ``2015.5``. epoch_format : str The format of the epoch. Either Julian year (``'jyear'``) or Julian date (``'jd'``). table_name : str Overrides the default model table name. This can be useful sometimes if, for example, a view has been created that contains only the columns from the main table needed for cross-matching. has_duplicates : bool Whether the table contains duplicates. has_missing_coordinates : bool Whether the catalogue contains rows in which the RA/Dec are null. skip : bool If `True`, the table will be used as a join node but will not be cross-matched. This is useful for testing and also if the table is in a previous version of the configuration file and you are using the ``base_version`` option but want to remove that table. It can also be used when setting a ``join_path`` for a model but otherwise don't want the table to be processed. skip_phases : list A list of cross-matching phases to be skipped for this model. Refer to the `.XMatchPlanner` documentation for definitions on what each phase does. query_radius : float The radius, in arcsec, to use in the radial query for cross-matching. If not provided defaults to the `.XMatchPlanner` value. join_weight : float The weight used by `.XMatchPlanner.get_join_paths` to determine the cost of using this table as a join. Lower weights translate to better chances of that join path to be selected. allow_multiple_bests : bool When set to False, the default, a catalogid can only be associated with one target in the model with ``best=True`` (the catalogid can be associated with multiple targets in phase 2, but only one will be marked as best). When set to True, a catalog id can be associated with multiple targets in the model with ``best=True``. database_options : dict A dictionary of database configuration parameters to be set locally for this model for each processing phase transaction, temporarily overriding the default database configuration. Keys must be the database parameter to modify. The value can be a simple string with the value to set, or a dictionary that more accurately defines when the parameter will be applied. See `.XMatchPlanner` for more information. Returns ------- :obj:`peewee:Model` The same input model with the additional cross-matching parameters added to the metadata namespace ``xmatch``. """ meta = Model._meta meta.xmatch = types.SimpleNamespace() meta.xmatch.resolution = resolution or numpy.nan if not ra_column or not dec_column: indexes = meta.database.get_indexes(meta.table_name, meta.schema) for index in indexes: if "q3c" in index.sql.lower(): match = re.match(r'.+q3c_ang2ipix\("*(\w+)"*, "*(\w+)"*\).+', index.sql) if match: ra_column, dec_column = match.groups() meta.xmatch.ra_column = ra_column meta.xmatch.dec_column = dec_column meta.xmatch.pmra_column = pmra_column meta.xmatch.pmdec_column = pmdec_column meta.xmatch.is_pmra_cos = is_pmra_cos meta.xmatch.parallax_column = parallax_column assert ((epoch is None) & (epoch_column is None)) or ( (epoch is not None) ^ (epoch_column is not None) ), "epoch and epoch_column are mutually exclusive." meta.xmatch.epoch = epoch meta.xmatch.epoch_column = epoch_column meta.xmatch.epoch_format = epoch_format meta.xmatch.relational_table = relational_table meta.xmatch.has_duplicates = has_duplicates meta.xmatch.has_missing_coordinates = has_missing_coordinates meta.xmatch.skip = skip meta.xmatch.skip_phases = skip_phases or [] meta.xmatch.query_radius = query_radius meta.xmatch.allow_multiple_bests = allow_multiple_bests meta.xmatch.row_count = int( get_row_count( meta.database, meta.table_name, schema=meta.schema, approximate=True, ) ) meta.xmatch.join_weight = join_weight meta.xmatch.database_options = database_options or {} return Model
[docs] class XMatchPlanner(object): """Prepares and runs catalogue cross-matching. This class prepares the execution of a cross-matching between multiple catalogues, the result being an output table of unique targets linked to each input catalog via a relational table. Target coordinates are propagated to a common epoch when proper motions are available. Instantiating the class only prepares the process and sets up the processing order; cross-matching itself happens by calling the `.run` method. The output table contains a sequential integer identifier for each unique target (``catalogid``), along with the following columns: ``ra``, ``dec``, ``pmra``, ``pmdec``, ``parallax``, ``lead``, and ``version_id``. ``version_id`` relates the record in the ``version`` table which contains the cross-matching plan and the tag of the code used when it was run. ``lead`` indicates from which one of the input catalogues the coordinates were obtained. The output table is related to each input catalogue via a many-to-many table with the format ``<output_table>_to_<catalogue_table>`` where ``<output_table>`` is the table name of the output table and ``<catalog_table>`` is the table name of the referred catalogue. Each relational table contains the unique identifier column, ``catalogid``, ``target_id`` pointing to the primary key of the referred catalogue, ``version_id``, and a boolean ``best`` indicating whether it is the best (closest) possible match when a target in the output table correspond to multiple targets in the input catalogue. The cross-matching process roughly follows the following process: | .. image:: _static/Catalogdb_Crossmatch.png :scale: 75 % :align: center | In practice the cross-matching process begins by creating a graph of all nodes (tables to be processed and additional tables, ``extra_nodes``, in the database) and edges (foreign keys relating two tables). This graph is used to determine join conditions and to establish the order in which the input models will be processed. The processing order is determined by the ``order`` and ``key`` input parameters. When ``key='row_count'``, tables are sorted by number of decreasing rows so that tables with more targets are processed first (note that to speed things up the row count is always the latest known approximate determined by ``ANALYZE``); if ``key='resolution'`` the associated spatial resolution for a catalogue is used to process catalogues with high resolution first. If ``order='hierarchical'``, all the tables are divided into as many disconnected subgraphs as exist; then for each subgraph the maximum row count or minim resolution is calculated (depending on the value of ``key``). Subgraphs are sorted based on this result and then tables belonging to each subgraph are sorted by key. If ``order='global'`` the ``key`` ordering is applied to all tables without taking into account subgraphs. To speed things up unique targets are initially inserted into a temporary table ``<output_table>_<uid>`` where ``<output_table>`` is the name of the output table and ``<uid>`` is a unique identifier based on the version. Once the order has been determined and when `.run` is called, each table model is processed in order. The first model is just ingested completely into the temporary table and its associated relational table is created if it does not exist (the equivalent of *phase 3* below). For each additional model the following three stages are applied: - In *phase 1* we determine what targets in the input model have an existing cross-match to targets already ingested into the temporary table. To do that we build all possible joins between the model and the temporary table. If multiple joins are possible via a given table only the shortest is used (see `.get_join_paths`). For all matched targets we insert entries in the relational table. The *lead* of the original entries is not changed. - In *phase 2* we perform the actual cross-match between targets in the temporary table and the ones in the input catalogue. Currently the only cross-matching method available is a spatial cone query with radius ``query_radius``. All matched targets are added to the relational table and the one with the smallest distance is defined as *best*. - In *phase 3* we determine any target in the input catalogue that has not been cross-matched at this point and insert them into the temporary table as new entries. The *lead* is set to the input catalogue and the one-to-one match is added to the relational table. In phases 1 and 3 the queries are initially stored as a Postgresql temporary table for efficiency, and then copied to the relational table. After all the tables have been processed the output temporary table is inserted in bulk into the output table and dropped. In addition to the limitations of the spatial cone query method, the following caveats are known: - Input tables with duplicate targets are not currently supported. - In *phase 2* there is no current measure in place for the same target to be associated with more than one catalogid (i.e., each cross-match is performed independently). Parameters ---------- database : peewee:PostgresqlDatabase A `peewee:PostgresqlDatabase` to the database the tables to cross-match. models : list The list of `.XMatchModel` classes to be cross-matched. If the model correspond to a non-existing table it will be silently ignored. plan : str The cross-matching plan version. run_id : int An integer to identify this run of cross-matching. The ID is bit shifted `.RUN_ID_BIT_SHIFT` positions and added to the catalogid. This allows to quickly associate a ``catalogid`` with a run without having to query ``catalogdb``. A ``run_id`` cannot be used if there are targets already in ``catalog`` using that same ``run_id``. version_id The ``catalogdb.version.id`` cross-match version to use. Normally this will be `None`, in which case a new version id will be created. For "addendum" runs, this should be set to the run to which to append. extra_nodes : list List of PeeWee models to be used as extra nodes for joins (i.e., already established cross-matches between catalogues). This models are not processed or inserted into the output table. order : str or list The type of sorting to be applies to the input models to decide in what order to process them. Currently allowed values are ``'hierarchical'`` and ``'global'`` (refer to the description above). The order can also be a list of table names, in which case that order is used without any further sorting. key : str The key to be used while sorting. Can be ``'row_count'`` or ``'resolution'``. epoch : float The epoch to which to convert all the target coordinates as they are inserted into the output table. start_node : str If specified, the name of the table that will be inserted first regarding of the above sorting process. query_radius : float The radius, in arcsec, for cross-matching between existing targets. Used in phase 2. Defaults to 1 arcsec. schema : str The schema in which all the tables to cross-match live (multiple schemas are not supported), and the schema in which the output tables will be created. output_table : str The name of the output table. Defaults to ``catalog``. temp_schema The schema where the temporary ``catalog`` table will be initially created. log A logger to which to log messages. If not provided the ``target_selection`` logger is used. log_path : str The path to which to log or `False` to disable file logging. debug : bool or int Controls the level to which to log to the screen. If `False` no logging is done to ``stdout``. If `True` the logging level is set to ``debug`` (all messages). It's also possible specify a numerical value for the `logging level <logging>`. show_sql : bool Whether to log the full SQL queries being run. sample_region : tuple Allows to specify a 3-element tuple with the (``ra``, ``dec``, ``radius``) of the region to which to limit the cross-match. All values must be in degrees. It can also be a list of tuples, in which case the union of all the regions will be sampled. path_mode : str The mode to look for join paths to link tables to output catalog. ``original`` mode uses a scorched earth algorithm that iteratively looks the shortest path and removes the first node trasspased and starts again until no shortest paths are available. ``full`` mode retrieves all the paths that are not a subsample of another path. ``config_list`` mode takes the list of paths from ``join_paths`` parameter in the .yml configuration file. join_paths : list When using path_mode=``config_list`` is the list of paths to link tables to output catalog table in phase_1. database_options : dict A dictionary of database configuration parameters to be set locally during each phase transaction, temporarily overriding the default database configuration. Keys must be the database parameter to modify. The value can be a simple string with the value to set, or a dictionary that more accurately defines when the parameter will be applied. For example :: database_options: work_mem: '2GB' temp_buffers : {value: '500MB', phases: [3]} In this case the ``temp_buffers='500MB'`` option will only be set for phase 3. Configuration options to be used for a specific table can be set up when defining the `.XMatchModel`. """ def __init__( self, database, models, plan, run_id, version_id=None, extra_nodes=[], order="hierarchical", key="row_count", epoch=EPOCH, start_node=None, query_radius=None, schema="catalogdb", temp_schema=TEMP_SCHEMA, output_table="catalog", log=None, log_path="./xmatch_{plan}.log", debug=False, show_sql=False, sample_region=None, database_options=None, path_mode="full", join_paths=None, ): self.log = log or target_selection.log self.log.header = "" if log_path: log_path = os.path.realpath(log_path) if self.log.fh: self.log.removeHandler(self.log.fh) self.log.fh = None self.log.start_file_logger(log_path.format(plan=plan), rotating=False, mode="a") if debug is True: self.log.sh.setLevel(0) elif debug is False: self.log.sh.setLevel(100) else: self.log.sh.setLevel(debug) self.schema = schema self.temp_schema = temp_schema self.output_table = output_table self.md5 = hashlib.md5(plan.encode()).hexdigest()[0:16] self._temp_table = self.output_table + "_" + self.md5 self.database = database assert self.database.connected, "database is not connected." self.plan = plan self.run_id = run_id self.version_id = version_id self.is_addendum = version_id is not None self.tag = target_selection.__version__ self.log.info(f"plan = {self.plan!r}; run_id = {self.run_id}; tag = {self.tag!r}.") self.log.info(f"Reference Epoch = {epoch}") self.models = {model._meta.table_name: model for model in models} self.extra_nodes = {model._meta.table_name: model for model in extra_nodes} self._check_models() self._options = { "query_radius": query_radius or QUERY_RADIUS, "show_sql": show_sql, "sample_region": sample_region, "epoch": epoch, "database_options": database_options or None, } qrd = self._options["query_radius"] self.log.info(f"Query radius = {qrd}") if self._options["sample_region"]: sample_region = self._options["sample_region"] self.log.warning(f"Using sample region {sample_region!r}.") self._log_db_configuration() self.model_graph = None self.update_model_graph() self.process_order = [] self.set_process_order(order=order, key=key, start_node=start_node) self._max_cid = self.run_id << RUN_ID_BIT_SHIFT self.path_mode = path_mode if path_mode == "config_list": join_paths_warning = "join_paths needed for path_mode=config_list" assert join_paths is not None, join_paths_warning self.join_paths = join_paths
[docs] @classmethod def read(cls, in_models, plan, config_file=None, **kwargs): """Instantiates `.XMatchPlanner` from a configuration file. The YAML configuration file must organised by plan string (multiple plans can live in the same file). Any parameter that `.XMatchPlanner` accepts can be passed via the configuration file. Additionally, the configuration file accepts the two extra parameters: ``exclude``, a list of table names that will be ignored (this is useful if you pass a datbase or base class as ``in_models`` and want to ignore some of the models), and ``tables``, a dictionary of table names with parameters to be passed to `.XMatchModel` for its corresponding model. An example of a valid configuration files is: .. code-block:: yaml '0.1.0': order: hierarchical key: resolution query_radius: 1. schema: catalogdb output_table: catalog start_node: tic_v8 debug: true log_path: false exclude: ['catwise'] tables: tic_v8: ra_column: ra dec_column: dec pmra_column: pmra pmdec_column: pmdec is_pmra_cos: true parallax_column: plx epoch: 2015.5 gaia_dr2_source: ra_column: ra dec_column: dec pmra_column: pmra pmdec_column: pmdec is_pmra_cos: true parallax_column: parallax epoch: 2015.5 skip: true It is also possible to use a parameter ``base_plan`` pointing to a previous plan string. In that case the previous plan configuration will be used as base and the new values will be merged (the update happens recursively as with normal Python dictionaries). Note that only models that match the table names in ``tables`` will be passed to `.XMatchPlanner` to be processed; any other table will be used as an extra joining node unless it's listed in ``exclude``, in which case it will be ignored completely. It's possible to set the ``skip`` option for a table; this has the same effect as removing the entry in ``table``. Parameters ---------- in_models The models to cross-match. Can be a list or tuple of PeeWee :obj:`peewee:Model` instances, a base class from which all the models to use subclass, or a `~sdssdb.connection.PeeweeDatabaseConnection` to the database containing the models. In the latter case, the models must have been imported so that they are available via the ``models`` attribute. plan : str The cross-matching plan. config_file : str or dict The path to the configuration file to use. Defaults to ``config/xmatch.yml``. The file must contain a hash with the cross-match plan. kwargs : dict User arguments that will override the configuration file values. """ # HACK: this ensures that the catalogdb.models are populated. In principle thos would not # work if schema != catalogdb but anyway many other things would fail in that case. from sdssdb.peewee.sdss5db import catalogdb # noqa if isinstance(in_models, (list, tuple)): models = in_models elif inspect.isclass(in_models) and issubclass(in_models, Model): database = in_models._meta.database models = set(in_models.__subclasses__()) while True: old_models = models.copy() for model in old_models: models |= set(model.__subclasses__()) if models == old_models: break elif isinstance(in_models, PeeweeDatabaseConnection): database = in_models models = database.models.values() else: raise TypeError(f"invalid input of type {type(in_models)!r}") assert database.connected, "database is not connected." if config_file is None: config_file = os.path.dirname(target_selection.__file__) + "/config/xmatch.yml" config = XMatchPlanner._read_config(config_file, plan) table_config = config.pop("tables", {}) or {} exclude_nodes = config.pop("exclude_nodes", []) or [] assert "schema" in config, "schema is required in configuration." schema = config["schema"] models = { model._meta.table_name: model for model in models if model._meta.schema == schema } xmatch_models = {} for table_name in table_config: if table_name not in models: continue table_params = table_config[table_name] or {} xmatch_models[table_name] = XMatchModel(models[table_name], **table_params) extra_nodes_config = config.pop("extra_nodes", []) for table_name in extra_nodes_config: if table_name in models: xmatch_models[table_name] = XMatchModel(models[table_name]) else: warnings.warn( f"Cannot find model for extra node {table_name!r}.", TargetSelectionUserWarning, ) extra_nodes = [ models[table_name] for table_name in models if table_name not in xmatch_models and table_name not in exclude_nodes ] config.update(kwargs) signature = inspect.signature(XMatchPlanner) valid_kw = {} for kw in config: if kw not in signature.parameters: warnings.warn( f"ignoring invalid configuration value {kw!r}.", TargetSelectionUserWarning, ) continue valid_kw[kw] = config[kw] return cls(database, xmatch_models.values(), plan, extra_nodes=extra_nodes, **valid_kw)
@staticmethod def _read_config(file_, plan): """Reads the configuration file, recursively.""" if isinstance(file_, dict): config = copy.deepcopy(file_) else: config = yaml.load(open(file_, "r"), Loader=yaml.SafeLoader) assert plan in config, f"plan {plan!r} not found in configuration." base_plan = config[plan].pop("base_plan", None) if base_plan: config = merge_config(config[plan], XMatchPlanner._read_config(file_, base_plan)) else: config = config[plan] return config def _check_models(self): """Checks the input analyse models.""" catalog_tname = self.output_table # Remove extra nodes that are relational tables because we'll use # temporary relational tables instead. self.extra_nodes = { tname: self.extra_nodes[tname] for tname in self.extra_nodes if ( not tname.startswith(catalog_tname + "_to_") and not tname == catalog_tname and self.extra_nodes[tname].table_exists() ) } for tname in list(self.models.keys()): model = self.models[tname] meta = model._meta view_exists = any( [is_view(self.database, tname, self.schema, mat) for mat in [True, False]] ) if not model.table_exists() and not view_exists: self.log.warning(f"table {tname!r} does not exist.") elif tname == catalog_tname: pass elif tname.startswith(catalog_tname + "_to_"): pass elif meta.xmatch.skip: self.extra_nodes[tname] = model else: continue self.models.pop(tname) if len(self.models) == 0: raise XMatchError("no models to cross-match.") def _log_db_configuration(self): """Logs some key database configuration parameters.""" parameters = [ "shared_buffers", "effective_cache_size", "wal_buffers", "effective_io_concurrency", "work_mem", "max_worker_processes", "random_page_cost", "seq_page_cost", "cpu_index_tuple_cost", "cpu_operator_cost", "default_statistics_target", "temp_buffers", "plan_cache_mode", "geqo_effort", "force_parallel_mode", "enable_seqscan", "enable_nestloop", ] values = get_configuration_values(self.database, parameters) self.log.debug("Current database configuration parameters.") for parameter in values: log_str = f"{parameter} = {values[parameter]}" self.log.debug(log_str)
[docs] def update_model_graph(self): """Updates the model graph using models as nodes and fks as edges.""" self.model_graph = networkx.Graph() self.model_graph.add_node(self._temp_table, model=TempCatalog) all_models = list(self.models.values()) + list(self.extra_nodes.values()) for model in all_models: table_name = model._meta.table_name self.model_graph.add_node(table_name, model=model) for model in all_models: table_name = model._meta.table_name for fk_model in model._meta.model_refs: ref_table_name = fk_model._meta.table_name if ref_table_name not in self.model_graph.nodes: continue # Determines the join weight as the average of the # weights of the joined nodes. if hasattr(model._meta, "xmatch"): model_weight = model._meta.xmatch.join_weight else: model_weight = 1 if hasattr(fk_model._meta, "xmatch"): fk_model_weight = fk_model._meta.xmatch.join_weight else: fk_model_weight = 1 join_weight = 0.5 * (model_weight + fk_model_weight) self.model_graph.add_edge(table_name, ref_table_name, join_weight=join_weight) if model in self.models.values(): rel_model = self.get_relational_model(model, sandboxed=False) rel_model._meta.schema = self.schema rel_model_tname = rel_model._meta.table_name self.model_graph.add_node(rel_model_tname, model=rel_model) self.model_graph.add_edge(self._temp_table, rel_model_tname) self.model_graph.add_edge(table_name, rel_model_tname) return self.model_graph
[docs] def set_process_order(self, order="hierarchical", key="row_count", start_node=None): """Sets and returns the order in which tables will be processed. See `.XMatchPlanner` for details on how the order is decided depending on the input parameters. """ if isinstance(order, (list, tuple)): self.log.info(f"processing order: {order}") self.process_order = order return order assert order in ["hierarchical", "global"], f"invalid order {order!r}." self.log.info(f"processing order mode is {order!r}") assert key in ["row_count", "resolution"], f"invalid key {key}." self.log.info(f"ordering key is {key!r}.") graph = self.model_graph.copy() for model in self.extra_nodes.values(): graph.remove_node(model._meta.table_name) if order == "hierarchical": subgraphs = networkx.connected_components(graph) else: subgraphs = [node for node in graph.nodes] subgraphs_ext = [] for sg in subgraphs: if start_node and start_node in sg: # Last item in record is 0 for initial table, 1 for other. # This prioritises the initial table in a sort without # having to reverse. subgraphs_ext.append((sg, numpy.nan, 0)) else: if key == "row_count": total_row_count = sum([self.models[tn]._meta.xmatch.row_count for tn in sg]) # Use -total_row_count to avoid needing reverse order. subgraphs_ext.append((sg, -total_row_count, 1)) elif key == "resolution": resolution = [self.models[tn]._meta.xmatch.resolution for tn in sg] if all(numpy.isnan(resolution)): min_resolution = numpy.nan else: min_resolution = numpy.nanmin(resolution) subgraphs_ext.append((sg, -(min_resolution or numpy.nan), 1)) subgraphs_ordered = list(zip(*sorted(subgraphs_ext, key=lambda x: (x[2], x[1], x[0]))))[0] if order == "global": ordered_tables = subgraphs_ordered else: ordered_tables = [] for sgo in subgraphs_ordered: sg_ext = [] for table_name in sgo: model = self.models[table_name] if start_node and start_node == table_name: sg_ext.append((table_name, numpy.nan, 0)) continue if key == "row_count": row_count = model._meta.xmatch.row_count sg_ext.append((table_name, -row_count, 1)) elif key == "resolution": resolution = model._meta.xmatch.resolution sg_ext.append((table_name, resolution, 1)) # Use table name as second sorting order # to use alphabetic order in case of draw. sg_ordered = list(zip(*sorted(sg_ext, key=lambda x: (x[2], x[1], x[0]))))[0] ordered_tables.extend(sg_ordered) self.log.info(f"processing order: {ordered_tables}") self.process_order = ordered_tables return ordered_tables
[docs] def get_join_paths(self, source, return_models=False, mode="full"): """Determines all possible join path between two tables. Mode ``original`` follows a scorched earth approach in which once an edge has been used for a join it cannot be used again. This produces only distinct joins between the two nodes. Mode ``full`` includes all possible paths that are not a subsample of another path. In both modes paths that include only the source and destination through their relational table are ignored, as are paths in which the last node before the output table has not yet been processed. Finally mode "config_list" is used to take the paths from a list indicated in the configuration file. Weights can be defined by setting the ``join_weight`` value in `.XMatchModel`. The weight for each edge is the average of the the ``join_weight`` of the two nodes joined. The weight defaults to 1. Lower weights translate to better chances of that join path to be selected. Parameters ---------- source : str The initial table for the path. return_models : bool If `True`, returns each path as a list of models. Otherwise returns the table names. mode : str The method used to obtain the paths. Possible values are ``original``, ``full``, and ``config_list`` Returns ------- `list` A list in which each item is a path with the table names or models joining ``source`` to ``dest``. The list is sorted in order of increasing path length. """ graph = self.model_graph.copy() porder = self.process_order paths = [] if mode == "original": dest = self._temp_table rel_table_name = self._temp_table + "_to_" + source if rel_table_name in graph: graph.remove_node(rel_table_name) while True: try: spath = shortest_path(graph, source, dest, weight="join_weight") if spath[-3] in porder and porder.index(spath[-3]) < porder.index(source): paths.append(spath) graph.remove_node(spath[1]) except networkx.NetworkXNoPath: break if mode == "full": ind = porder.index(source) for dest in porder[:ind]: paths1 = list(networkx.all_simple_paths(graph, source, dest, cutoff=1)) if len(paths1) > 0: paths += paths1 graph.remove_node(paths1[0][1]) # Only 1 possible path for dest in porder[:ind]: paths2orless = list(networkx.all_simple_paths(graph, source, dest, cutoff=2)) paths2 = [el for el in paths2orless if len(el) == 3] if len(paths2) > 0: paths += paths2 used2_edges = [(el[1], el[2]) for el in paths2] graph.remove_edges_from(used2_edges) for dest in porder[:ind]: forb_couples = [set(el[1:]) for el in paths if len(el) == 3] paths3orless = list(networkx.all_simple_paths(graph, source, dest, cutoff=3)) paths3 = [ el for el in paths3orless if len(el) == 4 and {el[1], el[3]} not in forb_couples ] if len(paths3) > 0: paths += paths3 full_paths = [] for path in paths: full_path = path + ["catalog_to_" + path[-1], self._temp_table] full_paths.append(full_path) paths = full_paths if mode == "config_list": all_paths = self.join_paths paths = [path for path in all_paths if path[0] == source] if len(paths) == 0: return [] if return_models: nodes = self.model_graph.nodes paths = [[nodes[node]["model"] for node in path] for path in paths] return paths
def _prepare_models(self): """Prepare the Catalog, CatalogTemp, and Version models.""" # Sets the metadata of the Catalog, TempCatalog, and Version tables. Catalog._meta.schema = self.schema Catalog._meta.table_name = self.output_table Catalog._meta.set_database(self.database) TempCatalog._meta.schema = self.temp_schema TempCatalog._meta.table_name = self._temp_table TempCatalog._meta.set_database(self.database) Version._meta.schema = self.schema Version._meta.set_database(self.database) def _check_version(self, model, force=False): """Checks if a model contains a plan version.""" with self.database.atomic(): self.database.execute_sql("SET LOCAL enable_seqscan = off;") vexists = ( peewee.Select( columns=[ fn.EXISTS( model.select(SQL("1")).where(model.version_id == self.version_id) ) ] ) .tuples() .execute(self.database) )[0][0] if vexists: msg = ( f"{model._meta.table_name!r} contains records for this " f"cross-matching plan ({self.plan})." ) if force: self.log.warning(msg) else: raise XMatchError(msg) def _create_models(self, force=False): """Creates models and performs some checks.""" self._prepare_models() # Bind models self.database.bind([Catalog, TempCatalog, Version]) if not Version.table_exists(): self.database.create_tables([Version]) self.log.info(f"Created table {Version._meta.table_name}.") if self.version_id is None: version, vcreated = Version.get_or_create(plan=self.plan, tag=self.tag) self.version_id = version.id else: vcreated = False if vcreated: vmsg = "Added version record " else: vmsg = "Using version record " self.log.info(vmsg + f"({self.version_id}, {self.plan}, {self.tag}).") # Make sure the output table exists. if not Catalog.table_exists(): # Add Q3C index for Catalog Catalog.add_index( SQL( f"CREATE INDEX catalog_q3c_idx ON " f"{self.schema}.{self.output_table} " f"(q3c_ang2ipix(ra, dec))" ) ) self.database.create_tables([Catalog]) self.log.info(f"Created table {self.output_table}.") # Check if Catalog already has entries for this xmatch version. if Catalog.table_exists(): self._check_version(Catalog, force or self.is_addendum) if TempCatalog.table_exists(): msg = f"Temporary table {self._temp_table} already exists." if force: self.log.warning(msg) else: raise XMatchError(msg) self._check_version(TempCatalog, force) try: self._temp_count = int( get_row_count( self.database, self._temp_table, schema=self.schema, approximate=True, ) ) except ValueError: self._temp_count = 0 else: # Add Q3C index for TempCatalog TempCatalog.add_index( SQL( f"CREATE INDEX IF NOT EXISTS " f"{self._temp_table}_q3c_idx ON " f"{self.temp_schema}.{self._temp_table} " f"(q3c_ang2ipix(ra, dec))" ) ) self.database.create_tables([TempCatalog]) self.log.info(f"Created table {self._temp_table}.") self._temp_count = 0
[docs] def run( self, vacuum=False, analyze=False, from_=None, force=False, dry_run=False, keep_temp=False, ): """Runs the cross-matching process. Parameters ---------- vacuum : bool Vacuum all output tables before and after processing new catalogues. analyze : bool Analyze all output tables before and after processing new catalogues. from_ : str Table from which to start running the process. Useful in reruns to skip tables already processed. force : bool Allows to continue even if the temporary table exists or the output table contains records for this version. ``force=True`` is assumed if ``from_`` is defined. dry_run : bool If `False`, loads the temporary tables into ``Catalog`` and ``CatalogToXXX``. `True` implies ``keep_temp=True``; all the cross-matching steps will be run but the original tables won't be modified. A dry run can only be executed for a plan with a single catalogue since processing multiple catalogue requires the final tables to have been updated for successive catalogues. keep_temp : bool Whether to keep the temporary table or to drop it after the cross matching is done. """ if len(self.process_order) > 1 and dry_run is True: raise RuntimeError("Cannot dry run with a plan that includes more than one catalogue.") if dry_run: keep_temp = True if vacuum or analyze: cmd = " ".join(("VACUUM" if vacuum else "", "ANALYZE" if analyze else "")).strip() self.log.info(f"Running {cmd} on output tables.") vacuum_outputs( self.database, vacuum=vacuum, analyze=analyze, schema=Catalog._meta.schema, table=Catalog._meta.table_name, ) if TempCatalog.table_exists(): vacuum_table(self.database, f"{self.schema}.{self._temp_table}") # Checks if there are multiple cross-matching plans running at the same # time. This is problematic because there can be catalogid collisions. # This should only be allowed if we define the starting catalogid # manually and make sure there cannot be collisions. # temp_tables = [table for table in self.database.get_tables(self.schema) # if table.startswith(self.output_table + '_') and # not table.startswith(self.output_table + '_to_') and # table != self._temp_table] # if len(temp_tables) > 0: # raise XMatchError('Another cross-match plan is currently running.') self._create_models(force or (from_ is not None)) if from_ is not None: max_cid = TempCatalog.select(fn.MAX(TempCatalog.catalogid)).scalar() self._max_cid = max_cid + 1 # just to be sure numbers dont overlap with Timer() as timer: p_order = self.process_order for norder, table_name in enumerate(p_order): if dry_run and norder > 0: raise RuntimeError("Cannot dry run more than one catalogue.") if from_ and p_order.index(table_name) < p_order.index(from_): self.log.warning(f"Skipping table {table_name}.") continue model = self.models[table_name] self.process_model(model, force=force) if not dry_run: self.load_output_tables(model, keep_temp=keep_temp, vacuum=vacuum) self.log.info(f"Cross-matching completed in {timer.interval:.3f} s.")
[docs] def process_model(self, model, force=False): """Processes a model, loading it into the output table.""" table_name = model._meta.table_name self.log.header = f"[{table_name.upper()}] " self.log.info(f"Processing table {table_name}.") self._log_table_configuration(model) if model._meta.xmatch.has_duplicates: raise TargetSelectionNotImplemented( "handling of tables with duplicates is not implemented." ) rel_model = self.get_relational_model(model, sandboxed=True, create=False) rel_model_table_name = rel_model._meta.table_name if rel_model.table_exists(): if force is False: raise RuntimeError( f"Sandboxed relational table {rel_model_table_name} exists. " "Delete it manually before continuing." ) else: self.log.warning(f"Sandboxed relational table {rel_model_table_name} exists.") # Check if there are already records in catalog for this version. if self.process_order.index(model._meta.table_name) == 0 and not self.is_addendum: is_first_model = True else: is_first_model = False self._phases_run = set() with Timer() as timer: if is_first_model: self._run_phase_3(model) else: self._run_phase_1(model) self._run_phase_2(model) self._run_phase_3(model) self.log.info(f"Fully processed {table_name} in {timer.elapsed:.0f} s.") self.update_model_graph() self.log.header = ""
def _get_model_fields(self, model): """Returns the model fields needed to populate Catalog.""" meta = model._meta xmatch = meta.xmatch fields = meta.fields # List of fields that will become the SELECT clause. model_fields = [] ra_field = fields[xmatch.ra_column] dec_field = fields[xmatch.dec_column] to_epoch = self._options["epoch"] # RA, Dec, and proper motion fields. if model._meta.table_name == "tic_v8": # TODO: this should be handled in a way that can be opted-in from # the configuration, but for now I'll just hardcode it here. pmra_field = fields[xmatch.pmra_column] pmdec_field = fields[xmatch.pmdec_column] delta_years2000 = to_epoch - get_epoch(model) delta_years2015p5 = to_epoch - 2015.5 racorr2000_field, deccorr2000_field = sql_apply_pm( ra_field, dec_field, pmra_field, pmdec_field, delta_years2000, xmatch.is_pmra_cos, ) racorr2015p5_field, deccorr2015p5_field = sql_apply_pm( model.ra_orig, model.dec_orig, pmra_field, pmdec_field, delta_years2015p5, xmatch.is_pmra_cos, ) ra_field = Case( None, [(model.posflag == "gaia2", racorr2015p5_field)], racorr2000_field, ) dec_field = Case( None, [(model.posflag == "gaia2", deccorr2015p5_field)], deccorr2000_field, ) model_fields.extend([ra_field.alias("ra"), dec_field.alias("dec")]) model_fields.extend([pmra_field.alias("pmra"), pmdec_field.alias("pmdec")]) elif xmatch.pmra_column: pmra_field = fields[xmatch.pmra_column] pmdec_field = fields[xmatch.pmdec_column] if (xmatch.epoch and xmatch.epoch != to_epoch) or xmatch.epoch_column: delta_years = to_epoch - get_epoch(model) ra_field, dec_field = sql_apply_pm( ra_field, dec_field, pmra_field, pmdec_field, delta_years, xmatch.is_pmra_cos, ) if not xmatch.is_pmra_cos: pmra_field *= fn.cos(fn.radians(dec_field)) model_fields.extend([ra_field.alias("ra"), dec_field.alias("dec")]) model_fields.extend([pmra_field.alias("pmra"), pmdec_field.alias("pmdec")]) else: pmra_field = peewee.SQL("null") pmdec_field = peewee.SQL("null") model_fields.extend([ra_field.alias("ra"), dec_field.alias("dec")]) # Parallax if xmatch.parallax_column: model_fields.append(fields[xmatch.parallax_column].alias("parallax")) return model_fields
[docs] def get_output_model(self, temporary=False): """Returns the temporary or final output model (``catalog``).""" if temporary: return TempCatalog return Catalog
[docs] def get_relational_model(self, model, sandboxed=False, temp=False, create=False): """Gets or creates a relational table for a given model. When the relational model is ``sandboxed``, the table is created in the temporary schema and suffixed with the same MD5 used for the run. """ cat_table = Catalog._meta.table_name prefix = cat_table + "_to_" meta = model._meta pk = meta.primary_key if isinstance(pk, str) and pk == "__composite_key__": raise XMatchError(f"composite pk found for model {model.__name__!r}.") # Auto/Serial are automatically PKs. Convert them to integers # to avoid having two pks in the relational table. if pk.__class__.field_type == "AUTO": model_pk_class = peewee.IntegerField elif pk.__class__.field_type == "BIGAUTO": model_pk_class = peewee.BigIntegerField else: model_pk_class = pk.__class__ class BaseModel(peewee.Model): catalogid = peewee.BigIntegerField(null=False, index=True) target_id = model_pk_class(null=False, index=True) version_id = peewee.SmallIntegerField(null=False, index=True) distance = peewee.DoubleField(null=True) best = peewee.BooleanField(null=False) plan_id = peewee.TextField(null=True) added_by_phase = peewee.SmallIntegerField(null=True) class Meta: database = meta.database schema = self.temp_schema if sandboxed else meta.schema primary_key = False model_prefix = "".join(x.capitalize() or "_" for x in prefix.rstrip().split("_")) RelationalModel = type(model_prefix + model.__name__, (BaseModel,), {}) if temp: RelationalModel._meta.primary_key = False RelationalModel._meta.composite_key = False RelationalModel._meta.set_table_name(uuid.uuid4().hex[0:8]) RelationalModel._meta.schema = None return RelationalModel table_name = prefix + meta.table_name if sandboxed: table_name += f"_{self.md5}" RelationalModel._meta.table_name = table_name if create and not RelationalModel.table_exists(): RelationalModel.create_table() # Add foreign key field here. We want to avoid Peewee creating it # as a constraint and index if the table is created because that would # slow down inserts. We'll created them manually with add_fks. # Note that we do not create an FK between the relational model and # Catalog because the relationship is only unique on # (catalogid, version_id). RelationalModel._meta.remove_field("target_id") RelationalModel._meta.add_field( "target_id", peewee.ForeignKeyField(model, column_name="target_id", backref="+"), ) return RelationalModel
def _build_join(self, path): """Returns a build query for a given join path.""" model = path[0] query = model.select() for inode in range(1, len(path)): if path[inode] is TempCatalog: query = query.join( TempCatalog, on=(TempCatalog.catalogid == path[inode - 1].catalogid) ) elif path[inode]._meta.table_name == "gaia_dr2_neighbourhood": query = ( query.join(path[inode]) .where(path[inode].angular_distance < 200) # For 1-1 .distinct(path[inode].dr2_source_id) ) # To confirm else: query = query.join(path[inode]) return query def _run_phase_1(self, model): """Runs the linking against matched catalogids stage.""" xmatch = model._meta.xmatch table_name = model._meta.table_name rel_model_sb = self.get_relational_model(model, create=True, sandboxed=True) rel_model = self.get_relational_model(model, create=False, sandboxed=False) model_pk = model._meta.primary_key self.log.info("Phase 1: linking existing targets.") if 1 in model._meta.xmatch.skip_phases: self.log.warning("Skipping due to configuration.") return path_mode = self.path_mode join_paths = self.get_join_paths(table_name, mode=path_mode) if len(join_paths) == 0: self.log.debug(f"No paths found between {table_name} and temporary output table.") return False self.log.debug( f"Found {len(join_paths)} paths between {table_name} and temporary output table." ) for n_path, path in enumerate(join_paths): # Remove the temporary catalog table at the end of the join path # because we only need catalogid and we can get that from the # relational model, saving us one join. path = path[0:-1] join_models = [self.model_graph.nodes[node]["model"] for node in path] # Get the relational model that leads to the temporary catalog # table in the join. We'll want to filter on version_id to avoid # a sequential scan. join_rel_model = join_models[-1] # Whether to allow multiple best matches for the same target. catalogid_condition = rel_model_sb.catalogid == join_rel_model.catalogid if xmatch.allow_multiple_bests: catalogid_condition = peewee.SQL("FALSE") query = ( self._build_join(join_models) .select( model_pk.alias("target_id"), join_rel_model.catalogid, peewee.Value(True).alias("best"), ) .where( join_rel_model.version_id == self.version_id, join_rel_model.best >> True, ) .where( ~fn.EXISTS( rel_model_sb.select(SQL("1")).where( rel_model_sb.version_id == self.version_id, ((rel_model_sb.target_id == model_pk) | catalogid_condition), ) ) ) # Select only one match per target in the catalogue with are cross-matching. .distinct(model_pk) ) # Deal with duplicates in LS8 if table_name == "legacy_survey_dr8": query = query.where(self._get_ls8_where(model)) # Remove Duplicates and non-primary entries from LS10 if table_name == "legacy_survey_dr10": query = query.where( model.survey_primary >> True, fn.coalesce(model.ref_cat, "") != "T2" ) # If the real relational model exists, exclude any matches that already exist there. if rel_model.table_exists(): query = query.where( ~fn.EXISTS( rel_model.select(SQL("1")).where( rel_model.version_id == self.version_id, rel_model.target_id == model_pk, ) ) ) # In query we do not include a Q3C where for the sample region because # TempCatalog for this plan should already be sample region limited. with Timer() as timer: with self.database.atomic(): temp_model = self.get_relational_model(model, temp=True, sandboxed=True) temp_table = temp_model._meta.table_name self._setup_transaction(model, phase=1) self.log.debug( f"Selecting linked targets into temporary " f"table {temp_table!r} with join path " f"{path}{self._get_sql(query)}" ) query.create_table(temp_table, temporary=True) self.log.debug( f"Copying data into relational model {rel_model_sb._meta.table_name!r}." ) fields = [ temp_model.target_id, temp_model.catalogid, temp_model.version_id, temp_model.best, temp_model.plan_id, temp_model.added_by_phase, ] nids = ( rel_model_sb.insert_from( temp_model.select( temp_model.target_id, temp_model.catalogid, peewee.Value(self.version_id), temp_model.best, peewee.Value(self.plan) if self.is_addendum else None, peewee.Value(1), ), fields, ) .returning() .execute() ) self.log.debug(f"Linked {nids.rowcount:,} records in {timer.interval:.3f} s.") self._phases_run.add(1) if nids.rowcount > 0: self._analyze(rel_model_sb) def _run_phase_2(self, model, source=TempCatalog): """Associates existing targets in Catalog with entries in the model. Here ``source`` is the catalogue with which we are spatially cross-matching. Normally this is the temporary catalog table which we are building for this cross-match run. But when we are doing an addendum, that table is going to be empty (at least of the first table in the addeundum), so we need to also have the option of using the real ``catalog`` table as the source. This method will call itself recursively with ``source=Catalog`` if this is an addendum run. """ meta = model._meta xmatch = meta.xmatch table_name = meta.table_name s_table_name = source._meta.table_name self.log.info(f"Phase 2: cross-matching against existing targets ({s_table_name}).") if 2 in xmatch.skip_phases: self.log.warning("Skipping due to configuration.") return rel_model_sb = self.get_relational_model(model, create=True, sandboxed=True) rel_sb_table_name = rel_model_sb._meta.table_name rel_model = self.get_relational_model(model, create=False, sandboxed=False) model_pk = meta.primary_key model_ra = meta.fields[xmatch.ra_column] model_dec = meta.fields[xmatch.dec_column] catalog_epoch = self._options["epoch"] query_radius = xmatch.query_radius or self._options["query_radius"] # Should we use proper motions? model_epoch = get_epoch(model) is_model_expression = isinstance(model_epoch, (peewee.Expression, peewee.Function)) use_pm = model_epoch and (is_model_expression or (model_epoch != catalog_epoch)) if use_pm: self.log.debug("Determining maximum epoch delta between catalogues.") if isinstance(model_epoch, (int, float)): max_delta_epoch = float(abs(model_epoch - catalog_epoch)) else: max_delta_epoch = float( model.select(fn.MAX(fn.ABS(model_epoch - catalog_epoch))) .where(self._get_sample_where(model_ra, model_dec)) .scalar() ) max_delta_epoch += 0.1 # Add .1 yr to be sure it's an upper bound self.log.debug(f"Maximum epoch delta: {max_delta_epoch:.3f} (+ 0.1 year).") if use_pm: model_pmra = meta.fields[xmatch.pmra_column] model_pmdec = meta.fields[xmatch.pmdec_column] model_is_pmra_cos = int(xmatch.is_pmra_cos) q3c_dist = fn.q3c_dist_pm( model_ra, model_dec, model_pmra, model_pmdec, model_is_pmra_cos, model_epoch, source.ra, source.dec, catalog_epoch, ) q3c_join = fn.q3c_join_pm( model_ra, model_dec, model_pmra, model_pmdec, model_is_pmra_cos, model_epoch, source.ra, source.dec, catalog_epoch, max_delta_epoch, query_radius / 3600.0, ) else: q3c_dist = fn.q3c_dist(model_ra, model_dec, source.ra, source.dec) q3c_join = fn.q3c_join( model_ra, model_dec, source.ra, source.dec, query_radius / 3600.0, ) # Get the cross-matched catalogid and model target pk (target_id), # and their distance. xmatched = ( source.select( source.catalogid, model_pk.alias("target_id"), q3c_dist.alias("distance"), source.version_id, ) .join(model, peewee.JOIN.CROSS) .where(q3c_join) .where(self._get_sample_where(model_ra, model_dec)) ) if table_name == "legacy_survey_dr8": xmatched = xmatched.where(self._get_ls8_where(model)) if table_name == "legacy_survey_dr10": xmatched = xmatched.where( model.survey_primary >> True, fn.coalesce(model.ref_cat, "") != "T2", ) # This may break the use of the index but I think it's needed if # the model is the second table in q3c_join and has empty RA/Dec. if xmatch.has_missing_coordinates and use_pm: xmatched = xmatched.where(model_ra.is_null(False), model_dec.is_null(False)) xmatched = xmatched.cte("xmatched", materialized=True) # We'll partition over each group of targets that match the same # catalogid and mark the one with the smallest distance to it as best. partition = fn.first_value(xmatched.c.target_id).over( partition_by=[xmatched.c.catalogid], order_by=[xmatched.c.distance.asc()], ) best = peewee.Value(partition == xmatched.c.target_id) # Select the values to insert. Remove target_ids that were already # present in the relational table due to phase 1. # We separate the filter in two IF NOT EXISTS clauses to # be sure the query planner uses the indexes for each (it # won't necessarily do it if we do an OR). Also, make sure # we compare version_id, target_id and not the other way # because that's the order in which we defined the index. in_query = xmatched.select( xmatched.c.target_id, xmatched.c.catalogid, peewee.Value(self.version_id).alias("version_id"), xmatched.c.distance.alias("distance"), best.alias("best"), self.plan if self.is_addendum else None, peewee.Value(2).alias("added_by_phase"), ) # Whether to allow multiple best matches for the same target. catalogid_condition = rel_model_sb.catalogid == xmatched.c.catalogid if xmatch.allow_multiple_bests: catalogid_condition = peewee.SQL("FALSE") # We only need to care about already linked targets if phase 1 run. if 1 in self._phases_run: in_query = in_query.where( xmatched.c.version_id == self.version_id, ~fn.EXISTS( rel_model_sb.select(SQL("1")).where( (rel_model_sb.version_id == self.version_id) & (catalogid_condition | (rel_model_sb.target_id == xmatched.c.target_id)) ) ), ) # Whether to allow multiple best matches for the same target. catalogid_condition = rel_model.catalogid == xmatched.c.catalogid if xmatch.allow_multiple_bests: catalogid_condition = peewee.SQL("FALSE") if rel_model.table_exists(): in_query = in_query.where( ~fn.EXISTS( rel_model.select(SQL("1")).where( (rel_model.version_id == self.version_id) & (catalogid_condition | (rel_model.target_id == xmatched.c.target_id)) ) ) ) # Make sure we are only spatially cross-matching against targets from the same # cross-match. This only matters for addendum runs because in those we are also # cross-matching against the existing catalogdb.catalog table which contains multiple # versions. in_query = in_query.where(xmatched.c.version_id == self.version_id) with Timer() as timer: with self.database.atomic(): # 1. Tweak database configuration for this transaction to # ensure Q3C index is used. # May need to increase work_mem during this transaction to # make sure the Q3C index is used. self._setup_transaction(model, phase=2) # 2. Run cross-match and insert data into relational model. fields = [ rel_model_sb.target_id, rel_model_sb.catalogid, rel_model_sb.version_id, rel_model_sb.distance, rel_model_sb.best, rel_model_sb.plan_id, rel_model_sb.added_by_phase, ] in_query = rel_model_sb.insert_from( in_query.with_cte(xmatched), fields, ).returning() self.log.debug( f"Running Q3C query and inserting cross-matched data into " f"relational table {rel_sb_table_name!r}: " f"{self._get_sql(in_query)}" ) n_catalogid = in_query.execute().rowcount self.log.debug( f"Cross-matched {source._meta.table_name} with " f"{n_catalogid:,} targets in {table_name}. " f"Run in {timer.interval:.3f} s." ) if n_catalogid > 0: self._phases_run.add(2) self._analyze(rel_model_sb) # For addenda it's not sufficient to cross-match with the temporary table, because that # does not contain all the cumulated targets from this cross-match version. We need to # also spatially cross-match with the real catalog table (but only for the targets with # version_id=<this-version-id>). if self.is_addendum and source != Catalog: self._run_phase_2(model, source=Catalog) def _run_phase_3(self, model): """Add non-matched targets to Catalog and the relational table.""" meta = model._meta xmatch = meta.xmatch self.log.info("Phase 3: adding non cross-matched targets.") rel_model_sb = self.get_relational_model(model, create=True, sandboxed=True) rel_sb_table_name = rel_model_sb._meta.table_name rel_model = self.get_relational_model(model, create=False, sandboxed=False) if 3 in xmatch.skip_phases: self.log.warning("Skipping due to configuration.") return table_name = meta.table_name model_fields = self._get_model_fields(model) model_pk = meta.primary_key model_ra = meta.fields[xmatch.ra_column] model_dec = meta.fields[xmatch.dec_column] unmatched = model.select( (fn.row_number().over() + self._max_cid).alias("catalogid"), model_pk.alias("target_id"), *model_fields, ).where(self._get_sample_where(model_ra, model_dec)) if 1 in self._phases_run or 2 in self._phases_run: unmatched = unmatched.where( ~fn.EXISTS( rel_model_sb.select(SQL("1")).where( rel_model_sb.version_id == self.version_id, rel_model_sb.target_id == model_pk, rel_model_sb.best >> True, ) ) ) if rel_model.table_exists(): unmatched = unmatched.where( ~fn.EXISTS( rel_model.select(SQL("1")).where( rel_model.version_id == self.version_id, rel_model.target_id == model_pk, rel_model.best >> True, ) ) ) if xmatch.has_missing_coordinates: unmatched = unmatched.where(model_ra.is_null(False), model_dec.is_null(False)) # TODO: this is horrible and should be moved to the configuration. if model._meta.table_name == "tic_v8": unmatched = unmatched.where(model.objtype != "EXTENDED") if table_name == "legacy_survey_dr8": unmatched = unmatched.where(self._get_ls8_where(model)) if table_name == "legacy_survey_dr10a": unmatched = unmatched.where(model.ref_cat != "T2") if table_name == "legacy_survey_dr10": unmatched = unmatched.where( model.survey_primary >> True, fn.coalesce(model.ref_cat, "") != "T2" ) with Timer() as timer: with self.database.atomic(): # TODO: Not sure it's worth using a temporary table here. # 1. Run link query and create temporary table with results. self._setup_transaction(model, phase=3) temp_model = self.get_relational_model(model, temp=True, sandboxed=True) temp_model_name = temp_model._meta.table_name self.log.debug( f"Selecting unique targets " f"into temporary table " f"{temp_model_name!r}{self._get_sql(unmatched)}" ) unmatched.create_table(temp_model_name, temporary=True) # Analyze the temporary table to gather stats. # self.log.debug('Running ANALYZE on temporary table.') # self.database.execute_sql(f'ANALYZE "{temp_model_name}";') # 2. Copy data from temporary table to relational table. Add # catalogid at this point. fields = [ temp_model.catalogid, temp_model.target_id, temp_model.version_id, temp_model.best, rel_model_sb.plan_id, rel_model_sb.added_by_phase, ] rel_insert_query = rel_model_sb.insert_from( temp_model.select( temp_model.catalogid, temp_model.target_id, self.version_id, peewee.SQL("true"), self.plan if self.is_addendum else None, peewee.Value(3).alias("added_by_phase"), ), fields, ).returning() self.log.debug( f"Copying data into relational model " f"{rel_sb_table_name!r}" f"{self._get_sql(rel_insert_query)}" ) cursor = rel_insert_query.execute() n_rows = cursor.rowcount self.log.debug( f"Insertion into {rel_sb_table_name} completed " f"with {n_rows:,} rows in " f"{timer.elapsed:.3f} s." ) # 3. Fill out the temporary catalog table with the information # from the unique targets. temp_table = peewee.Table(temp_model_name) fields = [ TempCatalog.catalogid, TempCatalog.lead, TempCatalog.version_id, ] select_columns = [ temp_table.c.catalogid, peewee.Value(table_name), self.version_id, ] for field in model_fields: if field._alias == "ra": fields.append(TempCatalog.ra) select_columns.append(temp_table.c.ra) elif field._alias == "dec": fields.append(TempCatalog.dec) select_columns.append(temp_table.c.dec) elif field._alias == "pmra": fields.append(TempCatalog.pmra) select_columns.append(temp_table.c.pmra) elif field._alias == "pmdec": fields.append(TempCatalog.pmdec) select_columns.append(temp_table.c.pmdec) elif field._alias == "parallax": fields.append(TempCatalog.parallax) select_columns.append(temp_table.c.parallax) insert_query = TempCatalog.insert_from( temp_table.select(*select_columns), fields ).returning() self.log.debug( f"Running INSERT query into {self._temp_table}{self._get_sql(insert_query)}" ) cursor = insert_query.execute() n_rows = cursor.rowcount # Avoid having to calculate max_cid again self._max_cid += n_rows self._temp_count += n_rows self.log.debug(f"Inserted {n_rows:,} rows. Total time: {timer.elapsed:.3f} s.") self._phases_run.add(3) if n_rows > 0.5 * self._temp_count: # Cluster if > 50% of rows are new self.log.debug(f"Running CLUSTER on {self._temp_table} with q3c index.") self.database.execute_sql( f"CLUSTER {self.temp_schema}.{self._temp_table} using {self._temp_table}_q3c_idx;" ) self.log.debug(f"Running ANALYZE on {self._temp_table}.") self.database.execute_sql(f"ANALYZE {self.temp_schema}.{self._temp_table};") self._analyze(rel_model_sb, catalog=False)
[docs] def load_output_tables(self, model, keep_temp=False, vacuum=True): """Loads the temporary tables into the output tables.""" self._load_output_table(TempCatalog, Catalog, keep_temp=keep_temp, vacuum=vacuum) rel_model_sb = self.get_relational_model(model, sandboxed=True, create=False) rel_model = self.get_relational_model(model, sandboxed=False, create=True) self._load_output_table(rel_model_sb, rel_model, keep_temp=keep_temp, vacuum=vacuum)
def _load_output_table(self, from_model, to_model, keep_temp=False, vacuum=True): """Copies the temporary table to the real output table.""" to_table = f"{to_model._meta.schema}.{to_model._meta.table_name}" from_table = f"{from_model._meta.schema}.{from_model._meta.table_name}" self.log.info(f"Copying {from_table} to {to_table}.") with Timer() as timer: with self.database.atomic(): self._setup_transaction() insert_query = to_model.insert_from( from_model.select(), from_model.select()._returning ).returning() self.log.debug( f"Running INSERT query into {to_table} {self._get_sql(insert_query)}" ) cursor = insert_query.execute() n_rows = cursor.rowcount self.log.debug(f"Inserted {n_rows:,} rows in {timer.elapsed:.3f} s.") if not keep_temp: self.database.drop_tables([from_model]) self.log.info(f"Dropped temporary table {from_table}.") if n_rows > 0 and vacuum: self.log.debug(f"Running VACUUM ANALYZE on {to_table}.") vacuum_table(self.database, to_table, vacuum=True, analyze=True) def _get_sql(self, query, return_string=False): """Returns colourised SQL text for logging.""" query_str, query_params = query.sql() if query_params: for ind in range(len(query_params)): if isinstance(query_params[ind], str): query_params[ind] = "'" + query_params[ind] + "'" query_str = query_str % tuple(query_params) query_str = query_str.replace("None", "Null") if return_string: return query_str elif self.log.rich_console: return f": {rich.markup.escape(query_str)}" elif self._options["show_sql"]: return f": {color_text(query_str, 'blue')}" else: return "." def _setup_transaction(self, model=None, phase=None): """Sets database parameters for the transaction.""" if not self._options["database_options"]: return options = self._options["database_options"].copy() if model and model._meta.xmatch.database_options: options.update(model._meta.xmatch.database_options) for param in options: if param == "maintenance_work_mem": continue param_config = options[param] if isinstance(param_config, dict): value = param_config.get("value") param_phases = param_config.get("phase", None) if param_phases and (not phase or phase not in param_phases): continue else: value = param_config stm = f"SET LOCAL {param}={value!r};" self.database.execute_sql(stm) self.log.debug(stm) def _get_sample_where(self, ra_field, dec_field): """Returns the list of conditions to sample a model.""" sample_region = self._options["sample_region"] if sample_region is None: return True if len(sample_region) == 3 and not isinstance(sample_region[0], (list, tuple)): return fn.q3c_radial_query( ra_field, dec_field, sample_region[0], sample_region[1], sample_region[2], ) sample_conds = peewee.SQL("false") for ra, dec, radius in sample_region: sample_conds |= fn.q3c_radial_query(ra_field, dec_field, ra, dec, radius) return sample_conds def _get_ls8_where(self, model): """Removes duplicates from LS8 queries.""" return ~( ( (model.release == 8000) & (model.dec > 32.375) & (model.ra > 100.0) & (model.ra < 300.0) ) | ((model.release == 8001) & (model.dec < 32.375)) )
[docs] def show_join_paths(self): """Prints all the available joint paths. This is useful before call `.run` to make sure the join paths to be used are correct or adjust the table weights otherwise. Note that the paths starting from the first model to be processed are ignored. """ if len(self.process_order) == 1: return mode = self.path_mode for table in self.process_order[1:]: paths = self.get_join_paths(table, mode=mode) if paths: for path in paths: print(path)
def _analyze(self, rel_model, vacuum=False, catalog=False): """Analyses a relational model after insertion.""" schema = rel_model._meta.schema table_name = rel_model._meta.table_name db_opts = self._options["database_options"] if db_opts: work_mem = db_opts.get("maintenance_work_mem", None) if work_mem: self.database.execute_sql(f"SET maintenance_work_mem = {work_mem!r}") self.log.debug(f"Running ANALYZE on {table_name}.") vacuum_table(self.database, f"{schema}.{table_name}", vacuum=vacuum, analyze=True) if catalog: self.log.debug(f"Running ANALYZE on {self._temp_table}.") vacuum_table( self.database, f"catalogdb.{self._temp_table}", vacuum=vacuum, analyze=True, ) def _log_table_configuration(self, model): """Logs the configuration used to cross-match a table.""" xmatch = model._meta.xmatch parameters = [ "ra_column", "dec_column", "pmra_column", "pmdec_column", "is_pmra_cos", "parallax_column", "epoch", "epoch_column", "epoch_format", "has_duplicates", "skip", "skip_phases", "query_radius", "row_count", "resolution", "join_weight", "has_missing_coordinates", ] self.log.debug("Table cross-matching parameters:") for parameter in parameters: value = getattr(xmatch, parameter, None) if parameter == "query_radius": value = value or QUERY_RADIUS if value is True or value is False or value is None: value = str(value).lower() if parameter == "row_count": self.log.debug(f"{parameter}: {value:,}") else: self.log.debug(f"{parameter}: {value}")