from __future__ import annotations
import dataclasses
# from .persisted_query import PersistedQuery
import itertools
from abc import ABCMeta
from collections.abc import Collection, Hashable, Iterable, Mapping
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
NamedTuple,
Protocol,
TypedDict,
TypeVar,
cast,
)
from ..input.strparsing import InstanceTracking
from ..utils import unique
from .dialect import SQLDialect
if TYPE_CHECKING:
import pandas as pd # pyright: ignore[reportMissingTypeStubs]
import polars as pl
from typing_extensions import Self
# import public interface so we can avoid internal ._.... appearing in
# function signatures, doco, etc.
import csql
import csql.dialect
import csql.overrides
import csql.persist
from .overrides import Overrides
ScalarParameterValue = Hashable
ParameterList = tuple[ScalarParameterValue, ...]
[docs]
class RenderedQuery(NamedTuple):
"""
A :class:`RenderedQuery` is a pair of ``(sql, parameters)``, ready
to be passed directly to a database.
They are obtained by using :meth:`Query.build`.
"""
sql: str
""" The rendered SQL, ready to be passed to a database. """
parameters: ParameterList
""" A tuple of parameters, to go along with the SQL. """
parameter_names: tuple[str | None, ...]
""" A tuple of parameter names that the parameters were passed as. """
# utility properties for easy splatting
@property
def pd(self) -> dict[str, Any]:
"""
Gives dict of ``{'sql':sql, 'params':params}``, for usage like:
>>> con = my_connection()
>>> q = Q('select 123')
>>> pd.read_sql(**q.build().pd, con=con) # doctest: +IGNORE_RESULT
"""
return {"sql": self.sql, "params": self.parameters}
@property
def db(self) -> tuple[str, ParameterList]:
"""
Returns a tuple of (sql, params), for usage like:
>>> con = my_connection()
>>> q = Q('select 123')
>>> con.cursor().execute(*q.build().db) # doctest: +IGNORE_RESULT
"""
return (self.sql, self.parameters)
@property
def params_dict(self) -> dict[str, Hashable]:
return {
k: v for k, v in zip(self.parameter_names, self.parameters) if k is not None
}
@property
def ch(self) -> ClickhouseQueryArgs:
return {"query": self.sql, "parameters": self.params_dict}
@property
def ddb(self) -> DuckDBQueryArgs:
return {"query": self.sql, "params": self.parameters}
@property
def pl(self) -> PolarsQueryArgs:
return {"query": self.sql, "execute_options": {"parameters": self.parameters}}
def __repr__(self) -> str:
return f"RenderedQuery({self.sql!r}, {self.parameters!r})"
[docs]
class ClickhouseQueryArgs(TypedDict):
query: str
parameters: dict[str, Hashable]
[docs]
class DuckDBQueryArgs(TypedDict):
query: str
params: ParameterList
[docs]
class PolarsQueryArgs(TypedDict):
query: str
execute_options: dict[str, Any]
[docs]
class QueryBit(metaclass=ABCMeta):
pass
[docs]
class QueryExtension(metaclass=ABCMeta):
pass
QE = TypeVar(
"QE", bound=QueryExtension
) # should be bound=Intersection[QueryExtension, NamedTuple]
class PreBuildHook(Protocol):
def __call__(self) -> Query | None: ...
@dataclass(frozen=True)
class PreBuild(QueryExtension):
hook: PreBuildHook
NOVALUE = "_csql_novalue"
[docs]
@dataclass(frozen=True)
class Query(QueryBit, InstanceTracking):
"""
A Query is CSQL's structured concept of a SQL query. You should not create these directly,
instead you should use :func:`csql.Q`.
"""
queryParts: tuple[str | QueryBit, ...]
":meta private:"
default_dialect: csql.dialect.SQLDialect | csql.dialect.InferOrDefault
":meta private:"
default_overrides: Overrides | csql.overrides.InferOrDefault | None
":meta private:"
_extensions: frozenset[QueryExtension]
":meta private:"
## deps
def _getDeps_(self) -> Iterable[Query]:
queryDeps = (part for part in self.queryParts if isinstance(part, Query))
for dep in queryDeps:
yield from dep._getDeps_()
yield dep
def _getDeps(self) -> Iterable[Query]:
return unique(self._getDeps_(), fn=id)
## extensions
def _get_extension(self, t: type[QE]) -> QE | None:
exts = {type(e): e for e in self._extensions} # could memoize this
return cast(QE, exts.get(t)) # mypy sucks
def _add_extensions(self, *e: QueryExtension) -> Query:
return dataclasses.replace(self, _extensions=self._extensions | set(e))
def _default_dialect(self) -> SQLDialect:
from .dialect import InferOrDefault
d = self.default_dialect
return d.dialect if isinstance(d, InferOrDefault) else d
def _default_overrides(self) -> Overrides | None:
from .overrides import InferOrDefault
o = self.default_overrides
return o.overrides if isinstance(o, InferOrDefault) else o
[docs]
def preview_pd(
self,
con: Any,
rows: int | None = 10,
dialect: csql.dialect.SQLDialect | None = None,
newParams: Mapping[str, ParameterValue] | None = None,
overrides: csql.overrides.Overrides | None = None,
) -> pd.DataFrame:
"""
Return a small dataframe to preview the results of this query.
Usage:
>>> c = my_connection()
>>> q = Q(f'''select 123 as val''')
>>> print(q.preview_pd(c))
val
0 123
:param con: A DBAPI-compliant connection, passed directly to ``con`` arg of :func:`pandas.read_sql`.
:param rows: The number of rows to pull.
:rtype: :class:`pandas.DataFrame`
"""
import pandas as pd # pyright: ignore[reportMissingTypeStubs]
from ..utils import limit_query
previewQ = limit_query(self, rows, dialect)
return pd.read_sql( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
**previewQ.build(
dialect=dialect, newParams=newParams, overrides=overrides
).pd,
con=con,
)
[docs]
def preview_pl(
self,
con: Any,
rows: int | None = 10,
dialect: csql.dialect.SQLDialect | None = None,
newParams: Mapping[str, ParameterValue] | None = None,
overrides: csql.overrides.Overrides | None = None,
) -> pl.DataFrame:
"""
Return a small polars DataFrame to preview the results of this query.
:rtype: :class:`polars.DataFrame`
"""
import polars as pl
from ..utils import limit_query
previewQ = limit_query(self, rows, dialect)
preview = previewQ.build(
dialect=dialect, newParams=newParams, overrides=overrides
)
try:
import clickhouse_connect # pyright: ignore[reportMissingTypeStubs]
import clickhouse_connect.driver # pyright: ignore[reportMissingTypeStubs]
except ImportError:
clickhouse_connect = None
if clickhouse_connect is not None and isinstance(
con, clickhouse_connect.driver.Client
):
return con.query_df_arrow(**preview.ch, dataframe_library="polars") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] # mypy: ignore[no-any-return]
try:
import duckdb
except ImportError:
duckdb = None # mypy: ignore[assignment]
if duckdb is not None and isinstance(con, duckdb.DuckDBPyConnection):
return con.query(**preview.ddb).pl()
return pl.read_database( # pyright: ignore[reportUnknownMemberType]
preview.sql, con, execute_options={"parameters": preview.parameters}
)
[docs]
def build(
self,
*,
dialect: csql.dialect.SQLDialect | None = None,
newParams: Mapping[str, ParameterValue] | None = None,
overrides: csql.overrides.Overrides | None = None,
) -> csql.RenderedQuery:
"""
Build this :class:`csql.Query` into a :class:`csql.RenderedQuery`.
While you can specify paramters to manually override how this Query is rendered, it's normally
better to just supply these as defaults when you create your Queries in the first place. See: :ref:`sql-dialects`.
:param dialect: An optional :class:`csql.dialect.SQLDialect` to render as. See :ref:`sql-dialects`.
:param newParams: A dictionary of ``{'key': value}`` to override any parameters. See: :ref:`reparam`.
:param overrides: An optional :class:`csql.overrides.Overrides` to override how rendering workd. See: :ref:`overrides`.
"""
dialect = dialect or self._default_dialect()
from ..persist import cache_replacer
from ..renderer.parameters import ParameterRenderer
from ..renderer.query import BoringSQLRenderer, QueryRenderer
from .overrides import Overrides
from .query_replacers import (
params_replacer,
pre_build_replacer,
replace_queries_in_tree,
)
overrides = overrides or self._default_overrides() or Overrides()
ParamRenderer = (
overrides.paramRenderer
if overrides.paramRenderer is not None
else ParameterRenderer.get(dialect)
)
if not issubclass(ParamRenderer, ParameterRenderer): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(
f"{ParamRenderer} needs to be a subclass of csql.ParameterRenderer"
)
QR: type[QueryRenderer] = (
overrides.queryRenderer
if overrides.queryRenderer is not None
else BoringSQLRenderer
)
if not issubclass(QR, QueryRenderer): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(
f"{QueryRenderer} needs to be a subclass of csql.SQLRenderer"
)
queryRenderer = QR(ParamRenderer, dialect=dialect)
new_self = self
new_self = replace_queries_in_tree(params_replacer(newParams), new_self)
new_self = replace_queries_in_tree(cache_replacer(queryRenderer), new_self)
new_self = replace_queries_in_tree(pre_build_replacer(), new_self)
queryRenderer = QR(ParamRenderer, dialect=dialect)
return queryRenderer.render(new_self)
# return RenderedQuery(
# sql=rendered.sql,
# parameters=rendered.parameters
# parameter_names = rendered.parameter_names
# )
@property
def pd(self) -> dict[str, Any]:
"""
Convenience wrapper for Query.build().pd.
Returns a dict of ``{'sql':sql, 'params':params}``, for usage like:
>>> import pandas as pd
>>> con = my_connection()
>>> q = Q('select 123')
>>> pd.read_sql(**q.pd, con=con) # doctest: +IGNORE_RESULT
"""
return self.build().pd
@property
def db(self) -> tuple[str, ParameterList]:
"""
Convenience wrapper for :meth:`Query.build().db<RenderedQuery.db>`.
Returns a tuple of (sql, params), for usage like:
>>> con = my_connection()
>>> q = Q('select 123')
>>> con.cursor().execute(*q.db) # doctest: +IGNORE_RESULT
"""
return self.build().db
@property
def pl(self) -> PolarsQueryArgs:
"""
Convenience wrapper for :meth:`Query.build().db<RenderedQuery.pl>`.
Returns a dict of ``{'query':query, 'execute_options':{'parameters':params}}``,
for usage like:
>>> import polars as pl
>>> con = my_connection()
>>> q = Q('select 123')
>>> pl.read_database(**q.pl, connection=con) # doctest: +IGNORE_RESULT
"""
return self.build().pl
@property
def ddb(self) -> DuckDBQueryArgs:
"""
Convenience wrapper for :meth:`Query.build().db<RenderedQuery.ddb>`.
Returns a dict of ``{'query':query, 'params':params}``,
for usage like:
>>> import duckdb
>>> con = duckdb.connect()
>>> q = Q('select 123')
>>> con.query(**q.ddb).pl # doctest: +IGNORE_RESULT
"""
return self.build().ddb
@property
def ch(self) -> ClickhouseQueryArgs:
"""
Convenience wrapper for :meth:`Query.build().db<RenderedQuery.ch>`.
Returns a dict of ``{'query':query, 'parameters':params}``,
for usage like:
>>> # doctest: +SKIP - needs a running clickhouse server
>>> import clickhouse_connect
>>> ch = clickhouse_connect.create_client(password='asdf')
>>> q = Q('select 123')
>>> ch.query_df_arrow(**q.ch, dataframe_library='polars') # doctest: +IGNORE_RESULT
"""
return self.build().ch
[docs]
def persist(
self, cacher: csql.persist.Cacher, tag: str | None = None
) -> csql.Query:
"""
Marks this query for persistance with the given :class:`csql.persist.Cacher`.
See: :ref:`persist`
Usage:
>>> con = some_connection()
>>> cache = csql.contrib.persist.TempTableCacher(con)
>>> q = Q(f'select 123 from something_slow').persist(cache)
>>> q.preview_pd(con) # slow # doctest: +IGNORE_RESULT
>>> q.preview_pd(con) # fast # doctest: +IGNORE_RESULT
>>> q2 = Q(f'select count(*) from {q}')
>>> q2.preview_pd(con) # also fast # doctest: +IGNORE_RESULT
"""
return cacher.persist(self, tag)
ParameterValue = Hashable | Collection[Hashable]
[docs]
@dataclass(frozen=True)
class ParameterPlaceholder(QueryBit, InstanceTracking):
"""
A ParameterPlaceholder is what you get when you get an individual parameter by
name from a :class:`Parameters` object, like `p['param_you_want']`. The only thing
you should need to do with it is interpolate it into a query:
>>> p = Parameters(param_you_want=123)
>>> q = Q(f'select {p["param_you_want"]}')
>>> q.db
('select :1', (123,))
"""
key: str | AutoKey
":meta private:"
value: csql.ParameterValue
":meta private:"
_key_context: (
int | None
) # allow people to pass multiple distinct parameters with the same key into a Query.
fmt: str
def _withFmt(self, fmt: str) -> Self:
return dataclasses.replace(self, fmt=fmt)
[docs]
@dataclass(frozen=True)
class AutoKey:
"""
A wrapper for a parameter key, indicating it was generated automatically by :meth:`csql.Parameters.add`.
"""
k: str
[docs]
class Parameters:
"""
Parameters let you quickly initialize a bunch of params to pass into your queries.
Once parameters have been added in the Parameters constructor or with :meth:`add`, they
can be pulled out by their ``p['parameter name']``, for use in a :func:`Query<Q>`.
Usage:
>>> p = Parameters(
... start=date(2019,1,1),
... end=date(2020,1,1)
... )
>>> q = Q(f"select * from customers where {p['start']} <= date and date < {p['end']}")
See: :ref:`reparam`
"""
params: dict[str | AutoKey, ParameterValue]
":meta private:"
def __init__(self, **kwargs: ParameterValue):
self.params = {k: self._check_hashable_value(k, v) for k, v in kwargs.items()}
@staticmethod
def _check_hashable_value(key: str | AutoKey, val: Any) -> Hashable:
if isinstance(val, Collection) and not isinstance(val, str):
val = tuple(val) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
try:
_h = hash(val)
return cast(Hashable, val)
except TypeError as e:
raise ValueError(
f"Refusing to add {key}:{val} - parameter values need to be hashable."
) from e
def _add(self, key: str | AutoKey, val: Any) -> ParameterPlaceholder:
if key in self.params:
raise ValueError(
f"Refusing to add {key}: it is already in this set of Parameters (with value {self.params[key]})."
)
val = self._check_hashable_value(key, val)
self.params[key] = val
return self[key]
[docs]
def add(
self, value: csql.ParameterValue = NOVALUE, /, **kwargs: csql.ParameterValue
) -> csql.ParameterPlaceholder:
"""
Adds a single parameter into this Parameters, and returns it.
You don't normally need this (just add them directly when building :class:`Parameters`), but
it can be useful in loops where you need to build a query based on an unknown number of params.
Can be called as
>>> p.add('value') # doctest: +IGNORE_RESULT
... # which will add a single parameter with an autogenerated name.
Can also be called as
>>> p.add(key='value') # doctest: +IGNORE_RESULT
... # which will add a named parameter.
Useful in loops:
>>> p = Parameters()
>>> licence_cancellations = [
... ('Shazza', date(2019, 1, 1)),
... ('Bazza', date(2019, 1, 26)),
... ('Azza', date(2022, 1, 3))
... ]
>>> where_clause = ' or '.join(
... f'(name = {p.add(name)} and timestamp > {p.add(date)})'
... for name, date in licence_cancellations
... )
>>> query = Q(f'select * from frankston_traffic_log where {where_clause}')
:param value: A single parameter to add: ``add(123)``. Cannot be used with ``kwargs``.
:param kwargs: A single key and parameter to add: ``add(my_fav_number=123)``. Cannot be used with ``value``.
"""
passed_arg = value is not NOVALUE
passed_kw = len(kwargs) == 1
if not (passed_arg ^ passed_kw):
raise ValueError("You need to call either add(val) or add(key=val)")
if passed_arg:
generate_keys = (AutoKey(f"_add_{i}") for i in itertools.count())
auto_key = next(k for k in generate_keys if k not in self.params)
return self._add(auto_key, value)
elif passed_kw:
[(key, val)] = kwargs.items()
return self._add(key, val)
raise RuntimeError("Uh oh, csql bug. Please report.")
def __contains__(self, key: str) -> bool:
return self.params.__contains__(key)
def __getitem__(self, key: str | AutoKey) -> ParameterPlaceholder:
paramVal = self.params[key] # check existence
return ParameterPlaceholder(
key=key, value=paramVal, _key_context=id(self), fmt=""
)
def __getattr__(self, key: str) -> ParameterPlaceholder:
paramVal = self.params[key] # check existence
return ParameterPlaceholder(
key=key, value=paramVal, _key_context=id(self), fmt=""
)