before send to remote

This commit is contained in:
2022-08-26 16:41:18 +06:00
commit 3814beb3e0
5408 changed files with 652023 additions and 0 deletions

View File

@@ -0,0 +1,444 @@
"""
MySQL database backend for Django.
Requires mysqlclient: https://pypi.org/project/mysqlclient/
"""
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.regex_helper import _lazy_re_compile
try:
import MySQLdb as Database
except ImportError as err:
raise ImproperlyConfigured(
"Error loading MySQLdb module.\nDid you install mysqlclient?"
) from err
from MySQLdb.constants import CLIENT, FIELD_TYPE
from MySQLdb.converters import conversions
# Some of these import MySQLdb, so import them after checking if it's installed.
from .client import DatabaseClient
from .creation import DatabaseCreation
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from .validation import DatabaseValidation
version = Database.version_info
if version < (1, 4, 0):
raise ImproperlyConfigured(
"mysqlclient 1.4.0 or newer is required; you have %s." % Database.__version__
)
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
# terms of actual behavior as they are signed and include days -- and Django
# expects time.
django_conversions = {
**conversions,
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
}
# This should match the numerical portion of the version numbers (we can treat
# versions like 5.0.24 and 5.0.24a as the same).
server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
class CursorWrapper:
"""
A thin wrapper around MySQLdb's normal cursor class that catches particular
exception instances and reraises them with the correct types.
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
to the particular underlying representation returned by Connection.cursor().
"""
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
3819, # CHECK constraint is violated
4025, # CHECK constraint failed
)
def __init__(self, cursor):
self.cursor = cursor
def execute(self, query, args=None):
try:
# args is None means no string interpolation
return self.cursor.execute(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def executemany(self, query, args):
try:
return self.cursor.executemany(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "mysql"
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
"AutoField": "integer AUTO_INCREMENT",
"BigAutoField": "bigint AUTO_INCREMENT",
"BinaryField": "longblob",
"BooleanField": "bool",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "datetime(6)",
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
"DurationField": "bigint",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "double precision",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "char(15)",
"GenericIPAddressField": "char(39)",
"JSONField": "json",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint UNSIGNED",
"PositiveIntegerField": "integer UNSIGNED",
"PositiveSmallIntegerField": "smallint UNSIGNED",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallint AUTO_INCREMENT",
"SmallIntegerField": "smallint",
"TextField": "longtext",
"TimeField": "time(6)",
"UUIDField": "char(32)",
}
# For these data types:
# - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
# as nullable
# - all versions of MySQL and MariaDB don't support full width database
# indexes
_limited_data_types = (
"tinyblob",
"blob",
"mediumblob",
"longblob",
"tinytext",
"text",
"mediumtext",
"longtext",
"json",
)
operators = {
"exact": "= %s",
"iexact": "LIKE %s",
"contains": "LIKE BINARY %s",
"icontains": "LIKE %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE BINARY %s",
"endswith": "LIKE BINARY %s",
"istartswith": "LIKE %s",
"iendswith": "LIKE %s",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
"contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
"icontains": "LIKE CONCAT('%%', {}, '%%')",
"startswith": "LIKE BINARY CONCAT({}, '%%')",
"istartswith": "LIKE CONCAT({}, '%%')",
"endswith": "LIKE BINARY CONCAT('%%', {})",
"iendswith": "LIKE CONCAT('%%', {})",
}
isolation_levels = {
"read uncommitted",
"read committed",
"repeatable read",
"serializable",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def get_database_version(self):
return self.mysql_version
def get_connection_params(self):
kwargs = {
"conv": django_conversions,
"charset": "utf8",
}
settings_dict = self.settings_dict
if settings_dict["USER"]:
kwargs["user"] = settings_dict["USER"]
if settings_dict["NAME"]:
kwargs["database"] = settings_dict["NAME"]
if settings_dict["PASSWORD"]:
kwargs["password"] = settings_dict["PASSWORD"]
if settings_dict["HOST"].startswith("/"):
kwargs["unix_socket"] = settings_dict["HOST"]
elif settings_dict["HOST"]:
kwargs["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
kwargs["port"] = int(settings_dict["PORT"])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs["client_flag"] = CLIENT.FOUND_ROWS
# Validate the transaction isolation level, if specified.
options = settings_dict["OPTIONS"].copy()
isolation_level = options.pop("isolation_level", "read committed")
if isolation_level:
isolation_level = isolation_level.lower()
if isolation_level not in self.isolation_levels:
raise ImproperlyConfigured(
"Invalid transaction isolation level '%s' specified.\n"
"Use one of %s, or None."
% (
isolation_level,
", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
)
)
self.isolation_level = isolation_level
kwargs.update(options)
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
connection = Database.connect(**conn_params)
# bytes encoder in mysqlclient doesn't work and was added only to
# prevent KeyErrors in Django < 2.0. We can remove this workaround when
# mysqlclient 2.1 becomes the minimal mysqlclient supported by Django.
# See https://github.com/PyMySQL/mysqlclient/issues/489
if connection.encoders.get(bytes) is bytes:
connection.encoders.pop(bytes)
return connection
def init_connection_state(self):
super().init_connection_state()
assignments = []
if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
assignments.append("SET SQL_AUTO_IS_NULL = 0")
if self.isolation_level:
assignments.append(
"SET SESSION TRANSACTION ISOLATION LEVEL %s"
% self.isolation_level.upper()
)
if assignments:
with self.cursor() as cursor:
cursor.execute("; ".join(assignments))
@async_unsafe
def create_cursor(self, name=None):
cursor = self.connection.cursor()
return CursorWrapper(cursor)
def _rollback(self):
try:
BaseDatabaseWrapper._rollback(self)
except Database.NotSupportedError:
pass
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit(autocommit)
def disable_constraint_checking(self):
"""
Disable foreign key checks, primarily for use in adding rows with
forward references. Always return True to indicate constraint checks
need to be re-enabled.
"""
with self.cursor() as cursor:
cursor.execute("SET foreign_key_checks=0")
return True
def enable_constraint_checking(self):
"""
Re-enable foreign key checks after they have been disabled.
"""
# Override needs_rollback in case constraint_checks_disabled is
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
with self.cursor() as cursor:
cursor.execute("SET foreign_key_checks=1")
finally:
self.needs_rollback = needs_rollback
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
if not primary_key_column_name:
continue
relations = self.introspection.get_relations(cursor, table_name)
for column_name, (
referenced_column_name,
referenced_table_name,
) in relations.items():
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name,
column_name,
table_name,
referenced_table_name,
column_name,
referenced_column_name,
column_name,
referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
bad_row[0],
table_name,
column_name,
bad_row[1],
referenced_table_name,
referenced_column_name,
)
)
def is_usable(self):
try:
self.connection.ping()
except Database.Error:
return False
else:
return True
@cached_property
def display_name(self):
return "MariaDB" if self.mysql_is_mariadb else "MySQL"
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
check_constraints = {
"PositiveBigIntegerField": "`%(column)s` >= 0",
"PositiveIntegerField": "`%(column)s` >= 0",
"PositiveSmallIntegerField": "`%(column)s` >= 0",
}
if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
# MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
# a check constraint.
check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
return check_constraints
return {}
@cached_property
def mysql_server_data(self):
with self.temporary_connection() as cursor:
# Select some server variables and test if the time zone
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
# timezone isn't loaded into the mysql.time_zone table.
cursor.execute(
"""
SELECT VERSION(),
@@sql_mode,
@@default_storage_engine,
@@sql_auto_is_null,
@@lower_case_table_names,
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
"""
)
row = cursor.fetchone()
return {
"version": row[0],
"sql_mode": row[1],
"default_storage_engine": row[2],
"sql_auto_is_null": bool(row[3]),
"lower_case_table_names": bool(row[4]),
"has_zoneinfo_database": bool(row[5]),
}
@cached_property
def mysql_server_info(self):
return self.mysql_server_data["version"]
@cached_property
def mysql_version(self):
match = server_version_re.match(self.mysql_server_info)
if not match:
raise Exception(
"Unable to determine MySQL version from version string %r"
% self.mysql_server_info
)
return tuple(int(x) for x in match.groups())
@cached_property
def mysql_is_mariadb(self):
return "mariadb" in self.mysql_server_info.lower()
@cached_property
def sql_mode(self):
sql_mode = self.mysql_server_data["sql_mode"]
return set(sql_mode.split(",") if sql_mode else ())

View File

@@ -0,0 +1,60 @@
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "mysql"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
env = None
database = settings_dict["OPTIONS"].get(
"database",
settings_dict["OPTIONS"].get("db", settings_dict["NAME"]),
)
user = settings_dict["OPTIONS"].get("user", settings_dict["USER"])
password = settings_dict["OPTIONS"].get(
"password",
settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]),
)
host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"])
port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"])
server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca")
client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert")
client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key")
defaults_file = settings_dict["OPTIONS"].get("read_default_file")
charset = settings_dict["OPTIONS"].get("charset")
# Seems to be no good way to set sql_mode with CLI.
if defaults_file:
args += ["--defaults-file=%s" % defaults_file]
if user:
args += ["--user=%s" % user]
if password:
# The MYSQL_PWD environment variable usage is discouraged per
# MySQL's documentation due to the possibility of exposure through
# `ps` on old Unix flavors but --password suffers from the same
# flaw on even more systems. Usage of an environment variable also
# prevents password exposure if the subprocess.run(check=True) call
# raises a CalledProcessError since the string representation of
# the latter includes all of the provided `args`.
env = {"MYSQL_PWD": password}
if host:
if "/" in host:
args += ["--socket=%s" % host]
else:
args += ["--host=%s" % host]
if port:
args += ["--port=%s" % port]
if server_ca:
args += ["--ssl-ca=%s" % server_ca]
if client_cert:
args += ["--ssl-cert=%s" % client_cert]
if client_key:
args += ["--ssl-key=%s" % client_key]
if charset:
args += ["--default-character-set=%s" % charset]
if database:
args += [database]
args.extend(parameters)
return args, env

View File

@@ -0,0 +1,77 @@
from django.core.exceptions import FieldError
from django.db.models.expressions import Col
from django.db.models.sql import compiler
class SQLCompiler(compiler.SQLCompiler):
def as_subquery_condition(self, alias, columns, compiler):
qn = compiler.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
return (
"(%s) IN (%s)"
% (
", ".join("%s.%s" % (qn(alias), qn2(column)) for column in columns),
sql,
),
params,
)
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def as_sql(self):
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
# the SQLDeleteCompiler's default implementation when multiple tables
# are involved since MySQL/MariaDB will generate a more efficient query
# plan than when using a subquery.
where, having = self.query.where.split_having()
if self.single_alias or having:
# DELETE FROM cannot be used when filtering against aggregates
# since it doesn't allow for GROUP BY and HAVING clauses.
return super().as_sql()
result = [
"DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias())
]
from_sql, from_params = self.get_from_clause()
result.extend(from_sql)
where_sql, where_params = self.compile(where)
if where_sql:
result.append("WHERE %s" % where_sql)
return " ".join(result), tuple(from_params) + tuple(where_params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
def as_sql(self):
update_query, update_params = super().as_sql()
# MySQL and MariaDB support UPDATE ... ORDER BY syntax.
if self.query.order_by:
order_by_sql = []
order_by_params = []
db_table = self.query.get_meta().db_table
try:
for resolved, (sql, params, _) in self.get_order_by():
if (
isinstance(resolved.expression, Col)
and resolved.expression.alias != db_table
):
# Ignore ordering if it contains joined fields, because
# they cannot be used in the ORDER BY clause.
raise FieldError
order_by_sql.append(sql)
order_by_params.extend(params)
update_query += " ORDER BY " + ", ".join(order_by_sql)
update_params += tuple(order_by_params)
except FieldError:
# Ignore ordering if it contains annotations, because they're
# removed in .update() and cannot be resolved.
pass
return update_query, update_params
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass

View File

@@ -0,0 +1,87 @@
import os
import subprocess
import sys
from django.db.backends.base.creation import BaseDatabaseCreation
from .client import DatabaseClient
class DatabaseCreation(BaseDatabaseCreation):
def sql_table_creation_suffix(self):
suffix = []
test_settings = self.connection.settings_dict["TEST"]
if test_settings["CHARSET"]:
suffix.append("CHARACTER SET %s" % test_settings["CHARSET"])
if test_settings["COLLATION"]:
suffix.append("COLLATE %s" % test_settings["COLLATION"])
return " ".join(suffix)
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if len(e.args) < 1 or e.args[0] != 1007:
# All errors except "database exists" (1007) cancel tests.
self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
else:
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_db_params = {
"dbname": self.connection.ops.quote_name(target_database_name),
"suffix": self.sql_table_creation_suffix(),
}
with self._nodb_cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
if keepdb:
# If the database should be kept, skip everything else.
return
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error recreating the test database: %s" % e)
sys.exit(2)
self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name):
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(
self.connection.settings_dict, []
)
dump_cmd = [
"mysqldump",
*cmd_args[1:-1],
"--routines",
"--events",
source_database_name,
]
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name
with subprocess.Popen(
dump_cmd, stdout=subprocess.PIPE, env=dump_env
) as dump_proc:
with subprocess.Popen(
load_cmd,
stdin=dump_proc.stdout,
stdout=subprocess.DEVNULL,
env=load_env,
):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()

View File

@@ -0,0 +1,380 @@
import operator
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
empty_fetchmany_value = ()
allows_group_by_pk = True
related_fields_match_type = True
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
allow_sliced_subqueries_with_in = False
has_select_for_update = True
supports_forward_references = False
supports_regex_backreferencing = False
supports_date_lookup_using_string = False
supports_timezones = False
requires_explicit_null_ordering_when_grouping = True
can_release_savepoints = True
atomic_transactions = False
can_clone_databases = True
supports_temporal_subtraction = True
supports_select_intersection = False
supports_select_difference = False
supports_slicing_ordering_in_compound = True
supports_index_on_text_field = False
supports_update_conflicts = True
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN
DECLARE V_I INTEGER;
SET V_I = 1;
END;
"""
create_test_procedure_with_int_param_sql = """
CREATE PROCEDURE test_procedure (P_I INTEGER)
BEGIN
DECLARE V_I INTEGER;
SET V_I = P_I;
END;
"""
# Neither MySQL nor MariaDB support partial indexes.
supports_partial_indexes = False
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
# indexed expression.
collate_as_index_expression = True
supports_order_by_nulls_modifier = False
order_by_nulls_first = True
supports_logical_xor = True
@cached_property
def minimum_database_version(self):
if self.connection.mysql_is_mariadb:
return (10, 3)
else:
return (5, 7)
@cached_property
def bare_select_suffix(self):
if (
self.connection.mysql_is_mariadb and self.connection.mysql_version < (10, 4)
) or (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version < (8,)
):
return " FROM DUAL"
return ""
@cached_property
def test_collations(self):
charset = "utf8"
if (
self.connection.mysql_is_mariadb
and self.connection.mysql_version >= (10, 6)
) or (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version >= (8, 0, 30)
):
# utf8 is an alias for utf8mb3 in MariaDB 10.6+ and MySQL 8.0.30+.
charset = "utf8mb3"
return {
"ci": f"{charset}_general_ci",
"non_default": f"{charset}_esperanto_ci",
"swedish_ci": f"{charset}_swedish_ci",
}
test_now_utc_template = "UTC_TIMESTAMP"
@cached_property
def django_test_skips(self):
skips = {
"This doesn't work on MySQL.": {
"db_functions.comparison.test_greatest.GreatestTests."
"test_coalesce_workaround",
"db_functions.comparison.test_least.LeastTests."
"test_coalesce_workaround",
},
"Running on MySQL requires utf8mb4 encoding (#18392).": {
"model_fields.test_textfield.TextFieldTests.test_emoji",
"model_fields.test_charfield.TestCharField.test_emoji",
},
"MySQL doesn't support functional indexes on a function that "
"returns JSON": {
"schema.tests.SchemaTests.test_func_index_json_key_transform",
},
"MySQL supports multiplying and dividing DurationFields by a "
"scalar value but it's not implemented (#25287).": {
"expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
},
"UPDATE ... ORDER BY syntax on MySQL/MariaDB does not support ordering by"
"related fields.": {
"update.tests.AdvancedTests."
"test_update_ordered_by_inline_m2m_annotation",
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
},
}
if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode:
skips.update(
{
"GROUP BY optimization does not work properly when "
"ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.": {
"aggregation.tests.AggregateTestCase."
"test_aggregation_subquery_annotation_multivalued",
"annotations.tests.NonAggregateAnnotationTestCase."
"test_annotation_aggregate_with_m2o",
},
}
)
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (
8,
):
skips.update(
{
"Casting to datetime/time is not supported by MySQL < 8.0. "
"(#30224)": {
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_time_from_python",
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_datetime_from_python",
},
"MySQL < 8.0 returns string type instead of datetime/time. "
"(#30224)": {
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_time_from_database",
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_datetime_from_database",
},
}
)
if self.connection.mysql_is_mariadb and (
10,
4,
3,
) < self.connection.mysql_version < (10, 5, 2):
skips.update(
{
"https://jira.mariadb.org/browse/MDEV-19598": {
"schema.tests.SchemaTests."
"test_alter_not_unique_field_to_primary_key",
},
}
)
if self.connection.mysql_is_mariadb and (
10,
4,
12,
) < self.connection.mysql_version < (10, 5):
skips.update(
{
"https://jira.mariadb.org/browse/MDEV-22775": {
"schema.tests.SchemaTests."
"test_alter_pk_with_self_referential_field",
},
}
)
if (
self.connection.mysql_is_mariadb and self.connection.mysql_version < (10, 4)
) or (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version < (8,)
):
skips.update(
{
"Parenthesized combined queries are not supported on MySQL < 8 and "
"MariaDB < 10.4": {
"queries.test_qs_combinators.QuerySetSetOperationTests."
"test_union_in_subquery",
"queries.test_qs_combinators.QuerySetSetOperationTests."
"test_union_in_subquery_related_outerref",
"queries.test_qs_combinators.QuerySetSetOperationTests."
"test_union_in_with_ordering",
}
}
)
if not self.supports_explain_analyze:
skips.update(
{
"MariaDB and MySQL >= 8.0.18 specific.": {
"queries.test_explain.ExplainTests.test_mysql_analyze",
},
}
)
return skips
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
return self.connection.mysql_server_data["default_storage_engine"]
@cached_property
def allows_auto_pk_0(self):
"""
Autoincrement primary key can be set to 0 if it doesn't generate new
autoincrement values.
"""
return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode
@cached_property
def update_can_self_select(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
10,
3,
2,
)
@cached_property
def can_introspect_foreign_keys(self):
"Confirm support for introspected foreign keys"
return self._mysql_storage_engine != "MyISAM"
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"BinaryField": "TextField",
"BooleanField": "IntegerField",
"DurationField": "BigIntegerField",
"GenericIPAddressField": "CharField",
}
@cached_property
def can_return_columns_from_insert(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
10,
5,
0,
)
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)
@cached_property
def has_zoneinfo_database(self):
return self.connection.mysql_server_data["has_zoneinfo_database"]
@cached_property
def is_sql_auto_is_null_enabled(self):
return self.connection.mysql_server_data["sql_auto_is_null"]
@cached_property
def supports_over_clause(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 2)
supports_frame_range_fixed_distance = property(
operator.attrgetter("supports_over_clause")
)
@cached_property
def supports_column_check_constraints(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 16)
supports_table_check_constraints = property(
operator.attrgetter("supports_column_check_constraints")
)
@cached_property
def can_introspect_check_constraints(self):
if self.connection.mysql_is_mariadb:
version = self.connection.mysql_version
return version >= (10, 3, 10)
return self.connection.mysql_version >= (8, 0, 16)
@cached_property
def has_select_for_update_skip_locked(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 6)
return self.connection.mysql_version >= (8, 0, 1)
@cached_property
def has_select_for_update_nowait(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 1)
@cached_property
def has_select_for_update_of(self):
return (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version >= (8, 0, 1)
)
@cached_property
def supports_explain_analyze(self):
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (
8,
0,
18,
)
@cached_property
def supported_explain_formats(self):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
# backends.
formats = {"JSON", "TEXT", "TRADITIONAL"}
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
8,
0,
16,
):
formats.add("TREE")
return formats
@cached_property
def supports_transactions(self):
"""
All storage engines except MyISAM support transactions.
"""
return self._mysql_storage_engine != "MyISAM"
uses_savepoints = property(operator.attrgetter("supports_transactions"))
can_release_savepoints = property(operator.attrgetter("supports_transactions"))
@cached_property
def ignores_table_name_case(self):
return self.connection.mysql_server_data["lower_case_table_names"]
@cached_property
def supports_default_in_lead_lag(self):
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
return not self.connection.mysql_is_mariadb
@cached_property
def supports_json_field(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (5, 7, 8)
@cached_property
def can_introspect_json_field(self):
if self.connection.mysql_is_mariadb:
return self.supports_json_field and self.can_introspect_check_constraints
return self.supports_json_field
@cached_property
def supports_index_column_ordering(self):
if self._mysql_storage_engine != "InnoDB":
return False
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 8)
return self.connection.mysql_version >= (8, 0, 1)
@cached_property
def supports_expression_indexes(self):
return (
not self.connection.mysql_is_mariadb
and self._mysql_storage_engine != "MyISAM"
and self.connection.mysql_version >= (8, 0, 13)
)
@cached_property
def can_rename_index(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 5, 2)
return True

View File

@@ -0,0 +1,335 @@
from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.models import Index
from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple(
"FieldInfo", BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint")
)
InfoLine = namedtuple(
"InfoLine",
"col_name data_type max_len num_prec num_scale extra column_default "
"collation is_unsigned",
)
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FIELD_TYPE.BLOB: "TextField",
FIELD_TYPE.CHAR: "CharField",
FIELD_TYPE.DECIMAL: "DecimalField",
FIELD_TYPE.NEWDECIMAL: "DecimalField",
FIELD_TYPE.DATE: "DateField",
FIELD_TYPE.DATETIME: "DateTimeField",
FIELD_TYPE.DOUBLE: "FloatField",
FIELD_TYPE.FLOAT: "FloatField",
FIELD_TYPE.INT24: "IntegerField",
FIELD_TYPE.JSON: "JSONField",
FIELD_TYPE.LONG: "IntegerField",
FIELD_TYPE.LONGLONG: "BigIntegerField",
FIELD_TYPE.SHORT: "SmallIntegerField",
FIELD_TYPE.STRING: "CharField",
FIELD_TYPE.TIME: "TimeField",
FIELD_TYPE.TIMESTAMP: "DateTimeField",
FIELD_TYPE.TINY: "IntegerField",
FIELD_TYPE.TINY_BLOB: "TextField",
FIELD_TYPE.MEDIUM_BLOB: "TextField",
FIELD_TYPE.LONG_BLOB: "TextField",
FIELD_TYPE.VAR_STRING: "CharField",
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if "auto_increment" in description.extra:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
if description.is_unsigned:
if field_type == "BigIntegerField":
return "PositiveBigIntegerField"
elif field_type == "IntegerField":
return "PositiveIntegerField"
elif field_type == "SmallIntegerField":
return "PositiveSmallIntegerField"
# JSON data type is an alias for LONGTEXT in MariaDB, use check
# constraints clauses to introspect JSONField.
if description.has_json_constraint:
return "JSONField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("SHOW FULL TABLES")
return [
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]))
for row in cursor.fetchall()
]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface."
"""
json_constraints = {}
if (
self.connection.mysql_is_mariadb
and self.connection.features.can_introspect_json_field
):
# JSON data type is an alias for LONGTEXT in MariaDB, select
# JSON_VALID() constraints to introspect JSONField.
cursor.execute(
"""
SELECT c.constraint_name AS column_name
FROM information_schema.check_constraints AS c
WHERE
c.table_name = %s AND
LOWER(c.check_clause) =
'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
c.constraint_schema = DATABASE()
""",
[table_name],
)
json_constraints = {row[0] for row in cursor.fetchall()}
# A default collation for the given table.
cursor.execute(
"""
SELECT table_collation
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
""",
[table_name],
)
row = cursor.fetchone()
default_column_collation = row[0] if row else ""
# information_schema database gives more accurate results for some figures:
# - varchar length returned by cursor.description is an internal length,
# not visible length (#5725)
# - precision and scale (for decimal fields) (#5014)
# - auto_increment is not available in cursor.description
cursor.execute(
"""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
CASE
WHEN collation_name = %s THEN NULL
ELSE collation_name
END AS collation_name,
CASE
WHEN column_type LIKE '%% unsigned' THEN 1
ELSE 0
END AS is_unsigned
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()
""",
[default_column_collation, table_name],
)
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
def to_int(i):
return int(i) if i is not None else i
fields = []
for line in cursor.description:
info = field_info[line[0]]
fields.append(
FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.collation,
info.extra,
info.is_unsigned,
line[0] in json_constraints,
)
)
return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if "auto_increment" in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{"table": table_name, "column": field_info.name}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all foreign keys in the given table.
"""
cursor.execute(
"""
SELECT column_name, referenced_column_name, referenced_table_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL
""",
[table_name],
)
return {
field_name: (other_field, other_table)
for field_name, other_field, other_table in cursor.fetchall()
}
def get_storage_engine(self, cursor, table_name):
"""
Retrieve the storage engine for a given table. Return the default
storage engine if the table doesn't exist.
"""
cursor.execute(
"""
SELECT engine
FROM information_schema.tables
WHERE
table_name = %s AND
table_schema = DATABASE()
""",
[table_name],
)
result = cursor.fetchone()
if not result:
return self.connection.features._mysql_storage_engine
return result[0]
def _parse_constraint_columns(self, check_clause, columns):
check_columns = OrderedSet()
statement = sqlparse.parse(check_clause)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if (
token.ttype == sqlparse.tokens.Name
and self.connection.ops.quote_name(token.value) == token.value
and token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1])
return check_columns
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Get the actual constraint names and columns
name_query = """
SELECT kc.`constraint_name`, kc.`column_name`,
kc.`referenced_table_name`, kc.`referenced_column_name`,
c.`constraint_type`
FROM
information_schema.key_column_usage AS kc,
information_schema.table_constraints AS c
WHERE
kc.table_schema = DATABASE() AND
c.table_schema = kc.table_schema AND
c.constraint_name = kc.constraint_name AND
c.constraint_type != 'CHECK' AND
kc.table_name = %s
ORDER BY kc.`ordinal_position`
"""
cursor.execute(name_query, [table_name])
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
"columns": OrderedSet(),
"primary_key": kind == "PRIMARY KEY",
"unique": kind in {"PRIMARY KEY", "UNIQUE"},
"index": False,
"check": False,
"foreign_key": (ref_table, ref_column) if ref_column else None,
}
if self.connection.features.supports_index_column_ordering:
constraints[constraint]["orders"] = []
constraints[constraint]["columns"].add(column)
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
columns = {
info.name for info in self.get_table_description(cursor, table_name)
}
if self.connection.mysql_is_mariadb:
type_query = """
SELECT c.constraint_name, c.check_clause
FROM information_schema.check_constraints AS c
WHERE
c.constraint_schema = DATABASE() AND
c.table_name = %s
"""
else:
type_query = """
SELECT cc.constraint_name, cc.check_clause
FROM
information_schema.check_constraints AS cc,
information_schema.table_constraints AS tc
WHERE
cc.constraint_schema = DATABASE() AND
tc.table_schema = cc.constraint_schema AND
cc.constraint_name = tc.constraint_name AND
tc.constraint_type = 'CHECK' AND
tc.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
constraint_columns = self._parse_constraint_columns(
check_clause, columns
)
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
constraints[constraint] = {
"columns": constraint_columns,
"primary_key": False,
"unique": False,
"index": False,
"check": True,
"foreign_key": None,
}
# Now add in the indexes
cursor.execute(
"SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
)
for table, non_unique, index, colseq, column, order, type_ in [
x[:6] + (x[10],) for x in cursor.fetchall()
]:
if index not in constraints:
constraints[index] = {
"columns": OrderedSet(),
"primary_key": False,
"unique": not non_unique,
"check": False,
"foreign_key": None,
}
if self.connection.features.supports_index_column_ordering:
constraints[index]["orders"] = []
constraints[index]["index"] = True
constraints[index]["type"] = (
Index.suffix if type_ == "BTREE" else type_.lower()
)
constraints[index]["columns"].add(column)
if self.connection.features.supports_index_column_ordering:
constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint["columns"] = list(constraint["columns"])
return constraints

View File

@@ -0,0 +1,467 @@
import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.mysql.compiler"
# MySQL stores positive fields as UNSIGNED ints.
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
"PositiveSmallIntegerField": (0, 65535),
"PositiveIntegerField": (0, 4294967295),
"PositiveBigIntegerField": (0, 18446744073709551615),
}
cast_data_types = {
"AutoField": "signed integer",
"BigAutoField": "signed integer",
"SmallAutoField": "signed integer",
"CharField": "char(%(max_length)s)",
"DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
"TextField": "char",
"IntegerField": "signed integer",
"BigIntegerField": "signed integer",
"SmallIntegerField": "signed integer",
"PositiveBigIntegerField": "unsigned integer",
"PositiveIntegerField": "unsigned integer",
"PositiveSmallIntegerField": "unsigned integer",
"DurationField": "signed integer",
}
cast_char_field_without_max_length = "char"
explain_prefix = "EXPLAIN"
# EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == "week_day":
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return f"DAYOFWEEK({sql})", params
elif lookup_type == "iso_week_day":
# WEEKDAY() returns an integer, 0-6, Monday=0.
return f"WEEKDAY({sql}) + 1", params
elif lookup_type == "week":
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return f"WEEK({sql}, 3)", params
elif lookup_type == "iso_year":
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
else:
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
return f"EXTRACT({lookup_type} FROM {sql})", params
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
fields = {
"year": "%Y-01-01",
"month": "%Y-%m-01",
}
if lookup_type in fields:
format_str = fields[lookup_type]
return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
elif lookup_type == "quarter":
return (
f"MAKEDATE(YEAR({sql}), 1) + "
f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
(*params, *params),
)
elif lookup_type == "week":
return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
else:
return f"DATE({sql})", params
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname
def _convert_sql_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
return f"CONVERT_TZ({sql}, %s, %s)", (
*params,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return sql, params
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"DATE({sql})", params
def datetime_cast_time_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"TIME({sql})", params
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
fields = ["year", "month", "day", "hour", "minute", "second"]
format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
if lookup_type == "quarter":
return (
f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
f"INTERVAL QUARTER({sql}) QUARTER - "
f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
), (*params, *params, "%Y-%m-01 00:00:00")
if lookup_type == "week":
return (
f"CAST(DATE_FORMAT("
f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
), (*params, *params, "%Y-%m-%d 00:00:00")
try:
i = fields.index(lookup_type) + 1
except ValueError:
pass
else:
format_str = "".join(format[:i] + format_def[i:])
return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
return sql, params
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
fields = {
"hour": "%H:00:00",
"minute": "%H:%i:00",
"second": "%H:%i:%s",
}
if lookup_type in fields:
format_str = fields[lookup_type]
return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
else:
return f"TIME({sql})", params
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
return "INTERVAL %s MICROSECOND" % sql
def force_no_ordering(self):
"""
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
columns. If no ordering would otherwise be applied, we don't want any
implicit sorting going on.
"""
return [(None, ("NULL", [], False))]
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
return value
def last_executed_query(self, cursor, sql, params):
# With MySQLdb, cursor objects have an (undocumented) "_executed"
# attribute where the exact query sent to the database is saved.
# See MySQLdb/cursors.py in the source distribution.
# MySQLdb returns string, PyMySQL bytes.
return force_str(getattr(cursor, "_executed", None), errors="replace")
def no_limit_value(self):
# 2**64 - 1, as recommended by the MySQL documentation
return 18446744073709551615
def quote_name(self, name):
if name.startswith("`") and name.endswith("`"):
return name # Quoting once is enough.
return "`%s`" % name
def return_insert_columns(self, fields):
# MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
# statement.
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
if reset_sequences:
# It's faster to TRUNCATE tables that require a sequence reset
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
sql.extend(
"%s %s;"
% (
style.SQL_KEYWORD("TRUNCATE"),
style.SQL_FIELD(self.quote_name(table_name)),
)
for table_name in tables
)
else:
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
# and preserves sequences.
sql.extend(
"%s %s %s;"
% (
style.SQL_KEYWORD("DELETE"),
style.SQL_KEYWORD("FROM"),
style.SQL_FIELD(self.quote_name(table_name)),
)
for table_name in tables
)
sql.append("SET FOREIGN_KEY_CHECKS = 1;")
return sql
def sequence_reset_by_name_sql(self, style, sequences):
return [
"%s %s %s %s = 1;"
% (
style.SQL_KEYWORD("ALTER"),
style.SQL_KEYWORD("TABLE"),
style.SQL_FIELD(self.quote_name(sequence_info["table"])),
style.SQL_FIELD("AUTO_INCREMENT"),
)
for sequence_info in sequences
]
def validate_autopk_value(self, value):
# Zero in AUTO_INCREMENT field does not work without the
# NO_AUTO_VALUE_ON_ZERO SQL mode.
if value == 0 and not self.connection.features.allows_auto_pk_0:
raise ValueError(
"The database backend does not accept 0 as a value for AutoField."
)
return value
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
# MySQL doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"MySQL backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
return str(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")
return value.isoformat(timespec="microseconds")
def max_name_length(self):
return 64
def pk_default_value(self):
return "NULL"
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
if connector == "^":
return "POW(%s)" % ",".join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
# return an unsigned integer.
elif connector in ("&", "|", "<<", "#"):
connector = "^" if connector == "#" else connector
return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
elif connector == ">>":
lhs, rhs = sub_expressions
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
return super().combine_expression(connector, sub_expressions)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
elif internal_type == "DateTimeField":
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
return converters
def convert_booleanfield_value(self, value, expression, connection):
if value in (0, 1):
value = bool(value)
return value
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def binary_placeholder_sql(self, value):
return (
"_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
)
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
if internal_type == "TimeField":
if self.connection.mysql_is_mariadb:
# MariaDB includes the microsecond component in TIME_TO_SEC as
# a decimal. MySQL returns an integer without microseconds.
return (
"CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) "
"* 1000000 AS SIGNED)"
) % {
"lhs": lhs_sql,
"rhs": rhs_sql,
}, (
*lhs_params,
*rhs_params,
)
return (
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
rhs_params
) * 2
params = (*rhs_params, *lhs_params)
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
def explain_query_prefix(self, format=None, **options):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
if format and format.upper() == "TEXT":
format = "TRADITIONAL"
elif (
not format and "TREE" in self.connection.features.supported_explain_formats
):
# Use TREE by default (if supported) as it's more informative.
format = "TREE"
analyze = options.pop("analyze", False)
prefix = super().explain_query_prefix(format, **options)
if analyze and self.connection.features.supports_explain_analyze:
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
prefix = (
"ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
)
if format and not (analyze and not self.connection.mysql_is_mariadb):
# Only MariaDB supports the analyze option with formats.
prefix += " FORMAT=%s" % format
return prefix
def regex_lookup(self, lookup_type):
# REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
# doesn't exist in MySQL 5.x or in MariaDB.
if (
self.connection.mysql_version < (8, 0, 0)
or self.connection.mysql_is_mariadb
):
if lookup_type == "regex":
return "%s REGEXP BINARY %s"
return "%s REGEXP %s"
match_option = "c" if lookup_type == "regex" else "i"
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return "INSERT IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
if internal_type == "JSONField":
if self.connection.mysql_is_mariadb or lookup_type in (
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
):
lookup = "JSON_UNQUOTE(%s)"
return lookup
def conditional_expression_supported_in_where_clause(self, expression):
# MySQL ignores indexes with boolean fields unless they're compared
# directly to a boolean value.
if isinstance(expression, (Exists, Lookup)):
return True
if isinstance(expression, ExpressionWrapper) and expression.conditional:
return self.conditional_expression_supported_in_where_clause(
expression.expression
)
if getattr(expression, "conditional", False):
return False
return super().conditional_expression_supported_in_where_clause(expression)
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.UPDATE:
conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
field_sql = "%(field)s = VALUES(%(field)s)"
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
# aliases for the new row and its columns available in MySQL
# 8.0.19+.
if not self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (8, 0, 19):
conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
field_sql = "%(field)s = new.%(field)s"
# VALUES() was renamed to VALUE() in MariaDB 10.3.3+.
elif self.connection.mysql_version >= (10, 3, 3):
field_sql = "%(field)s = VALUE(%(field)s)"
fields = ", ".join(
[
field_sql % {"field": field}
for field in map(self.quote_name, update_fields)
]
)
return conflict_suffix_sql % {"fields": fields}
return super().on_conflict_suffix_sql(
fields,
on_conflict,
update_fields,
unique_fields,
)

View File

@@ -0,0 +1,175 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import NOT_PROVIDED
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL"
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
sql_create_column_inline_fk = (
", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
"REFERENCES %(to_table)s(%(to_column)s)"
)
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
sql_rename_index = "ALTER TABLE %(table)s RENAME INDEX %(old_name)s TO %(new_name)s"
sql_create_pk = (
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
)
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
@property
def sql_delete_check(self):
if self.connection.mysql_is_mariadb:
# The name of the column check constraint is the same as the field
# name on MariaDB. Adding IF EXISTS clause prevents migrations
# crash. Constraint is removed during a "MODIFY" column statement.
return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s"
return "ALTER TABLE %(table)s DROP CHECK %(name)s"
@property
def sql_rename_column(self):
# MariaDB >= 10.5.2 and MySQL >= 8.0.4 support an
# "ALTER TABLE ... RENAME COLUMN" statement.
if self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (10, 5, 2):
return super().sql_rename_column
elif self.connection.mysql_version >= (8, 0, 4):
return super().sql_rename_column
return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
def quote_value(self, value):
self.connection.ensure_connection()
if isinstance(value, str):
value = value.replace("%", "%%")
# MySQLdb escapes to string, PyMySQL to bytes.
quoted = self.connection.connection.escape(
value, self.connection.connection.encoders
)
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted
def _is_limited_data_type(self, field):
db_type = field.db_type(self.connection)
return (
db_type is not None
and db_type.lower() in self.connection._limited_data_types
)
def skip_default(self, field):
if not self._supports_limited_data_type_defaults:
return self._is_limited_data_type(field)
return False
def skip_default_on_alter(self, field):
if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
# MySQL doesn't support defaults for BLOB and TEXT in the
# ALTER COLUMN statement.
return True
return False
@property
def _supports_limited_data_type_defaults(self):
# MariaDB and MySQL >= 8.0.13 support defaults for BLOB and TEXT.
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 13)
def _column_default_sql(self, field):
if (
not self.connection.mysql_is_mariadb
and self._supports_limited_data_type_defaults
and self._is_limited_data_type(field)
):
# MySQL supports defaults for BLOB and TEXT columns only if the
# default value is written as an expression i.e. in parentheses.
return "(%s)"
return super()._column_default_sql(field)
def add_field(self, model, field):
super().add_field(model, field)
# Simulate the effect of a one-off default.
# field.default may be unhashable, so a set isn't used for "in" check.
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
effective_default = self.effective_default(field)
self.execute(
"UPDATE %(table)s SET %(column)s = %%s"
% {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
},
[effective_default],
)
def _field_should_be_indexed(self, model, field):
if not super()._field_should_be_indexed(model, field):
return False
storage = self.connection.introspection.get_storage_engine(
self.connection.cursor(), model._meta.db_table
)
# No need to create an index for ForeignKey fields except if
# db_constraint=False because the index from that constraint won't be
# created.
if (
storage == "InnoDB"
and field.get_internal_type() == "ForeignKey"
and field.db_constraint
):
return False
return not self._is_limited_data_type(field)
def _delete_composed_index(self, model, fields, *args):
"""
MySQL can remove an implicit FK index on a field when that field is
covered by another index like a unique_together. "covered" here means
that the more complex index starts like the simpler one.
https://bugs.mysql.com/bug.php?id=37910 / Django ticket #24757
We check here before removing the [unique|index]_together if we have to
recreate a FK index.
"""
first_field = model._meta.get_field(fields[0])
if first_field.get_internal_type() == "ForeignKey":
constraint_names = self._constraint_names(
model, [first_field.column], index=True
)
if not constraint_names:
self.execute(
self._create_index_sql(model, fields=[first_field], suffix="")
)
return super()._delete_composed_index(model, fields, *args)
def _set_field_new_type_null_status(self, field, new_type):
"""
Keep the null property of the old field. If it has changed, it will be
handled separately.
"""
if field.null:
new_type += " NULL"
else:
new_type += " NOT NULL"
return new_type
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _rename_field_sql(self, table, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._rename_field_sql(table, old_field, new_field, new_type)

View File

@@ -0,0 +1,77 @@
from django.core import checks
from django.db.backends.base.validation import BaseDatabaseValidation
from django.utils.version import get_docs_version
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
issues = super().check(**kwargs)
issues.extend(self._check_sql_mode(**kwargs))
return issues
def _check_sql_mode(self, **kwargs):
if not (
self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
):
return [
checks.Warning(
"%s Strict Mode is not set for database connection '%s'"
% (self.connection.display_name, self.connection.alias),
hint=(
"%s's Strict Mode fixes many data integrity problems in "
"%s, such as data truncation upon insertion, by "
"escalating warnings into errors. It is strongly "
"recommended you activate it. See: "
"https://docs.djangoproject.com/en/%s/ref/databases/"
"#mysql-sql-mode"
% (
self.connection.display_name,
self.connection.display_name,
get_docs_version(),
),
),
id="mysql.W002",
)
]
return []
def check_field_type(self, field, field_type):
"""
MySQL has the following field length restriction:
No character (varchar) fields can have a length exceeding 255
characters if they have a unique index on them.
MySQL doesn't support a database index on some data types.
"""
errors = []
if (
field_type.startswith("varchar")
and field.unique
and (field.max_length is None or int(field.max_length) > 255)
):
errors.append(
checks.Warning(
"%s may not allow unique CharFields to have a max_length "
"> 255." % self.connection.display_name,
obj=field,
hint=(
"See: https://docs.djangoproject.com/en/%s/ref/"
"databases/#mysql-character-fields" % get_docs_version()
),
id="mysql.W003",
)
)
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
"%s does not support a database index on %s columns."
% (self.connection.display_name, field_type),
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id="fields.W162",
)
)
return errors