before send to remote

This commit is contained in:
2022-08-26 16:41:18 +06:00
commit 3814beb3e0
5408 changed files with 652023 additions and 0 deletions

View File

@@ -0,0 +1,512 @@
"""
Implementations of SQL functions for SQLite.
"""
import functools
import random
import statistics
from datetime import timedelta
from hashlib import sha1, sha224, sha256, sha384, sha512
from math import (
acos,
asin,
atan,
atan2,
ceil,
cos,
degrees,
exp,
floor,
fmod,
log,
pi,
radians,
sin,
sqrt,
tan,
)
from re import search as re_search
from django.db.backends.base.base import timezone_constructor
from django.db.backends.utils import (
split_tzname_delta,
typecast_time,
typecast_timestamp,
)
from django.utils import timezone
from django.utils.crypto import md5
from django.utils.duration import duration_microseconds
def register(connection):
create_deterministic_function = functools.partial(
connection.create_function,
deterministic=True,
)
create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract)
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
create_deterministic_function(
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
)
create_deterministic_function(
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
)
create_deterministic_function(
"django_datetime_extract", 4, _sqlite_datetime_extract
)
create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc)
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
create_deterministic_function("regexp", 2, _sqlite_regexp)
create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
create_deterministic_function("COT", 1, _sqlite_cot)
create_deterministic_function("LPAD", 3, _sqlite_lpad)
create_deterministic_function("MD5", 1, _sqlite_md5)
create_deterministic_function("REPEAT", 2, _sqlite_repeat)
create_deterministic_function("REVERSE", 1, _sqlite_reverse)
create_deterministic_function("RPAD", 3, _sqlite_rpad)
create_deterministic_function("SHA1", 1, _sqlite_sha1)
create_deterministic_function("SHA224", 1, _sqlite_sha224)
create_deterministic_function("SHA256", 1, _sqlite_sha256)
create_deterministic_function("SHA384", 1, _sqlite_sha384)
create_deterministic_function("SHA512", 1, _sqlite_sha512)
create_deterministic_function("SIGN", 1, _sqlite_sign)
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
connection.create_function("RAND", 0, random.random)
connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
connection.create_aggregate("VAR_POP", 1, VarPop)
connection.create_aggregate("VAR_SAMP", 1, VarSamp)
# Some math functions are enabled by default in SQLite 3.35+.
sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
if not connection.execute(sql).fetchone()[0]:
create_deterministic_function("ACOS", 1, _sqlite_acos)
create_deterministic_function("ASIN", 1, _sqlite_asin)
create_deterministic_function("ATAN", 1, _sqlite_atan)
create_deterministic_function("ATAN2", 2, _sqlite_atan2)
create_deterministic_function("CEILING", 1, _sqlite_ceiling)
create_deterministic_function("COS", 1, _sqlite_cos)
create_deterministic_function("DEGREES", 1, _sqlite_degrees)
create_deterministic_function("EXP", 1, _sqlite_exp)
create_deterministic_function("FLOOR", 1, _sqlite_floor)
create_deterministic_function("LN", 1, _sqlite_ln)
create_deterministic_function("LOG", 2, _sqlite_log)
create_deterministic_function("MOD", 2, _sqlite_mod)
create_deterministic_function("PI", 0, _sqlite_pi)
create_deterministic_function("POWER", 2, _sqlite_power)
create_deterministic_function("RADIANS", 1, _sqlite_radians)
create_deterministic_function("SIN", 1, _sqlite_sin)
create_deterministic_function("SQRT", 1, _sqlite_sqrt)
create_deterministic_function("TAN", 1, _sqlite_tan)
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
if dt is None:
return None
try:
dt = typecast_timestamp(dt)
except (TypeError, ValueError):
return None
if conn_tzname:
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
hours, minutes = offset.split(":")
offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == "+" else -offset_delta
dt = timezone.localtime(dt, timezone_constructor(tzname))
return dt
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
return f"{dt.year:04d}-01-01"
elif lookup_type == "quarter":
month_in_quarter = dt.month - (dt.month - 1) % 3
return f"{dt.year:04d}-{month_in_quarter:02d}-01"
elif lookup_type == "month":
return f"{dt.year:04d}-{dt.month:02d}-01"
elif lookup_type == "week":
dt = dt - timedelta(days=dt.weekday())
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
elif lookup_type == "day":
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
if dt is None:
return None
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt_parsed is None:
try:
dt = typecast_time(dt)
except (ValueError, TypeError):
return None
else:
dt = dt_parsed
if lookup_type == "hour":
return f"{dt.hour:02d}:00:00"
elif lookup_type == "minute":
return f"{dt.hour:02d}:{dt.minute:02d}:00"
elif lookup_type == "second":
return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.date().isoformat()
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.time().isoformat()
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "week_day":
return (dt.isoweekday() % 7) + 1
elif lookup_type == "iso_week_day":
return dt.isoweekday()
elif lookup_type == "week":
return dt.isocalendar()[1]
elif lookup_type == "quarter":
return ceil(dt.month / 3)
elif lookup_type == "iso_year":
return dt.isocalendar()[0]
else:
return getattr(dt, lookup_type)
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
return f"{dt.year:04d}-01-01 00:00:00"
elif lookup_type == "quarter":
month_in_quarter = dt.month - (dt.month - 1) % 3
return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
elif lookup_type == "month":
return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
elif lookup_type == "week":
dt = dt - timedelta(days=dt.weekday())
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
elif lookup_type == "day":
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
elif lookup_type == "hour":
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
elif lookup_type == "minute":
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
f"{dt.hour:02d}:{dt.minute:02d}:00"
)
elif lookup_type == "second":
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
)
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_time_extract(lookup_type, dt):
if dt is None:
return None
try:
dt = typecast_time(dt)
except (ValueError, TypeError):
return None
return getattr(dt, lookup_type)
def _sqlite_prepare_dtdelta_param(conn, param):
if conn in ["+", "-"]:
if isinstance(param, int):
return timedelta(0, 0, param)
else:
return typecast_timestamp(param)
return param
def _sqlite_format_dtdelta(connector, lhs, rhs):
"""
LHS and RHS can be either:
- An integer number of microseconds
- A string representing a datetime
- A scalar value, e.g. float
"""
if connector is None or lhs is None or rhs is None:
return None
connector = connector.strip()
try:
real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
except (ValueError, TypeError):
return None
if connector == "+":
# typecast_timestamp() returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
elif connector == "-":
out = str(real_lhs - real_rhs)
elif connector == "*":
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out
def _sqlite_time_diff(lhs, rhs):
if lhs is None or rhs is None:
return None
left = typecast_time(lhs)
right = typecast_time(rhs)
return (
(left.hour * 60 * 60 * 1000000)
+ (left.minute * 60 * 1000000)
+ (left.second * 1000000)
+ (left.microsecond)
- (right.hour * 60 * 60 * 1000000)
- (right.minute * 60 * 1000000)
- (right.second * 1000000)
- (right.microsecond)
)
def _sqlite_timestamp_diff(lhs, rhs):
if lhs is None or rhs is None:
return None
left = typecast_timestamp(lhs)
right = typecast_timestamp(rhs)
return duration_microseconds(left - right)
def _sqlite_regexp(pattern, string):
if pattern is None or string is None:
return None
if not isinstance(string, str):
string = str(string)
return bool(re_search(pattern, string))
def _sqlite_acos(x):
if x is None:
return None
return acos(x)
def _sqlite_asin(x):
if x is None:
return None
return asin(x)
def _sqlite_atan(x):
if x is None:
return None
return atan(x)
def _sqlite_atan2(y, x):
if y is None or x is None:
return None
return atan2(y, x)
def _sqlite_bitxor(x, y):
if x is None or y is None:
return None
return x ^ y
def _sqlite_ceiling(x):
if x is None:
return None
return ceil(x)
def _sqlite_cos(x):
if x is None:
return None
return cos(x)
def _sqlite_cot(x):
if x is None:
return None
return 1 / tan(x)
def _sqlite_degrees(x):
if x is None:
return None
return degrees(x)
def _sqlite_exp(x):
if x is None:
return None
return exp(x)
def _sqlite_floor(x):
if x is None:
return None
return floor(x)
def _sqlite_ln(x):
if x is None:
return None
return log(x)
def _sqlite_log(base, x):
if base is None or x is None:
return None
# Arguments reversed to match SQL standard.
return log(x, base)
def _sqlite_lpad(text, length, fill_text):
if text is None or length is None or fill_text is None:
return None
delta = length - len(text)
if delta <= 0:
return text[:length]
return (fill_text * length)[:delta] + text
def _sqlite_md5(text):
if text is None:
return None
return md5(text.encode()).hexdigest()
def _sqlite_mod(x, y):
if x is None or y is None:
return None
return fmod(x, y)
def _sqlite_pi():
return pi
def _sqlite_power(x, y):
if x is None or y is None:
return None
return x**y
def _sqlite_radians(x):
if x is None:
return None
return radians(x)
def _sqlite_repeat(text, count):
if text is None or count is None:
return None
return text * count
def _sqlite_reverse(text):
if text is None:
return None
return text[::-1]
def _sqlite_rpad(text, length, fill_text):
if text is None or length is None or fill_text is None:
return None
return (text + fill_text * length)[:length]
def _sqlite_sha1(text):
if text is None:
return None
return sha1(text.encode()).hexdigest()
def _sqlite_sha224(text):
if text is None:
return None
return sha224(text.encode()).hexdigest()
def _sqlite_sha256(text):
if text is None:
return None
return sha256(text.encode()).hexdigest()
def _sqlite_sha384(text):
if text is None:
return None
return sha384(text.encode()).hexdigest()
def _sqlite_sha512(text):
if text is None:
return None
return sha512(text.encode()).hexdigest()
def _sqlite_sign(x):
if x is None:
return None
return (x > 0) - (x < 0)
def _sqlite_sin(x):
if x is None:
return None
return sin(x)
def _sqlite_sqrt(x):
if x is None:
return None
return sqrt(x)
def _sqlite_tan(x):
if x is None:
return None
return tan(x)
class ListAggregate(list):
step = list.append
class StdDevPop(ListAggregate):
finalize = statistics.pstdev
class StdDevSamp(ListAggregate):
finalize = statistics.stdev
class VarPop(ListAggregate):
finalize = statistics.pvariance
class VarSamp(ListAggregate):
finalize = statistics.variance

View File

@@ -0,0 +1,364 @@
"""
SQLite backend for the sqlite3 module in the standard library.
"""
import decimal
import warnings
from itertools import chain
from sqlite3 import dbapi2 as Database
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.dateparse import parse_datetime, parse_time
from django.utils.regex_helper import _lazy_re_compile
from ._functions import register as register_functions
from .client import DatabaseClient
from .creation import DatabaseCreation
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
def decoder(conv_func):
"""
Convert bytestrings from Python's sqlite3 interface to a regular string.
"""
return lambda s: conv_func(s.decode())
Database.register_converter("bool", b"1".__eq__)
Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime))
Database.register_converter("timestamp", decoder(parse_datetime))
Database.register_adapter(decimal.Decimal, str)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "sqlite"
display_name = "SQLite"
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
"AutoField": "integer",
"BigAutoField": "integer",
"BinaryField": "BLOB",
"BooleanField": "bool",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "datetime",
"DecimalField": "decimal",
"DurationField": "bigint",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "real",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "char(15)",
"GenericIPAddressField": "char(39)",
"JSONField": "text",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint unsigned",
"PositiveIntegerField": "integer unsigned",
"PositiveSmallIntegerField": "smallint unsigned",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "integer",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
"UUIDField": "char(32)",
}
data_type_check_constraints = {
"PositiveBigIntegerField": '"%(column)s" >= 0',
"JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
}
data_types_suffix = {
"AutoField": "AUTOINCREMENT",
"BigAutoField": "AUTOINCREMENT",
"SmallAutoField": "AUTOINCREMENT",
}
# SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it.
# See https://www.sqlite.org/lang_expr.html for an explanation.
operators = {
"exact": "= %s",
"iexact": "LIKE %s ESCAPE '\\'",
"contains": "LIKE %s ESCAPE '\\'",
"icontains": "LIKE %s ESCAPE '\\'",
"regex": "REGEXP %s",
"iregex": "REGEXP '(?i)' || %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE %s ESCAPE '\\'",
"endswith": "LIKE %s ESCAPE '\\'",
"istartswith": "LIKE %s ESCAPE '\\'",
"iendswith": "LIKE %s ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
"contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
"icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
"startswith": r"LIKE {} || '%%' ESCAPE '\'",
"istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
"endswith": r"LIKE '%%' || {} ESCAPE '\'",
"iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
def get_connection_params(self):
settings_dict = self.settings_dict
if not settings_dict["NAME"]:
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value."
)
kwargs = {
"database": settings_dict["NAME"],
"detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
**settings_dict["OPTIONS"],
}
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
warnings.warn(
"The `check_same_thread` option was provided and set to "
"True. It will be overridden with False. Use the "
"`DatabaseWrapper.allow_thread_sharing` property instead "
"for controlling thread shareability.",
RuntimeWarning,
)
kwargs.update({"check_same_thread": False, "uri": True})
return kwargs
def get_database_version(self):
return self.Database.sqlite_version_info
@async_unsafe
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
register_functions(conn)
conn.execute("PRAGMA foreign_keys = ON")
# The macOS bundled SQLite defaults legacy_alter_table ON, which
# prevents atomic table renames (feature supports_atomic_references_rename)
conn.execute("PRAGMA legacy_alter_table = OFF")
return conn
def create_cursor(self, name=None):
return self.connection.cursor(factory=SQLiteCursorWrapper)
@async_unsafe
def close(self):
self.validate_thread_sharing()
# If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on
# an in-memory db.
if not self.is_in_memory_db():
BaseDatabaseWrapper.close(self)
def _savepoint_allowed(self):
# When 'isolation_level' is not None, sqlite3 commits before each
# savepoint; it's a bug. When it is None, savepoints don't make sense
# because autocommit is enabled. The only exception is inside 'atomic'
# blocks. To work around that bug, on SQLite, 'atomic' starts a
# transaction explicitly rather than simply disable autocommit.
return self.in_atomic_block
def _set_autocommit(self, autocommit):
if autocommit:
level = None
else:
# sqlite3's internal default is ''. It's different from None.
# See Modules/_sqlite/connection.c.
level = ""
# 'isolation_level' is a misleading API.
# SQLite always runs at the SERIALIZABLE isolation level.
with self.wrap_database_errors:
self.connection.isolation_level = level
def disable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute("PRAGMA foreign_keys = OFF")
# Foreign key constraints cannot be turned off while in a multi-
# statement transaction. Fetch the current state of the pragma
# to determine if constraints are effectively disabled.
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
return not bool(enabled)
def enable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute("PRAGMA foreign_keys = ON")
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor:
if table_names is None:
violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
else:
violations = chain.from_iterable(
cursor.execute(
"PRAGMA foreign_key_check(%s)"
% self.ops.quote_name(table_name)
).fetchall()
for table_name in table_names
)
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
for (
table_name,
rowid,
referenced_table_name,
foreign_key_index,
) in violations:
foreign_key = cursor.execute(
"PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
).fetchall()[foreign_key_index]
column_name, referenced_column_name = foreign_key[3:5]
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
primary_key_value, bad_value = cursor.execute(
"SELECT %s, %s FROM %s WHERE rowid = %%s"
% (
self.ops.quote_name(primary_key_column_name),
self.ops.quote_name(column_name),
self.ops.quote_name(table_name),
),
(rowid,),
).fetchone()
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
primary_key_value,
table_name,
column_name,
bad_value,
referenced_table_name,
referenced_column_name,
)
)
else:
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
if not primary_key_column_name:
continue
relations = self.introspection.get_relations(cursor, table_name)
for column_name, (
referenced_column_name,
referenced_table_name,
) in relations:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name,
column_name,
table_name,
referenced_table_name,
column_name,
referenced_column_name,
column_name,
referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
bad_row[0],
table_name,
column_name,
bad_row[1],
referenced_table_name,
referenced_column_name,
)
)
def is_usable(self):
return True
def _start_transaction_under_autocommit(self):
"""
Start a transaction explicitly in autocommit mode.
Staying in autocommit mode works around a bug of sqlite3 that breaks
savepoints when autocommit is disabled.
"""
self.cursor().execute("BEGIN")
def is_in_memory_db(self):
return self.creation.is_in_memory_db(self.settings_dict["NAME"])
FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
class SQLiteCursorWrapper(Database.Cursor):
"""
Django uses "format" style placeholders, but pysqlite2 uses "qmark" style.
This fixes it -- but note that if you want to use a literal "%s" in a query,
you'll need to use "%%s".
"""
def execute(self, query, params=None):
if params is None:
return Database.Cursor.execute(self, query)
query = self.convert_query(query)
return Database.Cursor.execute(self, query, params)
def executemany(self, query, param_list):
query = self.convert_query(query)
return Database.Cursor.executemany(self, query, param_list)
def convert_query(self, query):
return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")

View File

@@ -0,0 +1,10 @@
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "sqlite3"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, settings_dict["NAME"], *parameters]
return args, None

View File

@@ -0,0 +1,158 @@
import multiprocessing
import os
import shutil
import sqlite3
import sys
from pathlib import Path
from django.db import NotSupportedError
from django.db.backends.base.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
@staticmethod
def is_in_memory_db(database_name):
return not isinstance(database_name, Path) and (
database_name == ":memory:" or "mode=memory" in database_name
)
def _get_test_db_name(self):
test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
if test_database_name == ":memory:":
return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
return test_database_name
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
test_database_name = self._get_test_db_name()
if keepdb:
return test_database_name
if not self.is_in_memory_db(test_database_name):
# Erase the old test database
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (self._get_database_display_str(verbosity, test_database_name),)
)
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == "yes":
try:
os.remove(test_database_name)
except Exception as e:
self.log("Got an error deleting the old test database: %s" % e)
sys.exit(2)
else:
self.log("Tests cancelled.")
sys.exit(1)
return test_database_name
def get_test_db_clone_settings(self, suffix):
orig_settings_dict = self.connection.settings_dict
source_database_name = orig_settings_dict["NAME"]
if not self.is_in_memory_db(source_database_name):
root, ext = os.path.splitext(source_database_name)
return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"}
start_method = multiprocessing.get_start_method()
if start_method == "fork":
return orig_settings_dict
if start_method == "spawn":
return {
**orig_settings_dict,
"NAME": f"{self.connection.alias}_{suffix}.sqlite3",
}
raise NotSupportedError(
f"Cloning with start method {start_method!r} is not supported."
)
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
if not self.is_in_memory_db(source_database_name):
# Erase the old test database
if os.access(target_database_name, os.F_OK):
if keepdb:
return
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
try:
os.remove(target_database_name)
except Exception as e:
self.log("Got an error deleting the old test database: %s" % e)
sys.exit(2)
try:
shutil.copy(source_database_name, target_database_name)
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
sys.exit(2)
# Forking automatically makes a copy of an in-memory database.
# Spawn requires migrating to disk which will be re-opened in
# setup_worker_connection.
elif multiprocessing.get_start_method() == "spawn":
ondisk_db = sqlite3.connect(target_database_name, uri=True)
self.connection.connection.backup(ondisk_db)
def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and not self.is_in_memory_db(test_database_name):
# Remove the SQLite database file
os.remove(test_database_name)
def test_db_signature(self):
"""
Return a tuple that uniquely identifies a test database.
This takes into account the special cases of ":memory:" and "" for
SQLite since the databases will be distinct despite having the same
TEST NAME. See https://www.sqlite.org/inmemorydb.html
"""
test_database_name = self._get_test_db_name()
sig = [self.connection.settings_dict["NAME"]]
if self.is_in_memory_db(test_database_name):
sig.append(self.connection.alias)
else:
sig.append(test_database_name)
return tuple(sig)
def setup_worker_connection(self, _worker_id):
settings_dict = self.get_test_db_clone_settings(_worker_id)
# connection.settings_dict must be updated in place for changes to be
# reflected in django.db.connections. Otherwise new threads would
# connect to the default database instead of the appropriate clone.
start_method = multiprocessing.get_start_method()
if start_method == "fork":
# Update settings_dict in place.
self.connection.settings_dict.update(settings_dict)
self.connection.close()
elif start_method == "spawn":
alias = self.connection.alias
connection_str = (
f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"
)
source_db = self.connection.Database.connect(
f"file:{alias}_{_worker_id}.sqlite3", uri=True
)
target_db = sqlite3.connect(connection_str, uri=True)
source_db.backup(target_db)
source_db.close()
# Update settings_dict in place.
self.connection.settings_dict.update(settings_dict)
self.connection.settings_dict["NAME"] = connection_str
# Re-open connection to in-memory database before closing copy
# connection.
self.connection.connect()
target_db.close()
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
self.mark_expected_failures_and_skips()

View File

@@ -0,0 +1,141 @@
import operator
from django.db import transaction
from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.utils import OperationalError
from django.utils.functional import cached_property
from .base import Database
class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (3, 9)
test_db_allows_multiple_connections = False
supports_unspecified_pk = True
supports_timezones = False
max_query_params = 999
supports_transactions = True
atomic_transactions = False
can_rollback_ddl = True
can_create_inline_fk = False
supports_paramstyle_pyformat = False
requires_literal_defaults = True
can_clone_databases = True
supports_temporal_subtraction = True
ignores_table_name_case = True
supports_cast_with_precision = False
time_cast_precision = 3
can_release_savepoints = True
has_case_insensitive_like = True
# Is "ALTER TABLE ... RENAME COLUMN" supported?
can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0)
# Is "ALTER TABLE ... DROP COLUMN" supported?
can_alter_table_drop_column = Database.sqlite_version_info >= (3, 35, 5)
supports_parentheses_in_compound = False
# Deferred constraint checks can be emulated on SQLite < 3.20 but not in a
# reasonably performant way.
supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0)
can_defer_constraint_checks = supports_pragma_foreign_key_check
supports_functions_in_partial_indexes = Database.sqlite_version_info >= (3, 15, 0)
supports_over_clause = Database.sqlite_version_info >= (3, 25, 0)
supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
order_by_nulls_first = True
supports_json_field_contains = False
supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
supports_update_conflicts_with_target = supports_update_conflicts
test_collations = {
"ci": "nocase",
"cs": "binary",
"non_default": "nocase",
}
django_test_expected_failures = {
# The django_format_dtdelta() function doesn't properly handle mixed
# Date/DateTime fields and timedeltas.
"expressions.tests.FTimeDeltaTests.test_mixed_comparisons1",
}
@cached_property
def django_test_skips(self):
skips = {
"SQLite stores values rounded to 15 significant digits.": {
"model_fields.test_decimalfield.DecimalFieldTests."
"test_fetch_from_db_without_float_rounding",
},
"SQLite naively remakes the table on field alteration.": {
"schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
"schema.tests.SchemaTests.test_unique_and_reverse_m2m",
"schema.tests.SchemaTests."
"test_alter_field_default_doesnt_perform_queries",
"schema.tests.SchemaTests."
"test_rename_column_renames_deferred_sql_references",
},
"SQLite doesn't support negative precision for ROUND().": {
"db_functions.math.test_round.RoundTests."
"test_null_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_decimal_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_float_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_integer_with_negative_precision",
},
}
if Database.sqlite_version_info < (3, 27):
skips.update(
{
"Nondeterministic failure on SQLite < 3.27.": {
"expressions_window.tests.WindowFunctionTests."
"test_subquery_row_range_rank",
},
}
)
if self.connection.is_in_memory_db():
skips.update(
{
"the sqlite backend's close() method is a no-op when using an "
"in-memory database": {
"servers.test_liveserverthread.LiveServerThreadTest."
"test_closes_connections",
"servers.tests.LiveServerTestCloseConnectionTest."
"test_closes_connections",
},
}
)
return skips
@cached_property
def supports_atomic_references_rename(self):
return Database.sqlite_version_info >= (3, 26, 0)
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"BigAutoField": "AutoField",
"DurationField": "BigIntegerField",
"GenericIPAddressField": "CharField",
"SmallAutoField": "AutoField",
}
@cached_property
def supports_json_field(self):
with self.connection.cursor() as cursor:
try:
with transaction.atomic(self.connection.alias):
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
except OperationalError:
return False
return True
can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
has_json_object_function = property(operator.attrgetter("supports_json_field"))
@cached_property
def can_return_columns_from_insert(self):
return Database.sqlite_version_info >= (3, 35)
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)

View File

@@ -0,0 +1,438 @@
from collections import namedtuple
import sqlparse
from django.db import DatabaseError
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile
FieldInfo = namedtuple(
"FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
)
field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
def get_field_size(name):
"""Extract the size number from a "varchar(11)" type name"""
m = field_size_re.search(name)
return int(m[1]) if m else None
# This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
# as a simple dictionary lookup.
class FlexibleFieldLookupDict:
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
"bool": "BooleanField",
"boolean": "BooleanField",
"smallint": "SmallIntegerField",
"smallint unsigned": "PositiveSmallIntegerField",
"smallinteger": "SmallIntegerField",
"int": "IntegerField",
"integer": "IntegerField",
"bigint": "BigIntegerField",
"integer unsigned": "PositiveIntegerField",
"bigint unsigned": "PositiveBigIntegerField",
"decimal": "DecimalField",
"real": "FloatField",
"text": "TextField",
"char": "CharField",
"varchar": "CharField",
"blob": "BinaryField",
"date": "DateField",
"datetime": "DateTimeField",
"time": "TimeField",
}
def __getitem__(self, key):
key = key.lower().split("(", 1)[0].strip()
return self.base_data_types_reverse[key]
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = FlexibleFieldLookupDict()
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.pk and field_type in {
"BigIntegerField",
"IntegerField",
"SmallIntegerField",
}:
# No support for BigAutoField or SmallAutoField as SQLite treats
# all integer primary keys as signed 64-bit integers.
return "AutoField"
if description.has_json_constraint:
return "JSONField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute(
"""
SELECT name, type FROM sqlite_master
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
ORDER BY name"""
)
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
cursor.execute(
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
)
table_info = cursor.fetchall()
if not table_info:
raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
collations = self._get_column_collations(cursor, table_name)
json_columns = set()
if self.connection.features.can_introspect_json_field:
for line in table_info:
column = line[1]
json_constraint_sql = '%%json_valid("%s")%%' % column
has_json_constraint = cursor.execute(
"""
SELECT sql
FROM sqlite_master
WHERE
type = 'table' AND
name = %s AND
sql LIKE %s
""",
[table_name, json_constraint_sql],
).fetchone()
if has_json_constraint:
json_columns.add(column)
return [
FieldInfo(
name,
data_type,
None,
get_field_size(data_type),
None,
None,
not notnull,
default,
collations.get(name),
pk == 1,
name in json_columns,
)
for cid, name, data_type, notnull, default, pk in table_info
]
def get_sequences(self, cursor, table_name, table_fields=()):
pk_col = self.get_primary_key_column(cursor, table_name)
return [{"table": table_name, "column": pk_col}]
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
representing all foreign keys in the given table.
"""
cursor.execute(
"PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
)
return {
column_name: (ref_column_name, ref_table_name)
for (
_,
_,
ref_table_name,
column_name,
ref_column_name,
*_,
) in cursor.fetchall()
}
def get_primary_key_column(self, cursor, table_name):
"""Return the column name of the primary key for the given table."""
cursor.execute(
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
)
for _, name, *_, pk in cursor.fetchall():
if pk:
return name
return None
def _parse_column_or_constraint_definition(self, tokens, columns):
token = None
is_constraint_definition = None
field_name = None
constraint_name = None
unique = False
unique_columns = []
check = False
check_columns = []
braces_deep = 0
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, "("):
braces_deep += 1
elif token.match(sqlparse.tokens.Punctuation, ")"):
braces_deep -= 1
if braces_deep < 0:
# End of columns and constraints for table definition.
break
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
# End of current column or constraint definition.
break
# Detect column or constraint definition by first token.
if is_constraint_definition is None:
is_constraint_definition = token.match(
sqlparse.tokens.Keyword, "CONSTRAINT"
)
if is_constraint_definition:
continue
if is_constraint_definition:
# Detect constraint name by second token.
if constraint_name is None:
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
constraint_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
constraint_name = token.value[1:-1]
# Start constraint columns parsing after UNIQUE keyword.
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
unique = True
unique_braces_deep = braces_deep
elif unique:
if unique_braces_deep == braces_deep:
if unique_columns:
# Stop constraint parsing.
unique = False
continue
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
unique_columns.append(token.value)
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
unique_columns.append(token.value[1:-1])
else:
# Detect field name by first token.
if field_name is None:
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
field_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
field_name = token.value[1:-1]
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
unique_columns = [field_name]
# Start constraint columns parsing after CHECK keyword.
if token.match(sqlparse.tokens.Keyword, "CHECK"):
check = True
check_braces_deep = braces_deep
elif check:
if check_braces_deep == braces_deep:
if check_columns:
# Stop constraint parsing.
check = False
continue
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
if token.value in columns:
check_columns.append(token.value)
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
if token.value[1:-1] in columns:
check_columns.append(token.value[1:-1])
unique_constraint = (
{
"unique": True,
"columns": unique_columns,
"primary_key": False,
"foreign_key": None,
"check": False,
"index": False,
}
if unique_columns
else None
)
check_constraint = (
{
"check": True,
"columns": check_columns,
"primary_key": False,
"unique": False,
"foreign_key": None,
"index": False,
}
if check_columns
else None
)
return constraint_name, unique_constraint, check_constraint, token
def _parse_table_constraints(self, sql, columns):
# Check constraint parsing is based of SQLite syntax diagram.
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
statement = sqlparse.parse(sql)[0]
constraints = {}
unnamed_constrains_index = 0
tokens = (token for token in statement.flatten() if not token.is_whitespace)
# Go to columns and constraint definition
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, "("):
break
# Parse columns and constraint definition
while True:
(
constraint_name,
unique,
check,
end_token,
) = self._parse_column_or_constraint_definition(tokens, columns)
if unique:
if constraint_name:
constraints[constraint_name] = unique
else:
unnamed_constrains_index += 1
constraints[
"__unnamed_constraint_%s__" % unnamed_constrains_index
] = unique
if check:
if constraint_name:
constraints[constraint_name] = check
else:
unnamed_constrains_index += 1
constraints[
"__unnamed_constraint_%s__" % unnamed_constrains_index
] = check
if end_token.match(sqlparse.tokens.Punctuation, ")"):
break
return constraints
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Find inline check constraints.
try:
table_schema = cursor.execute(
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
% (self.connection.ops.quote_name(table_name),)
).fetchone()[0]
except TypeError:
# table_name is a view.
pass
else:
columns = {
info.name for info in self.get_table_description(cursor, table_name)
}
constraints.update(self._parse_table_constraints(table_schema, columns))
# Get the index info
cursor.execute(
"PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
)
for row in cursor.fetchall():
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
# columns. Discard last 2 columns if there.
number, index, unique = row[:3]
cursor.execute(
"SELECT sql FROM sqlite_master "
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
# There's at most one row.
(sql,) = cursor.fetchone() or (None,)
# Inline constraints are already detected in
# _parse_table_constraints(). The reasons to avoid fetching inline
# constraints from `PRAGMA index_list` are:
# - Inline constraints can have a different name and information
# than what `PRAGMA index_list` gives.
# - Not all inline constraints may appear in `PRAGMA index_list`.
if not sql:
# An inline constraint
continue
# Get the index info for that index
cursor.execute(
"PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
)
for index_rank, column_rank, column in cursor.fetchall():
if index not in constraints:
constraints[index] = {
"columns": [],
"primary_key": False,
"unique": bool(unique),
"foreign_key": None,
"check": False,
"index": True,
}
constraints[index]["columns"].append(column)
# Add type and column orders for indexes
if constraints[index]["index"]:
# SQLite doesn't support any index type other than b-tree
constraints[index]["type"] = Index.suffix
orders = self._get_index_columns_orders(sql)
if orders is not None:
constraints[index]["orders"] = orders
# Get the PK
pk_column = self.get_primary_key_column(cursor, table_name)
if pk_column:
# SQLite doesn't actually give a name to the PK constraint,
# so we invent one. This is fine, as the SQLite backend never
# deletes PK constraints by name, as you can't delete constraints
# in SQLite; we remake the table with a new PK instead.
constraints["__primary__"] = {
"columns": [pk_column],
"primary_key": True,
"unique": False, # It's not actually a unique constraint.
"foreign_key": None,
"check": False,
"index": False,
}
relations = enumerate(self.get_relations(cursor, table_name).items())
constraints.update(
{
f"fk_{index}": {
"columns": [column_name],
"primary_key": False,
"unique": False,
"foreign_key": (ref_table_name, ref_column_name),
"check": False,
"index": False,
}
for index, (column_name, (ref_column_name, ref_table_name)) in relations
}
)
return constraints
def _get_index_columns_orders(self, sql):
tokens = sqlparse.parse(sql)[0]
for token in tokens:
if isinstance(token, sqlparse.sql.Parenthesis):
columns = str(token).strip("()").split(", ")
return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
return None
def _get_column_collations(self, cursor, table_name):
row = cursor.execute(
"""
SELECT sql
FROM sqlite_master
WHERE type = 'table' AND name = %s
""",
[table_name],
).fetchone()
if not row:
return {}
sql = row[0]
columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
collations = {}
for column in columns:
tokens = column[1:].split()
column_name = tokens[0].strip('"')
for index, token in enumerate(tokens):
if token == "COLLATE":
collation = tokens[index + 1]
break
else:
collation = None
collations[column_name] = collation
return collations

View File

@@ -0,0 +1,436 @@
import datetime
import decimal
import uuid
from functools import lru_cache
from itertools import chain
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.constants import OnConflict
from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "text"
cast_data_types = {
"DateField": "TEXT",
"DateTimeField": "TEXT",
}
explain_prefix = "EXPLAIN QUERY PLAN"
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
jsonfield_datatype_values = frozenset(["null", "false", "true"])
def bulk_batch_size(self, fields, objs):
"""
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
999 variables per query.
If there's only a single field to insert, the limit is 500
(SQLITE_MAX_COMPOUND_SELECT).
"""
if len(fields) == 1:
return 500
elif len(fields) > 1:
return self.connection.features.max_query_params // len(fields)
else:
return len(objs)
def check_expression_support(self, expression):
bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
if isinstance(expression, bad_aggregates):
for expr in expression.get_source_expressions():
try:
output_field = expr.output_field
except (AttributeError, FieldError):
# Not every subexpression has an output_field which is fine
# to ignore.
pass
else:
if isinstance(output_field, bad_fields):
raise NotSupportedError(
"You cannot use Sum, Avg, StdDev, and Variance "
"aggregations on date/time fields in sqlite3 "
"since date/time is saved as text."
)
if (
isinstance(expression, models.Aggregate)
and expression.distinct
and len(expression.source_expressions) > 1
):
raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
"accepting multiple arguments."
)
def date_extract_sql(self, lookup_type, sql, params):
"""
Support EXTRACT with a user-defined function django_date_extract()
that's registered in connect(). Use single quotes because this is a
string and could otherwise cause a collision with a field name.
"""
return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the list of returned data.
"""
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
"""Do nothing since formatting is handled in the custom function."""
return sql
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
return f"django_date_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(),
*params,
*self._convert_tznames_to_sql(tzname),
)
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
return f"django_time_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(),
*params,
*self._convert_tznames_to_sql(tzname),
)
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return tzname, self.connection.timezone_name
return None, None
def datetime_cast_date_sql(self, sql, params, tzname):
return f"django_datetime_cast_date({sql}, %s, %s)", (
*params,
*self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, sql, params, tzname):
return f"django_datetime_cast_time({sql}, %s, %s)", (
*params,
*self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
return f"django_datetime_extract(%s, {sql}, %s, %s)", (
lookup_type.lower(),
*params,
*self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(),
*params,
*self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, sql, params):
return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
def pk_default_value(self):
return "NULL"
def _quote_params_for_last_executed_query(self, params):
"""
Only for last_executed_query! Don't use this to execute SQL queries!
"""
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
# number of return values, default = 2000). Since Python's sqlite3
# module doesn't expose the get_limit() C API, assume the default
# limits are in effect and split the work in batches if needed.
BATCH_SIZE = 999
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
chunk = params[index : index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
# Native sqlite3 cursors cannot be used as context managers.
try:
return cursor.execute(sql, params).fetchone()
finally:
cursor.close()
def last_executed_query(self, cursor, sql, params):
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
# pysqlite_statement_bind_parameters(
# self->statement, parameters, allow_8bit_chars
# );
# Unfortunately there is no way to reach self->statement from Python,
# so we quote and substitute parameters manually.
if params:
if isinstance(params, (list, tuple)):
params = self._quote_params_for_last_executed_query(params)
else:
values = tuple(params.values())
values = self._quote_params_for_last_executed_query(values)
params = dict(zip(params, values))
return sql % params
# For consistency with SQLiteCursorWrapper.execute(), just return sql
# when there are no parameters. See #13648 and #17158.
else:
return sql
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return '"%s"' % name
def no_limit_value(self):
return -1
def __references_graph(self, table_name):
query = """
WITH tables AS (
SELECT %s name
UNION
SELECT sqlite_master.name
FROM sqlite_master
JOIN tables ON (sql REGEXP %s || tables.name || %s)
) SELECT name FROM tables;
"""
params = (
table_name,
r'(?i)\s+references\s+("|\')?',
r'("|\')?\s*\(',
)
with self.connection.cursor() as cursor:
results = cursor.execute(query, params)
return [row[0] for row in results.fetchall()]
@cached_property
def _references_graph(self):
# 512 is large enough to fit the ~330 tables (as of this writing) in
# Django's test suite.
return lru_cache(maxsize=512)(self.__references_graph)
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
tables = set(
chain.from_iterable(self._references_graph(table) for table in tables)
)
sql = [
"%s %s %s;"
% (
style.SQL_KEYWORD("DELETE"),
style.SQL_KEYWORD("FROM"),
style.SQL_FIELD(self.quote_name(table)),
)
for table in tables
]
if reset_sequences:
sequences = [{"table": table} for table in tables]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
def sequence_reset_by_name_sql(self, style, sequences):
if not sequences:
return []
return [
"%s %s %s %s = 0 %s %s %s (%s);"
% (
style.SQL_KEYWORD("UPDATE"),
style.SQL_TABLE(self.quote_name("sqlite_sequence")),
style.SQL_KEYWORD("SET"),
style.SQL_FIELD(self.quote_name("seq")),
style.SQL_KEYWORD("WHERE"),
style.SQL_FIELD(self.quote_name("name")),
style.SQL_KEYWORD("IN"),
", ".join(
["'%s'" % sequence_info["table"] for sequence_info in sequences]
),
),
]
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"SQLite backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
return str(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
raise ValueError("SQLite backend does not support timezone-aware times.")
return str(value)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "DateTimeField":
converters.append(self.convert_datetimefield_value)
elif internal_type == "DateField":
converters.append(self.convert_datefield_value)
elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
elif internal_type == "DecimalField":
converters.append(self.get_decimalfield_converter(expression))
elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
elif internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
return converters
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
if not isinstance(value, datetime.datetime):
value = parse_datetime(value)
if settings.USE_TZ and not timezone.is_aware(value):
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_datefield_value(self, value, expression, connection):
if value is not None:
if not isinstance(value, datetime.date):
value = parse_date(value)
return value
def convert_timefield_value(self, value, expression, connection):
if value is not None:
if not isinstance(value, datetime.time):
value = parse_time(value)
return value
def get_decimalfield_converter(self, expression):
# SQLite stores only 15 significant digits. Digits coming from
# float inaccuracy must be removed.
create_decimal = decimal.Context(prec=15).create_decimal_from_float
if isinstance(expression, Col):
quantize_value = decimal.Decimal(1).scaleb(
-expression.output_field.decimal_places
)
def converter(value, expression, connection):
if value is not None:
return create_decimal(value).quantize(
quantize_value, context=expression.output_field.context
)
else:
def converter(value, expression, connection):
if value is not None:
return create_decimal(value)
return converter
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def convert_booleanfield_value(self, value, expression, connection):
return bool(value) if value in (1, 0) else value
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
return f"VALUES {values_sql}"
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
if connector == "^":
return "POWER(%s)" % ",".join(sub_expressions)
elif connector == "#":
return "BITXOR(%s)" % ",".join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
if connector not in ["+", "-", "*", "/"]:
raise DatabaseError("Invalid connector for timedelta: %s." % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
raise ValueError("Too many params for timedelta operations.")
return "django_format_dtdelta(%s)" % ", ".join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
return (None, None)
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
if internal_type == "TimeField":
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return "INSERT OR IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if (
on_conflict == OnConflict.UPDATE
and self.connection.features.supports_update_conflicts_with_target
):
return "ON CONFLICT(%s) DO UPDATE SET %s" % (
", ".join(map(self.quote_name, unique_fields)),
", ".join(
[
f"{field} = EXCLUDED.{field}"
for field in map(self.quote_name, update_fields)
]
),
)
return super().on_conflict_suffix_sql(
fields,
on_conflict,
update_fields,
unique_fields,
)

View File

@@ -0,0 +1,559 @@
import copy
from decimal import Decimal
from django.apps.registry import Apps
from django.db import NotSupportedError
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import Statement
from django.db.backends.utils import strip_quotes
from django.db.models import UniqueConstraint
from django.db.transaction import atomic
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
sql_create_inline_fk = (
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
)
sql_create_column_inline_fk = sql_create_inline_fk
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s"
def __enter__(self):
# Some SQLite schema alterations need foreign key constraints to be
# disabled. Enforce it here for the duration of the schema edition.
if not self.connection.disable_constraint_checking():
raise NotSupportedError(
"SQLite schema editor cannot be used while foreign key "
"constraint checks are enabled. Make sure to disable them "
"before entering a transaction.atomic() context because "
"SQLite does not support disabling them in the middle of "
"a multi-statement transaction."
)
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
self.connection.check_constraints()
super().__exit__(exc_type, exc_value, traceback)
self.connection.enable_constraint_checking()
def quote_value(self, value):
# The backend "mostly works" without this function and there are use
# cases for compiling Python without the sqlite3 libraries (e.g.
# security hardening).
try:
import sqlite3
value = sqlite3.adapt(value)
except ImportError:
pass
except sqlite3.ProgrammingError:
pass
# Manual emulation of SQLite parameter quoting
if isinstance(value, bool):
return str(int(value))
elif isinstance(value, (Decimal, float, int)):
return str(value)
elif isinstance(value, str):
return "'%s'" % value.replace("'", "''")
elif value is None:
return "NULL"
elif isinstance(value, (bytes, bytearray, memoryview)):
# Bytes are only allowed for BLOB fields, encoded as string
# literals containing hexadecimal data and preceded by a single "X"
# character.
return "X'%s'" % value.hex()
else:
raise ValueError(
"Cannot quote parameter value %r of type %s" % (value, type(value))
)
def prepare_default(self, value):
return self.quote_value(value)
def _is_referenced_by_fk_constraint(
self, table_name, column_name=None, ignore_self=False
):
"""
Return whether or not the provided table name is referenced by another
one. If `column_name` is specified, only references pointing to that
column are considered. If `ignore_self` is True, self-referential
constraints are ignored.
"""
with self.connection.cursor() as cursor:
for other_table in self.connection.introspection.get_table_list(cursor):
if ignore_self and other_table.name == table_name:
continue
relations = self.connection.introspection.get_relations(
cursor, other_table.name
)
for constraint_column, constraint_table in relations.values():
if constraint_table == table_name and (
column_name is None or constraint_column == column_name
):
return True
return False
def alter_db_table(
self, model, old_db_table, new_db_table, disable_constraints=True
):
if (
not self.connection.features.supports_atomic_references_rename
and disable_constraints
and self._is_referenced_by_fk_constraint(old_db_table)
):
if self.connection.in_atomic_block:
raise NotSupportedError(
(
"Renaming the %r table while in a transaction is not "
"supported on SQLite < 3.26 because it would break referential "
"integrity. Try adding `atomic = False` to the Migration class."
)
% old_db_table
)
self.connection.enable_constraint_checking()
super().alter_db_table(model, old_db_table, new_db_table)
self.connection.disable_constraint_checking()
else:
super().alter_db_table(model, old_db_table, new_db_table)
def alter_field(self, model, old_field, new_field, strict=False):
if not self._field_should_be_altered(old_field, new_field):
return
old_field_name = old_field.name
table_name = model._meta.db_table
_, old_column_name = old_field.get_attname_column()
if (
new_field.name != old_field_name
and not self.connection.features.supports_atomic_references_rename
and self._is_referenced_by_fk_constraint(
table_name, old_column_name, ignore_self=True
)
):
if self.connection.in_atomic_block:
raise NotSupportedError(
(
"Renaming the %r.%r column while in a transaction is not "
"supported on SQLite < 3.26 because it would break referential "
"integrity. Try adding `atomic = False` to the Migration class."
)
% (model._meta.db_table, old_field_name)
)
with atomic(self.connection.alias):
super().alter_field(model, old_field, new_field, strict=strict)
# Follow SQLite's documented procedure for performing changes
# that don't affect the on-disk content.
# https://sqlite.org/lang_altertable.html#otheralter
with self.connection.cursor() as cursor:
schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
0
]
cursor.execute("PRAGMA writable_schema = 1")
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
new_column_name = new_field.get_attname_column()[1]
search = references_template % old_column_name
replacement = references_template % new_column_name
cursor.execute(
"UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
(search, replacement),
)
cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
cursor.execute("PRAGMA writable_schema = 0")
# The integrity check will raise an exception and rollback
# the transaction if the sqlite_master updates corrupt the
# database.
cursor.execute("PRAGMA integrity_check")
# Perform a VACUUM to refresh the database representation from
# the sqlite_master table.
with self.connection.cursor() as cursor:
cursor.execute("VACUUM")
else:
super().alter_field(model, old_field, new_field, strict=strict)
def _remake_table(
self, model, create_field=None, delete_field=None, alter_field=None
):
"""
Shortcut to transform a model from old_model into new_model
This follows the correct procedure to perform non-rename or column
addition operations based on SQLite's documentation
https://www.sqlite.org/lang_altertable.html#caution
The essential steps are:
1. Create a table with the updated definition called "new__app_model"
2. Copy the data from the existing "app_model" table to the new table
3. Drop the "app_model" table
4. Rename the "new__app_model" table to "app_model"
5. Restore any index of the previous "app_model" table.
"""
# Self-referential fields must be recreated rather than copied from
# the old model to ensure their remote_field.field_name doesn't refer
# to an altered field.
def is_self_referential(f):
return f.is_relation and f.remote_field.model is model
# Work out the new fields dict / mapping
body = {
f.name: f.clone() if is_self_referential(f) else f
for f in model._meta.local_concrete_fields
}
# Since mapping might mix column names and default values,
# its values must be already quoted.
mapping = {
f.column: self.quote_name(f.column)
for f in model._meta.local_concrete_fields
}
# This maps field names (not columns) for things like unique_together
rename_mapping = {}
# If any of the new or altered fields is introducing a new PK,
# remove the old one
restore_pk_field = None
if getattr(create_field, "primary_key", False) or (
alter_field and getattr(alter_field[1], "primary_key", False)
):
for name, field in list(body.items()):
if field.primary_key and not (
# Do not remove the old primary key when an altered field
# that introduces a primary key is the same field.
alter_field
and name == alter_field[1].name
):
field.primary_key = False
restore_pk_field = field
if field.auto_created:
del body[name]
del mapping[field.column]
# Add in any created fields
if create_field:
body[create_field.name] = create_field
# Choose a default and insert it into the copy map
if not create_field.many_to_many and create_field.concrete:
mapping[create_field.column] = self.prepare_default(
self.effective_default(create_field),
)
# Add in any altered fields
if alter_field:
old_field, new_field = alter_field
body.pop(old_field.name, None)
mapping.pop(old_field.column, None)
body[new_field.name] = new_field
if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
"col": self.quote_name(old_field.column),
"default": self.prepare_default(self.effective_default(new_field)),
}
mapping[new_field.column] = case_sql
else:
mapping[new_field.column] = self.quote_name(old_field.column)
rename_mapping[old_field.name] = new_field.name
# Remove any deleted fields
if delete_field:
del body[delete_field.name]
del mapping[delete_field.column]
# Remove any implicit M2M tables
if (
delete_field.many_to_many
and delete_field.remote_field.through._meta.auto_created
):
return self.delete_model(delete_field.remote_field.through)
# Work inside a new app registry
apps = Apps()
# Work out the new value of unique_together, taking renames into
# account
unique_together = [
[rename_mapping.get(n, n) for n in unique]
for unique in model._meta.unique_together
]
# Work out the new value for index_together, taking renames into
# account
index_together = [
[rename_mapping.get(n, n) for n in index]
for index in model._meta.index_together
]
indexes = model._meta.indexes
if delete_field:
indexes = [
index for index in indexes if delete_field.name not in index.fields
]
constraints = list(model._meta.constraints)
# Provide isolated instances of the fields to the new model body so
# that the existing model's internals aren't interfered with when
# the dummy model is constructed.
body_copy = copy.deepcopy(body)
# Construct a new model with the new fields to allow self referential
# primary key to resolve to. This model won't ever be materialized as a
# table and solely exists for foreign key reference resolution purposes.
# This wouldn't be required if the schema editor was operating on model
# states instead of rendered models.
meta_contents = {
"app_label": model._meta.app_label,
"db_table": model._meta.db_table,
"unique_together": unique_together,
"index_together": index_together,
"indexes": indexes,
"constraints": constraints,
"apps": apps,
}
meta = type("Meta", (), meta_contents)
body_copy["Meta"] = meta
body_copy["__module__"] = model.__module__
type(model._meta.object_name, model.__bases__, body_copy)
# Construct a model with a renamed table name.
body_copy = copy.deepcopy(body)
meta_contents = {
"app_label": model._meta.app_label,
"db_table": "new__%s" % strip_quotes(model._meta.db_table),
"unique_together": unique_together,
"index_together": index_together,
"indexes": indexes,
"constraints": constraints,
"apps": apps,
}
meta = type("Meta", (), meta_contents)
body_copy["Meta"] = meta
body_copy["__module__"] = model.__module__
new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
# Create a new table with the updated schema.
self.create_model(new_model)
# Copy data from the old table into the new table
self.execute(
"INSERT INTO %s (%s) SELECT %s FROM %s"
% (
self.quote_name(new_model._meta.db_table),
", ".join(self.quote_name(x) for x in mapping),
", ".join(mapping.values()),
self.quote_name(model._meta.db_table),
)
)
# Delete the old table to make way for the new
self.delete_model(model, handle_autom2m=False)
# Rename the new table to take way for the old
self.alter_db_table(
new_model,
new_model._meta.db_table,
model._meta.db_table,
disable_constraints=False,
)
# Run deferred SQL on correct table
for sql in self.deferred_sql:
self.execute(sql)
self.deferred_sql = []
# Fix any PK-removed field
if restore_pk_field:
restore_pk_field.primary_key = True
def delete_model(self, model, handle_autom2m=True):
if handle_autom2m:
super().delete_model(model)
else:
# Delete the table (and only that)
self.execute(
self.sql_delete_table
% {
"table": self.quote_name(model._meta.db_table),
}
)
# Remove all deferred statements referencing the deleted table.
for sql in list(self.deferred_sql):
if isinstance(sql, Statement) and sql.references_table(
model._meta.db_table
):
self.deferred_sql.remove(sql)
def add_field(self, model, field):
"""Create a field on a model."""
if (
# Primary keys and unique fields are not supported in ALTER TABLE
# ADD COLUMN.
field.primary_key
or field.unique
or
# Fields with default values cannot by handled by ALTER TABLE ADD
# COLUMN statement because DROP DEFAULT is not supported in
# ALTER TABLE.
not field.null
or self.effective_default(field) is not None
):
self._remake_table(model, create_field=field)
else:
super().add_field(model, field)
def remove_field(self, model, field):
"""
Remove a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table.
"""
# M2M fields are a special case
if field.many_to_many:
# For implicit M2M tables, delete the auto-created table
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
# For explicit "through" M2M fields, do nothing
elif (
self.connection.features.can_alter_table_drop_column
# Primary keys, unique fields, and foreign keys are not
# supported in ALTER TABLE DROP COLUMN.
and not field.primary_key
and not field.unique
and not (field.remote_field and field.db_constraint)
):
super().remove_field(model, field)
# For everything else, remake.
else:
# It might not actually have a column behind it
if field.db_parameters(connection=self.connection)["type"] is None:
return
self._remake_table(model, delete_field=field)
def _alter_field(
self,
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict=False,
):
"""Perform a "physical" (non-ManyToMany) field update."""
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
# changed and there aren't any constraints.
if (
self.connection.features.can_alter_table_rename_column
and old_field.column != new_field.column
and self.column_sql(model, old_field) == self.column_sql(model, new_field)
and not (
old_field.remote_field
and old_field.db_constraint
or new_field.remote_field
and new_field.db_constraint
)
):
return self.execute(
self._rename_field_sql(
model._meta.db_table, old_field, new_field, new_type
)
)
# Alter by remaking table
self._remake_table(model, alter_field=(old_field, new_field))
# Rebuild tables with FKs pointing to this field.
old_collation = old_db_params.get("collation")
new_collation = new_db_params.get("collation")
if new_field.unique and (
old_type != new_type or old_collation != new_collation
):
related_models = set()
opts = new_field.model._meta
for remote_field in opts.related_objects:
# Ignore self-relationship since the table was already rebuilt.
if remote_field.related_model == model:
continue
if not remote_field.many_to_many:
if remote_field.field_name == new_field.name:
related_models.add(remote_field.related_model)
elif new_field.primary_key and remote_field.through._meta.auto_created:
related_models.add(remote_field.through)
if new_field.primary_key:
for many_to_many in opts.many_to_many:
# Ignore self-relationship since the table was already rebuilt.
if many_to_many.related_model == model:
continue
if many_to_many.remote_field.through._meta.auto_created:
related_models.add(many_to_many.remote_field.through)
for related_model in related_models:
self._remake_table(related_model)
def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""
if (
old_field.remote_field.through._meta.db_table
== new_field.remote_field.through._meta.db_table
):
# The field name didn't change, but some options did, so we have to
# propagate this altering.
self._remake_table(
old_field.remote_field.through,
alter_field=(
# The field that points to the target model is needed, so
# we can tell alter_field to change it - this is
# m2m_reverse_field_name() (as opposed to m2m_field_name(),
# which points to our model).
old_field.remote_field.through._meta.get_field(
old_field.m2m_reverse_field_name()
),
new_field.remote_field.through._meta.get_field(
new_field.m2m_reverse_field_name()
),
),
)
return
# Make a new through table
self.create_model(new_field.remote_field.through)
# Copy the data across
self.execute(
"INSERT INTO %s (%s) SELECT %s FROM %s"
% (
self.quote_name(new_field.remote_field.through._meta.db_table),
", ".join(
[
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]
),
", ".join(
[
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]
),
self.quote_name(old_field.remote_field.through._meta.db_table),
)
)
# Delete the old through table
self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition
or constraint.contains_expressions
or constraint.include
or constraint.deferrable
):
super().add_constraint(model, constraint)
else:
self._remake_table(model)
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition
or constraint.contains_expressions
or constraint.include
or constraint.deferrable
):
super().remove_constraint(model, constraint)
else:
self._remake_table(model)
def _collate_sql(self, collation):
return "COLLATE " + collation