"""
Code for argparse-related functionality.
"""
import _io
import argparse
import inspect
import json
import logging
import os
import pgtoolkit.pgpass
import psycopg2.extensions
import pyspark.sql
from collections.abc import Iterable
from functools import partial
from typing import Optional
from . import (
utils,
database_functions as udbf,
logging_functions as ulf,
spark_functions as usf,
)
# get docstring from main class
current_frame = inspect.currentframe()
while current_frame.f_globals["__name__"] != "__main__":
current_frame = current_frame.f_back
[docs]
class ZFileType(argparse.FileType):
"""
:class:`argparse.FileType` that opens specified argument using
:func:`~actio_python_utils.utils.zopen`
"""
def __call__(self, string: str) -> _io.TextIOWrapper:
"""
:param string: The argument to pass to
:func:`~actio_python_utils.utils.zopen` for opening
:return: A file handle opening `string` appropriately
"""
if string == "-":
return super().__call__(string)
else:
return utils.zopen(string, self._mode)
[docs]
def key_value_pair(arg: str, sep: str = "=") -> tuple[str, str]:
"""
Splits a string once on sep and returns the result
:param arg: The string to split on
:param sep: The separator to split the string on
:raises ValueError: If ``sep`` does not occur in ``arg``
:return: The string split on ``sep`` once
"""
if "=" in arg:
return arg.split("=", 1)
raise ValueError("Argument must be formatted as KEY=VALUE.")
[docs]
def file_exists(fn: str) -> str:
"""
Returns the real path to file ``fn`` if it exists
:param fn: The file name to check
:raises OSError: If ``fn`` doesn't exist
:return: The real path to ``fn``
"""
if os.path.isfile(fn):
return os.path.realpath(fn)
else:
raise OSError(f"{fn} does not exist.")
[docs]
def dir_exists(dirn: str) -> str:
"""
Returns the real path to directory ``dirn`` if it exists
:param dirn: The directory name to check
:raises OSError: If ``dirn`` doesn't exist
:return: The real path to ``dirn``
"""
if os.path.isdir(dirn):
return os.path.realpath(dirn)
else:
raise OSError(f"{dirn} does not exist.")
[docs]
def str_from_file(fn: str) -> str:
"""
Returns the text from a file name
:param fn: The file name to read
:raises OSError: if ``fn`` doesn't exist
:return: The string representing the content of ``fn``
"""
if os.path.isfile(fn):
with utils.zopen(fn) as fh:
return fh.read()
else:
raise OSError(f"{fn} does not exist.")
[docs]
class EnhancedArgumentParser(argparse.ArgumentParser):
r"""
Customized :class:`argparse.ArgumentParser` that sets description
automatically, uses both
:class:`argparse.ArgumentDefaultsHelpFormatter` and
:class:`argparse.RawTextHelpFormatter` formatters,
optionally sets up logging, database, and PySpark connections.
:param \*args: Optional positional arguments passed to
:func:`argparse.ArgumentParser` constructor
:param description: Passed to :func:`argparse.ArgumentParser` constructor
:param formatter_class: The help formatter to use
:param use_logging: Adds log level and log format arguments, then sets up
parsing when :meth:`parse_args` is called
:param use_database: Adds a database service argument, then creates a
connection to the specified database with the attribute name db when
:meth:`parse_args` is called
:param use_spark: Adds spark cores, spark memory, and spark config
arguments, then creates a PySpark session with the attribute name spark
when :meth:`parse_args` is called
:param use_xml: Adds dependencies to PySpark to parse XML files; sets
``use_spark = True``
:param use_glow: Adds dependencies to PySpark to use glow, e.g. to parse
VCF files; sets ``use_spark = True``
:param use_spark_db: Adds dependencies to PySpark to connect to a
database; sets ``use_spark = True`` and creates an object to create a
database connection with PySpark with the attribute name ``spark_db`
when :meth:`parse_args` is called
:param dont_create_db_connection: Don't create a database connection
even if ``use_database = True``
:param spark_extra_packages: Adds additional Spark package dependencies to
initialize; sets ``use_spark = True``
:param \**kwargs: Any additional named arguments
"""
[docs]
def __init__(
self,
*args,
description: str = current_frame.f_globals.get("__doc__", ""),
formatter_class: argparse.HelpFormatter = CustomFormatter,
use_logging: bool = False,
use_database: bool = False,
use_spark: bool = False,
use_xml: bool = False,
use_glow: bool = False,
use_spark_db: bool = False,
dont_create_db_connection: bool = False,
spark_extra_packages: Optional[Iterable[tuple[str, str]]] = None,
**kwargs,
) -> None:
super().__init__(
description=description, formatter_class=formatter_class, **kwargs
)
self.use_logging = use_logging
self.use_database = use_database
self.use_spark = any(
[use_spark, use_xml, use_glow, use_spark_db, spark_extra_packages]
)
self.use_xml = use_xml
self.use_glow = use_glow
self.use_spark_db = use_spark_db
self.dont_create_db_connection = dont_create_db_connection
self.spark_extra_packages = spark_extra_packages
if self.use_logging:
self.add_log_level_argument()
self.add_log_format_argument()
if self.use_database or self.use_spark_db:
self.add_db_service_argument()
if self.use_spark:
self.add_spark_cores_argument()
self.add_spark_memory_argument()
self.add_spark_config_argument()
[docs]
@staticmethod
def sanitize_argument(long_arg: str) -> str:
"""
Converts the argument name to the variable actually used
:param long_arg: The argument name
:return: The reformatted argument
"""
return long_arg.lstrip("-").replace("-", "_")
[docs]
def add_argument(
self,
short_arg: Optional[str] = None,
long_arg: Optional[str] = None,
*args,
**kwargs,
) -> None:
r"""
Adds an argument while retaining metavar instead of dest in help
message
:param short_arg: The short argument name
:param long_arg: The long argument name
:param \*args: Any additional positional arguments
:param \**kwargs: Any additional named arguments
"""
call = partial(
super().add_argument,
*[arg for arg in [short_arg, long_arg] if arg],
*args,
**kwargs,
)
if kwargs.get("action") in ("help", "store_true"):
call()
else:
call(
metavar=EnhancedArgumentParser.sanitize_argument(
utils.coalesce(long_arg, short_arg, "")
)
)
[docs]
def parse_args(
self,
*args,
db_connection_name: str = "db",
spark_name: str = "spark",
spark_db_name: str = "spark_db",
**kwargs,
) -> argparse.Namespace:
r"""
Parses arguments while optionally setting up logging, database, and/or
PySpark.
:param \*args: Any additional positional arguments
:param db_connection_name: The ``args`` attribute name to give to a
created database connection
:param spark_name: The ``args`` attribute name to give to a created
PySpark session
:param spark_db_name: The ``args`` attribute name to give to PostgreSQL
login credentials for use with PySpark
:param \**kwargs: Any additional named arguments
:return: Parsed arguments, additionally with attribute ``db`` as a
database connection if ``use_database = True``, with attribute
``spark`` if ``use_spark = True``, and attribute ``spark_db`` if
``use_spark_db = True``
"""
args = super().parse_args(*args, **kwargs)
if self.use_logging:
self.setup_logging(args)
if self.use_database and not self.dont_create_db_connection:
setattr(args, db_connection_name, self.setup_database(args))
if self.use_spark:
spark, pgpass_record = self.setup_spark(args)
setattr(args, spark_name, spark)
if pgpass_record:
setattr(args, spark_db_name, pgpass_record)
return args
[docs]
def add_log_level_argument(
self,
short_arg: Optional[str] = "-l",
long_arg: Optional[str] = "--log-level",
default: str = utils.cfg["logging"]["level"],
**kwargs,
) -> None:
r"""
Adds an argument to set the logging level, converts it to the proper
integer, and sets ``dest = "log_level"``
:param short_arg: Short argument name to use
:param long_arg: Long argument name to use
:param default: Default logging level value
:param \**kwargs: Any additional named arguments
"""
self.add_argument(
short_arg=short_arg,
long_arg=long_arg,
type=utils.DictToFunc(logging._nameToLevel),
default=default,
dest="log_level",
help=f"the logging level to use (choices are: {{{','.join(logging._nameToLevel)}}})",
**kwargs,
)
[docs]
def add_db_service_argument(
self,
short_arg: Optional[str] = "-s",
long_arg: Optional[str] = "--service",
default: Optional[str] = None,
**kwargs,
):
r"""
Adds an argument to set the database service name sets
``dest = "db_service"``
:param short_arg: Short argument name to use
:param long_arg: Long argument name to use
:param default: Default service
:param \**kwargs: Any additional named arguments
"""
self.add_argument(
short_arg=short_arg,
long_arg=long_arg,
default=default,
dest="db_service",
help="PostgreSQL service name to log in with",
**kwargs,
)
[docs]
def add_spark_cores_argument(
self,
short_arg: Optional[str] = "-c",
long_arg: Optional[str] = "--spark-cores",
default: int | str = utils.cfg["spark"]["cores"],
**kwargs,
) -> None:
r"""
Adds an argument to set the number of PySpark cores to use and sets
``dest = "spark_cores"``
:param short_arg: Short argument name to use
:param long_arg: Long argument name to use
:param default: Default cores
:param \**kwargs: Any additional named arguments
"""
self.add_argument(
short_arg=short_arg,
long_arg=long_arg,
default=default,
dest="spark_cores",
help="the number of cores to provide to Spark",
**kwargs,
)
[docs]
def add_spark_memory_argument(
self,
short_arg: Optional[str] = "-m",
long_arg: Optional[str] = "--spark-memory",
default: str = utils.cfg["spark"]["memory"],
**kwargs,
) -> None:
r"""
Adds an argument to set the amount of memory to give to PySpark and
sets ``dest = "spark_memory"``
:param short_arg: Short argument name to use
:param long_arg: Long argument name to use
:param default: Default memory to use
:param \**kwargs: Any additional named arguments
"""
self.add_argument(
short_arg=short_arg,
long_arg=long_arg,
default=default,
dest="spark_memory",
help="the amount of memory to provide to Spark",
**kwargs,
)
[docs]
def add_spark_config_argument(
self,
short_arg: Optional[str] = None,
long_arg: Optional[str] = "--spark-config",
**kwargs,
) -> None:
r"""
Adds an argument to provide 0 or more options to initialize the PySpark
session with and sets ``dest = "spark_config"``
:param short_arg: Short argument name to use
:param long_arg: Long argument name to use
:param \**kwargs: Any additional named arguments
"""
self.add_argument(
short_arg=short_arg,
long_arg=long_arg,
type=key_value_pair,
dest="spark_config",
action="append",
help="any additional config options to pass to spark (format is KEY=VALUE)",
**kwargs,
)
[docs]
def add_spark_load_config_argument(
self,
short_arg: Optional[str] = None,
long_arg: Optional[str] = "--spark-load-config",
**kwargs,
) -> None:
r"""
Adds an argument to provide 0 or more options to load a dataframe
in PySpark with and sets ``dest = "spark_load_config"``
:param short_arg: Short argument name to use
:param long_arg: Long argument name to use
:param \**kwargs: Any additional named arguments
"""
self.add_argument(
short_arg=short_arg,
long_arg=long_arg,
type=key_value_pair,
dest="spark_load_config",
action="append",
help="any options required to load the the data (format is KEY=VALUE)",
**kwargs,
)
[docs]
def setup_logging(
self,
args: argparse.Namespace,
name: str = "root",
stream: Optional[_io.TextIOWrapper] = None,
stream_handler_logging_level: Optional[str | int] = None,
) -> None:
"""
Sets up logging with
:func:`~actio_python_utils.logging_functions.setup_logging` and
specified log level and format
:param args: Parsed arguments from :meth:`parse_args`
:param name: Logger name to initialize
:param stream: Stream to log to
:param stream_handler_logging_level: Logging level to use for stream
"""
ulf.setup_logging(
logging_level=(
args.log_level if "log_level" in args else utils.cfg["logging"]["level"]
),
name=name,
stream=stream,
stream_handler_logging_level=stream_handler_logging_level,
format_string=(
args.log_format
if "log_format" in args
else utils.cfg["logging"]["format"]
),
)
[docs]
def setup_database(
self, args: argparse.Namespace
) -> psycopg2.extensions.connection:
"""
Returns a psycopg2 connection to the database specified in
`args.db_service`
:param args: Parsed arguments from :meth:`parse_args`
:return: The psycopg2 connection
"""
return udbf.connect_to_db(args.db_service if "db_service" in args else None)
[docs]
def setup_spark(
self, args: argparse.Namespace
) -> tuple[pyspark.sql.session.SparkSession, pgtoolkit.pgpass.PassEntry]:
"""
Returns a list with a created PySpark session and optionally a
PostgreSQL login record if ``use_spark_db = True``
:param args: Parsed arguments from :meth:`parse_args`
:return: A list with the created PySpark session and either a
:class:`pgtoolkit.pgpass.PassEntry` record or ``None``
"""
return_value = []
return_value.append(
usf.setup_spark(
cores=(
args.spark_cores
if "spark_cores" in args
else utils.cfg["spark"]["cores"]
),
memory=(
args.spark_memory
if "spark_memory" in args
else utils.cfg["spark"]["memory"]
),
use_xml=self.use_xml,
use_glow=self.use_glow,
use_db=self.use_spark_db,
extra_options=args.spark_config if "spark_config" in args else None,
extra_packages=self.spark_extra_packages,
)
)
return_value.append(
udbf.get_pg_config(args.db_service if "db_service" in args else None)
if self.use_spark_db
else None
)
return return_value