before send to remote

This commit is contained in:
2022-08-26 16:41:18 +06:00
commit 3814beb3e0
5408 changed files with 652023 additions and 0 deletions

View File

@@ -0,0 +1,190 @@
from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
from .datetime import (
Extract,
ExtractDay,
ExtractHour,
ExtractIsoWeekDay,
ExtractIsoYear,
ExtractMinute,
ExtractMonth,
ExtractQuarter,
ExtractSecond,
ExtractWeek,
ExtractWeekDay,
ExtractYear,
Now,
Trunc,
TruncDate,
TruncDay,
TruncHour,
TruncMinute,
TruncMonth,
TruncQuarter,
TruncSecond,
TruncTime,
TruncWeek,
TruncYear,
)
from .math import (
Abs,
ACos,
ASin,
ATan,
ATan2,
Ceil,
Cos,
Cot,
Degrees,
Exp,
Floor,
Ln,
Log,
Mod,
Pi,
Power,
Radians,
Random,
Round,
Sign,
Sin,
Sqrt,
Tan,
)
from .text import (
MD5,
SHA1,
SHA224,
SHA256,
SHA384,
SHA512,
Chr,
Concat,
ConcatPair,
Left,
Length,
Lower,
LPad,
LTrim,
Ord,
Repeat,
Replace,
Reverse,
Right,
RPad,
RTrim,
StrIndex,
Substr,
Trim,
Upper,
)
from .window import (
CumeDist,
DenseRank,
FirstValue,
Lag,
LastValue,
Lead,
NthValue,
Ntile,
PercentRank,
Rank,
RowNumber,
)
__all__ = [
# comparison and conversion
"Cast",
"Coalesce",
"Collate",
"Greatest",
"JSONObject",
"Least",
"NullIf",
# datetime
"Extract",
"ExtractDay",
"ExtractHour",
"ExtractMinute",
"ExtractMonth",
"ExtractQuarter",
"ExtractSecond",
"ExtractWeek",
"ExtractIsoWeekDay",
"ExtractWeekDay",
"ExtractIsoYear",
"ExtractYear",
"Now",
"Trunc",
"TruncDate",
"TruncDay",
"TruncHour",
"TruncMinute",
"TruncMonth",
"TruncQuarter",
"TruncSecond",
"TruncTime",
"TruncWeek",
"TruncYear",
# math
"Abs",
"ACos",
"ASin",
"ATan",
"ATan2",
"Ceil",
"Cos",
"Cot",
"Degrees",
"Exp",
"Floor",
"Ln",
"Log",
"Mod",
"Pi",
"Power",
"Radians",
"Random",
"Round",
"Sign",
"Sin",
"Sqrt",
"Tan",
# text
"MD5",
"SHA1",
"SHA224",
"SHA256",
"SHA384",
"SHA512",
"Chr",
"Concat",
"ConcatPair",
"Left",
"Length",
"Lower",
"LPad",
"LTrim",
"Ord",
"Repeat",
"Replace",
"Reverse",
"Right",
"RPad",
"RTrim",
"StrIndex",
"Substr",
"Trim",
"Upper",
# window
"CumeDist",
"DenseRank",
"FirstValue",
"Lag",
"LastValue",
"Lead",
"NthValue",
"Ntile",
"PercentRank",
"Rank",
"RowNumber",
]

View File

@@ -0,0 +1,212 @@
"""Database functions that do comparisons or type conversions."""
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields.json import JSONField
from django.utils.regex_helper import _lazy_re_compile
class Cast(Func):
"""Coerce an expression to a new field type."""
function = "CAST"
template = "%(function)s(%(expressions)s AS %(db_type)s)"
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
extra_context["db_type"] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
db_type = self.output_field.db_type(connection)
if db_type in {"datetime", "time"}:
# Use strftime as datetime/time don't keep fractional seconds.
template = "strftime(%%s, %(expressions)s)"
sql, params = super().as_sql(
compiler, connection, template=template, **extra_context
)
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
params.insert(0, format_string)
return sql, params
elif db_type == "date":
template = "date(%(expressions)s)"
return super().as_sql(
compiler, connection, template=template, **extra_context
)
return self.as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
template = None
output_type = self.output_field.get_internal_type()
# MySQL doesn't support explicit cast to float.
if output_type == "FloatField":
template = "(%(expressions)s + 0.0)"
# MariaDB doesn't support explicit cast to JSON.
elif output_type == "JSONField" and connection.mysql_is_mariadb:
template = "JSON_EXTRACT(%(expressions)s, '$')"
return self.as_sql(compiler, connection, template=template, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
return self.as_sql(
compiler,
connection,
template="(%(expressions)s)::%(db_type)s",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == "JSONField":
# Oracle doesn't support explicit cast to JSON.
template = "JSON_QUERY(%(expressions)s, '$')"
return super().as_sql(
compiler, connection, template=template, **extra_context
)
return self.as_sql(compiler, connection, **extra_context)
class Coalesce(Func):
"""Return, from left to right, the first non-null expression."""
function = "COALESCE"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Coalesce must take at least two expressions")
super().__init__(*expressions, **extra)
@property
def empty_result_set_value(self):
for expression in self.get_source_expressions():
result = expression.empty_result_set_value
if result is NotImplemented or result is not None:
return result
return None
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == "TextField":
clone = self.copy()
clone.set_source_expressions(
[
Func(expression, function="TO_NCLOB")
for expression in self.get_source_expressions()
]
)
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Collate(Func):
function = "COLLATE"
template = "%(expressions)s %(function)s %(collation)s"
# Inspired from
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
collation_re = _lazy_re_compile(r"^[\w\-]+$")
def __init__(self, expression, collation):
if not (collation and self.collation_re.match(collation)):
raise ValueError("Invalid collation name: %r." % collation)
self.collation = collation
super().__init__(expression)
def as_sql(self, compiler, connection, **extra_context):
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
return super().as_sql(compiler, connection, **extra_context)
class Greatest(Func):
"""
Return the maximum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = "GREATEST"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Greatest must take at least two expressions")
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
class JSONObject(Func):
function = "JSON_OBJECT"
output_field = JSONField()
def __init__(self, **fields):
expressions = []
for key, value in fields.items():
expressions.extend((Value(key), value))
super().__init__(*expressions)
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.has_json_object_function:
raise NotSupportedError(
"JSONObject() is not supported on this database backend."
)
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
return self.as_sql(
compiler,
connection,
function="JSONB_BUILD_OBJECT",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
class ArgJoiner:
def join(self, args):
args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
return ", ".join(args)
return self.as_sql(
compiler,
connection,
arg_joiner=ArgJoiner(),
template="%(function)s(%(expressions)s RETURNING CLOB)",
**extra_context,
)
class Least(Func):
"""
Return the minimum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, return the minimum not-null expression.
On MySQL, Oracle, and SQLite, if any expression is null, return null.
"""
function = "LEAST"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Least must take at least two expressions")
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
class NullIf(Func):
function = "NULLIF"
arity = 2
def as_oracle(self, compiler, connection, **extra_context):
expression1 = self.get_source_expressions()[0]
if isinstance(expression1, Value) and expression1.value is None:
raise ValueError("Oracle does not allow Value(None) for expression1.")
return super().as_sql(compiler, connection, **extra_context)

View File

@@ -0,0 +1,427 @@
from datetime import datetime
from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
DateField,
DateTimeField,
DurationField,
Field,
IntegerField,
TimeField,
)
from django.db.models.lookups import (
Transform,
YearExact,
YearGt,
YearGte,
YearLt,
YearLte,
)
from django.utils import timezone
class TimezoneMixin:
tzinfo = None
def get_tzname(self):
# Timezone conversions must happen to the input datetime *before*
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
# database as 2016-01-01 01:00:00 +00:00. Any results should be
# based on the input datetime not the stored datetime.
tzname = None
if settings.USE_TZ:
if self.tzinfo is None:
tzname = timezone.get_current_timezone_name()
else:
tzname = timezone._get_timezone_name(self.tzinfo)
return tzname
class Extract(TimezoneMixin, Transform):
lookup_name = None
output_field = IntegerField()
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
raise ValueError("lookup_name must be provided")
self.tzinfo = tzinfo
super().__init__(expression, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
sql, params = connection.ops.datetime_extract_sql(
self.lookup_name, sql, tuple(params), tzname
)
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
elif isinstance(lhs_output_field, DateField):
sql, params = connection.ops.date_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, TimeField):
sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
raise ValueError(
"Extract requires native DurationField database support."
)
sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
else:
# resolve_expression has already validated the output_field so this
# assert should never be hit.
assert False, "Tried to Extract from an invalid type."
return sql, params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
copy = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
field = getattr(copy.lhs, "output_field", None)
if field is None:
return copy
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
raise ValueError(
"Extract input expression must be DateField, DateTimeField, "
"TimeField, or DurationField."
)
# Passing dates to functions expecting datetimes is most likely a mistake.
if type(field) == DateField and copy.lookup_name in (
"hour",
"minute",
"second",
):
raise ValueError(
"Cannot extract time component '%s' from DateField '%s'."
% (copy.lookup_name, field.name)
)
if isinstance(field, DurationField) and copy.lookup_name in (
"year",
"iso_year",
"month",
"week",
"week_day",
"iso_week_day",
"quarter",
):
raise ValueError(
"Cannot extract component '%s' from DurationField '%s'."
% (copy.lookup_name, field.name)
)
return copy
class ExtractYear(Extract):
lookup_name = "year"
class ExtractIsoYear(Extract):
"""Return the ISO-8601 week-numbering year."""
lookup_name = "iso_year"
class ExtractMonth(Extract):
lookup_name = "month"
class ExtractDay(Extract):
lookup_name = "day"
class ExtractWeek(Extract):
"""
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
week.
"""
lookup_name = "week"
class ExtractWeekDay(Extract):
"""
Return Sunday=1 through Saturday=7.
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
lookup_name = "week_day"
class ExtractIsoWeekDay(Extract):
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
lookup_name = "iso_week_day"
class ExtractQuarter(Extract):
lookup_name = "quarter"
class ExtractHour(Extract):
lookup_name = "hour"
class ExtractMinute(Extract):
lookup_name = "minute"
class ExtractSecond(Extract):
lookup_name = "second"
DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)
DateField.register_lookup(ExtractIsoWeekDay)
DateField.register_lookup(ExtractWeek)
DateField.register_lookup(ExtractIsoYear)
DateField.register_lookup(ExtractQuarter)
TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)
DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)
ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)
ExtractIsoYear.register_lookup(YearExact)
ExtractIsoYear.register_lookup(YearGt)
ExtractIsoYear.register_lookup(YearGte)
ExtractIsoYear.register_lookup(YearLt)
ExtractIsoYear.register_lookup(YearLte)
class Now(Func):
template = "CURRENT_TIMESTAMP"
output_field = DateTimeField()
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
)
class TruncBase(TimezoneMixin, Transform):
kind = None
tzinfo = None
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(
self,
expression,
output_field=None,
tzinfo=None,
is_dst=timezone.NOT_PASSED,
**extra,
):
self.tzinfo = tzinfo
self.is_dst = is_dst
super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
if isinstance(self.output_field, DateTimeField):
sql, params = connection.ops.datetime_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, DateField):
sql, params = connection.ops.date_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, TimeField):
sql, params = connection.ops.time_trunc_sql(
self.kind, sql, tuple(params), tzname
)
else:
raise ValueError(
"Trunc only valid on DateField, TimeField, or DateTimeField."
)
return sql, params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
copy = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
if not isinstance(field, (DateField, TimeField)):
raise TypeError(
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
)
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
raise ValueError(
"output_field must be either DateField, TimeField, or DateTimeField"
)
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
class_output_field = (
self.__class__.output_field
if isinstance(self.__class__.output_field, Field)
else None
)
output_field = class_output_field or copy.output_field
has_explicit_output_field = (
class_output_field or field.__class__ is not copy.output_field.__class__
)
if type(field) == DateField and (
isinstance(output_field, DateTimeField)
or copy.kind in ("hour", "minute", "second", "time")
):
raise ValueError(
"Cannot truncate DateField '%s' to %s."
% (
field.name,
output_field.__class__.__name__
if has_explicit_output_field
else "DateTimeField",
)
)
elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField)
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
):
raise ValueError(
"Cannot truncate TimeField '%s' to %s."
% (
field.name,
output_field.__class__.__name__
if has_explicit_output_field
else "DateTimeField",
)
)
return copy
def convert_value(self, value, expression, connection):
if isinstance(self.output_field, DateTimeField):
if not settings.USE_TZ:
pass
elif value is not None:
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
"Database returned an invalid datetime value. Are time "
"zone definitions for your database installed?"
)
elif isinstance(value, datetime):
if value is None:
pass
elif isinstance(self.output_field, DateField):
value = value.date()
elif isinstance(self.output_field, TimeField):
value = value.time()
return value
class Trunc(TruncBase):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(
self,
expression,
kind,
output_field=None,
tzinfo=None,
is_dst=timezone.NOT_PASSED,
**extra,
):
self.kind = kind
super().__init__(
expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
)
class TruncYear(TruncBase):
kind = "year"
class TruncQuarter(TruncBase):
kind = "quarter"
class TruncMonth(TruncBase):
kind = "month"
class TruncWeek(TruncBase):
"""Truncate to midnight on the Monday of the week."""
kind = "week"
class TruncDay(TruncBase):
kind = "day"
class TruncDate(TruncBase):
kind = "date"
lookup_name = "date"
output_field = DateField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname()
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
class TruncTime(TruncBase):
kind = "time"
lookup_name = "time"
output_field = TimeField()
def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time.
sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname()
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
class TruncHour(TruncBase):
kind = "hour"
class TruncMinute(TruncBase):
kind = "minute"
class TruncSecond(TruncBase):
kind = "second"
DateTimeField.register_lookup(TruncDate)
DateTimeField.register_lookup(TruncTime)

View File

@@ -0,0 +1,212 @@
import math
from django.db.models.expressions import Func, Value
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
FixDecimalInputMixin,
NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform
class Abs(Transform):
function = "ABS"
lookup_name = "abs"
class ACos(NumericOutputFieldMixin, Transform):
function = "ACOS"
lookup_name = "acos"
class ASin(NumericOutputFieldMixin, Transform):
function = "ASIN"
lookup_name = "asin"
class ATan(NumericOutputFieldMixin, Transform):
function = "ATAN"
lookup_name = "atan"
class ATan2(NumericOutputFieldMixin, Func):
function = "ATAN2"
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(
connection.ops, "spatialite", False
) or connection.ops.spatial_version >= (5, 0, 0):
return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
# Cast integers to float to avoid inconsistent/buggy behavior if the
# arguments are mixed between integer and float or decimal.
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
clone = self.copy()
clone.set_source_expressions(
[
Cast(expression, FloatField())
if isinstance(expression.output_field, IntegerField)
else expression
for expression in self.get_source_expressions()[::-1]
]
)
return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform):
function = "CEILING"
lookup_name = "ceil"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
class Cos(NumericOutputFieldMixin, Transform):
function = "COS"
lookup_name = "cos"
class Cot(NumericOutputFieldMixin, Transform):
function = "COT"
lookup_name = "cot"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
)
class Degrees(NumericOutputFieldMixin, Transform):
function = "DEGREES"
lookup_name = "degrees"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="((%%(expressions)s) * 180 / %s)" % math.pi,
**extra_context,
)
class Exp(NumericOutputFieldMixin, Transform):
function = "EXP"
lookup_name = "exp"
class Floor(Transform):
function = "FLOOR"
lookup_name = "floor"
class Ln(NumericOutputFieldMixin, Transform):
function = "LN"
lookup_name = "ln"
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = "LOG"
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, "spatialite", False):
return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b).
clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[::-1])
return clone.as_sql(compiler, connection, **extra_context)
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = "MOD"
arity = 2
class Pi(NumericOutputFieldMixin, Func):
function = "PI"
arity = 0
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, template=str(math.pi), **extra_context
)
class Power(NumericOutputFieldMixin, Func):
function = "POWER"
arity = 2
class Radians(NumericOutputFieldMixin, Transform):
function = "RADIANS"
lookup_name = "radians"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="((%%(expressions)s) * %s / 180)" % math.pi,
**extra_context,
)
class Random(NumericOutputFieldMixin, Func):
function = "RANDOM"
arity = 0
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="RAND", **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="RAND", **extra_context)
def get_group_by_cols(self, alias=None):
return []
class Round(FixDecimalInputMixin, Transform):
function = "ROUND"
lookup_name = "round"
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
super().__init__(expression, precision, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
raise ValueError("SQLite does not support negative precision.")
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
source = self.get_source_expressions()[0]
return source.output_field
class Sign(Transform):
function = "SIGN"
lookup_name = "sign"
class Sin(NumericOutputFieldMixin, Transform):
function = "SIN"
lookup_name = "sin"
class Sqrt(NumericOutputFieldMixin, Transform):
function = "SQRT"
lookup_name = "sqrt"
class Tan(NumericOutputFieldMixin, Transform):
function = "TAN"
lookup_name = "tan"

View File

@@ -0,0 +1,57 @@
import sys
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast
class FixDecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
# - LOG(double, double)
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions(
[
Cast(expression, output_field)
if isinstance(expression.output_field, FloatField)
else expression
for expression in self.get_source_expressions()
]
)
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == "DurationField":
sql = "CAST(%s AS SIGNED)" % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == "DurationField":
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import (
IntervalToSeconds,
SecondsToInterval,
)
return compiler.compile(
SecondsToInterval(
self.__class__(IntervalToSeconds(expression), **options)
)
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):
source_fields = self.get_source_fields()
if any(isinstance(s, DecimalField) for s in source_fields):
return DecimalField()
if any(isinstance(s, IntegerField) for s in source_fields):
return FloatField()
return super()._resolve_output_field() if source_fields else FloatField()

View File

@@ -0,0 +1,351 @@
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields import CharField, IntegerField
from django.db.models.functions import Coalesce
from django.db.models.lookups import Transform
class MySQLSHA2Mixin:
def as_mysql(self, compiler, connection, **extra_content):
return super().as_sql(
compiler,
connection,
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
**extra_content,
)
class OracleHashMixin:
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template=(
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
),
**extra_context,
)
class PostgreSQLSHAMixin:
def as_postgresql(self, compiler, connection, **extra_content):
return super().as_sql(
compiler,
connection,
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
function=self.function.lower(),
**extra_content,
)
class Chr(Transform):
function = "CHR"
lookup_name = "chr"
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
function="CHAR",
template="%(function)s(%(expressions)s USING utf16)",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="%(function)s(%(expressions)s USING NCHAR_CS)",
**extra_context,
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
class ConcatPair(Func):
"""
Concatenate two arguments together. This is used by `Concat` because not
all backend databases support more than two arguments.
"""
function = "CONCAT"
def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler,
connection,
template="%(expressions)s",
arg_joiner=" || ",
**extra_context,
)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql(
compiler,
connection,
function="CONCAT_WS",
template="%(function)s('', %(expressions)s)",
**extra_context,
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
c.set_source_expressions(
[
Coalesce(expression, Value(""))
for expression in c.get_source_expressions()
]
)
return c
class Concat(Func):
"""
Concatenate text fields together. Backends that result in an entire
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure a non-null result.
"""
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Concat must take at least two expressions")
paired = self._paired(expressions)
super().__init__(paired, **extra)
def _paired(self, expressions):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
class Left(Func):
function = "LEFT"
arity = 2
output_field = CharField()
def __init__(self, expression, length, **extra):
"""
expression: the name of a field, or an expression returning a string
length: the number of characters to return from the start of the string
"""
if not hasattr(length, "resolve_expression"):
if length < 1:
raise ValueError("'length' must be greater than 0.")
super().__init__(expression, length, **extra)
def get_substr(self):
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
def as_oracle(self, compiler, connection, **extra_context):
return self.get_substr().as_oracle(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
class Length(Transform):
"""Return the number of characters in the expression."""
function = "LENGTH"
lookup_name = "length"
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function="CHAR_LENGTH", **extra_context
)
class Lower(Transform):
function = "LOWER"
lookup_name = "lower"
class LPad(Func):
function = "LPAD"
output_field = CharField()
def __init__(self, expression, length, fill_text=Value(" "), **extra):
if (
not hasattr(length, "resolve_expression")
and length is not None
and length < 0
):
raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra)
class LTrim(Transform):
function = "LTRIM"
lookup_name = "ltrim"
class MD5(OracleHashMixin, Transform):
function = "MD5"
lookup_name = "md5"
class Ord(Transform):
function = "ASCII"
lookup_name = "ord"
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="ORD", **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
class Repeat(Func):
function = "REPEAT"
output_field = CharField()
def __init__(self, expression, number, **extra):
if (
not hasattr(number, "resolve_expression")
and number is not None
and number < 0
):
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions
length = None if number is None else Length(expression) * number
rpad = RPad(expression, length, expression)
return rpad.as_sql(compiler, connection, **extra_context)
class Replace(Func):
function = "REPLACE"
def __init__(self, expression, text, replacement=Value(""), **extra):
super().__init__(expression, text, replacement, **extra)
class Reverse(Transform):
function = "REVERSE"
lookup_name = "reverse"
def as_oracle(self, compiler, connection, **extra_context):
# REVERSE in Oracle is undocumented and doesn't support multi-byte
# strings. Use a special subquery instead.
return super().as_sql(
compiler,
connection,
template=(
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
"FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
"GROUP BY %(expressions)s)"
),
**extra_context,
)
class Right(Left):
function = "RIGHT"
def get_substr(self):
return Substr(
self.source_expressions[0], self.source_expressions[1] * Value(-1)
)
class RPad(LPad):
function = "RPAD"
class RTrim(Transform):
function = "RTRIM"
lookup_name = "rtrim"
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA1"
lookup_name = "sha1"
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
function = "SHA224"
lookup_name = "sha224"
def as_oracle(self, compiler, connection, **extra_context):
raise NotSupportedError("SHA224 is not supported on Oracle.")
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA256"
lookup_name = "sha256"
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA384"
lookup_name = "sha384"
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA512"
lookup_name = "sha512"
class StrIndex(Func):
"""
Return a positive integer corresponding to the 1-indexed position of the
first occurrence of a substring inside another string, or 0 if the
substring is not found.
"""
function = "INSTR"
arity = 2
output_field = IntegerField()
def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
class Substr(Func):
function = "SUBSTRING"
output_field = CharField()
def __init__(self, expression, pos, length=None, **extra):
"""
expression: the name of a field, or an expression returning a string
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
if not hasattr(pos, "resolve_expression"):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
expressions = [expression, pos]
if length is not None:
expressions.append(length)
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
class Trim(Transform):
function = "TRIM"
lookup_name = "trim"
class Upper(Transform):
function = "UPPER"
lookup_name = "upper"

View File

@@ -0,0 +1,120 @@
from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
__all__ = [
"CumeDist",
"DenseRank",
"FirstValue",
"Lag",
"LastValue",
"Lead",
"NthValue",
"Ntile",
"PercentRank",
"Rank",
"RowNumber",
]
class CumeDist(Func):
function = "CUME_DIST"
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = "DENSE_RANK"
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = "FIRST_VALUE"
window_compatible = True
class LagLeadFunction(Func):
window_compatible = True
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
"%s requires a non-null source expression." % self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
"%s requires a positive integer for the offset."
% self.__class__.__name__
)
args = (expression, offset)
if default is not None:
args += (default,)
super().__init__(*args, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Lag(LagLeadFunction):
function = "LAG"
class LastValue(Func):
arity = 1
function = "LAST_VALUE"
window_compatible = True
class Lead(LagLeadFunction):
function = "LEAD"
class NthValue(Func):
function = "NTH_VALUE"
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError(
"%s requires a non-null source expression." % self.__class__.__name__
)
if nth is None or nth <= 0:
raise ValueError(
"%s requires a positive integer as for nth." % self.__class__.__name__
)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Ntile(Func):
function = "NTILE"
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError("num_buckets must be greater than 0.")
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = "PERCENT_RANK"
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = "RANK"
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = "ROW_NUMBER"
output_field = IntegerField()
window_compatible = True