before send to remote

This commit is contained in:
2022-08-26 16:41:18 +06:00
commit 3814beb3e0
5408 changed files with 652023 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from .migration import Migration, swappable_dependency # NOQA
from .operations import * # NOQA

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
from django.db import DatabaseError
class AmbiguityError(Exception):
"""More than one migration matches a name prefix."""
pass
class BadMigrationError(Exception):
"""There's a bad migration (unreadable/bad format/etc.)."""
pass
class CircularDependencyError(Exception):
"""There's an impossible-to-resolve circular dependency."""
pass
class InconsistentMigrationHistory(Exception):
"""An applied migration has some of its dependencies not applied."""
pass
class InvalidBasesError(ValueError):
"""A model's base classes can't be resolved."""
pass
class IrreversibleError(RuntimeError):
"""An irreversible migration is about to be reversed."""
pass
class NodeNotFoundError(LookupError):
"""An attempt on a node is made that is not available in the graph."""
def __init__(self, message, node, origin=None):
self.message = message
self.origin = origin
self.node = node
def __str__(self):
return self.message
def __repr__(self):
return "NodeNotFoundError(%r)" % (self.node,)
class MigrationSchemaMissing(DatabaseError):
pass
class InvalidMigrationPlan(ValueError):
pass

View File

@@ -0,0 +1,410 @@
from django.apps.registry import apps as global_apps
from django.db import migrations, router
from .exceptions import InvalidMigrationPlan
from .loader import MigrationLoader
from .recorder import MigrationRecorder
from .state import ProjectState
class MigrationExecutor:
"""
End-to-end migration execution - load migrations and run them up or down
to a specified set of targets.
"""
def __init__(self, connection, progress_callback=None):
self.connection = connection
self.loader = MigrationLoader(self.connection)
self.recorder = MigrationRecorder(self.connection)
self.progress_callback = progress_callback
def migration_plan(self, targets, clean_start=False):
"""
Given a set of targets, return a list of (Migration instance, backwards?).
"""
plan = []
if clean_start:
applied = {}
else:
applied = dict(self.loader.applied_migrations)
for target in targets:
# If the target is (app_label, None), that means unmigrate everything
if target[1] is None:
for root in self.loader.graph.root_nodes():
if root[0] == target[0]:
for migration in self.loader.graph.backwards_plan(root):
if migration in applied:
plan.append((self.loader.graph.nodes[migration], True))
applied.pop(migration)
# If the migration is already applied, do backwards mode,
# otherwise do forwards mode.
elif target in applied:
# If the target is missing, it's likely a replaced migration.
# Reload the graph without replacements.
if (
self.loader.replace_migrations
and target not in self.loader.graph.node_map
):
self.loader.replace_migrations = False
self.loader.build_graph()
return self.migration_plan(targets, clean_start=clean_start)
# Don't migrate backwards all the way to the target node (that
# may roll back dependencies in other apps that don't need to
# be rolled back); instead roll back through target's immediate
# child(ren) in the same app, and no further.
next_in_app = sorted(
n
for n in self.loader.graph.node_map[target].children
if n[0] == target[0]
)
for node in next_in_app:
for migration in self.loader.graph.backwards_plan(node):
if migration in applied:
plan.append((self.loader.graph.nodes[migration], True))
applied.pop(migration)
else:
for migration in self.loader.graph.forwards_plan(target):
if migration not in applied:
plan.append((self.loader.graph.nodes[migration], False))
applied[migration] = self.loader.graph.nodes[migration]
return plan
def _create_project_state(self, with_applied_migrations=False):
"""
Create a project state including all the applications without
migrations and applied migrations if with_applied_migrations=True.
"""
state = ProjectState(real_apps=self.loader.unmigrated_apps)
if with_applied_migrations:
# Create the forwards plan Django would follow on an empty database
full_plan = self.migration_plan(
self.loader.graph.leaf_nodes(), clean_start=True
)
applied_migrations = {
self.loader.graph.nodes[key]
for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
for migration, _ in full_plan:
if migration in applied_migrations:
migration.mutate_state(state, preserve=False)
return state
def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
"""
Migrate the database up to the given targets.
Django first needs to create all project states before a migration is
(un)applied and in a second step run all the database operations.
"""
# The django_migrations table must be present to record applied
# migrations, but don't create it if there are no migrations to apply.
if plan == []:
if not self.recorder.has_table():
return self._create_project_state(with_applied_migrations=False)
else:
self.recorder.ensure_schema()
if plan is None:
plan = self.migration_plan(targets)
# Create the forwards plan Django would follow on an empty database
full_plan = self.migration_plan(
self.loader.graph.leaf_nodes(), clean_start=True
)
all_forwards = all(not backwards for mig, backwards in plan)
all_backwards = all(backwards for mig, backwards in plan)
if not plan:
if state is None:
# The resulting state should include applied migrations.
state = self._create_project_state(with_applied_migrations=True)
elif all_forwards == all_backwards:
# This should only happen if there's a mixed plan
raise InvalidMigrationPlan(
"Migration plans with both forwards and backwards migrations "
"are not supported. Please split your migration process into "
"separate plans of only forwards OR backwards migrations.",
plan,
)
elif all_forwards:
if state is None:
# The resulting state should still include applied migrations.
state = self._create_project_state(with_applied_migrations=True)
state = self._migrate_all_forwards(
state, plan, full_plan, fake=fake, fake_initial=fake_initial
)
else:
# No need to check for `elif all_backwards` here, as that condition
# would always evaluate to true.
state = self._migrate_all_backwards(plan, full_plan, fake=fake)
self.check_replacements()
return state
def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
"""
Take a list of 2-tuples of the form (migration instance, False) and
apply them in the order they occur in the full_plan.
"""
migrations_to_run = {m[0] for m in plan}
for migration, _ in full_plan:
if not migrations_to_run:
# We remove every migration that we applied from these sets so
# that we can bail out once the last migration has been applied
# and don't always run until the very end of the migration
# process.
break
if migration in migrations_to_run:
if "apps" not in state.__dict__:
if self.progress_callback:
self.progress_callback("render_start")
state.apps # Render all -- performance critical
if self.progress_callback:
self.progress_callback("render_success")
state = self.apply_migration(
state, migration, fake=fake, fake_initial=fake_initial
)
migrations_to_run.remove(migration)
return state
def _migrate_all_backwards(self, plan, full_plan, fake):
"""
Take a list of 2-tuples of the form (migration instance, True) and
unapply them in reverse order they occur in the full_plan.
Since unapplying a migration requires the project state prior to that
migration, Django will compute the migration states before each of them
in a first run over the plan and then unapply them in a second run over
the plan.
"""
migrations_to_run = {m[0] for m in plan}
# Holds all migration states prior to the migrations being unapplied
states = {}
state = self._create_project_state()
applied_migrations = {
self.loader.graph.nodes[key]
for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
if self.progress_callback:
self.progress_callback("render_start")
for migration, _ in full_plan:
if not migrations_to_run:
# We remove every migration that we applied from this set so
# that we can bail out once the last migration has been applied
# and don't always run until the very end of the migration
# process.
break
if migration in migrations_to_run:
if "apps" not in state.__dict__:
state.apps # Render all -- performance critical
# The state before this migration
states[migration] = state
# The old state keeps as-is, we continue with the new state
state = migration.mutate_state(state, preserve=True)
migrations_to_run.remove(migration)
elif migration in applied_migrations:
# Only mutate the state if the migration is actually applied
# to make sure the resulting state doesn't include changes
# from unrelated migrations.
migration.mutate_state(state, preserve=False)
if self.progress_callback:
self.progress_callback("render_success")
for migration, _ in plan:
self.unapply_migration(states[migration], migration, fake=fake)
applied_migrations.remove(migration)
# Generate the post migration state by starting from the state before
# the last migration is unapplied and mutating it to include all the
# remaining applied migrations.
last_unapplied_migration = plan[-1][0]
state = states[last_unapplied_migration]
for index, (migration, _) in enumerate(full_plan):
if migration == last_unapplied_migration:
for migration, _ in full_plan[index:]:
if migration in applied_migrations:
migration.mutate_state(state, preserve=False)
break
return state
def apply_migration(self, state, migration, fake=False, fake_initial=False):
"""Run a migration forwards."""
migration_recorded = False
if self.progress_callback:
self.progress_callback("apply_start", migration, fake)
if not fake:
if fake_initial:
# Test to see if this is an already-applied initial migration
applied, state = self.detect_soft_applied(state, migration)
if applied:
fake = True
if not fake:
# Alright, do it normally
with self.connection.schema_editor(
atomic=migration.atomic
) as schema_editor:
state = migration.apply(state, schema_editor)
if not schema_editor.deferred_sql:
self.record_migration(migration)
migration_recorded = True
if not migration_recorded:
self.record_migration(migration)
# Report progress
if self.progress_callback:
self.progress_callback("apply_success", migration, fake)
return state
def record_migration(self, migration):
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_applied(app_label, name)
else:
self.recorder.record_applied(migration.app_label, migration.name)
def unapply_migration(self, state, migration, fake=False):
"""Run a migration backwards."""
if self.progress_callback:
self.progress_callback("unapply_start", migration, fake)
if not fake:
with self.connection.schema_editor(
atomic=migration.atomic
) as schema_editor:
state = migration.unapply(state, schema_editor)
# For replacement migrations, also record individual statuses.
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_unapplied(app_label, name)
self.recorder.record_unapplied(migration.app_label, migration.name)
# Report progress
if self.progress_callback:
self.progress_callback("unapply_success", migration, fake)
return state
def check_replacements(self):
"""
Mark replacement migrations applied if their replaced set all are.
Do this unconditionally on every migrate, rather than just when
migrations are applied or unapplied, to correctly handle the case
when a new squash migration is pushed to a deployment that already had
all its replaced migrations applied. In this case no new migration will
be applied, but the applied state of the squashed migration must be
maintained.
"""
applied = self.recorder.applied_migrations()
for key, migration in self.loader.replacements.items():
all_applied = all(m in applied for m in migration.replaces)
if all_applied and key not in applied:
self.recorder.record_applied(*key)
def detect_soft_applied(self, project_state, migration):
"""
Test whether a migration has been implicitly applied - that the
tables or columns it would create exist. This is intended only for use
on initial migrations (as it only looks for CreateModel and AddField).
"""
def should_skip_detecting_model(migration, model):
"""
No need to detect tables for proxy models, unmanaged models, or
models that can't be migrated on the current database.
"""
return (
model._meta.proxy
or not model._meta.managed
or not router.allow_migrate(
self.connection.alias,
migration.app_label,
model_name=model._meta.model_name,
)
)
if migration.initial is None:
# Bail if the migration isn't the first one in its app
if any(app == migration.app_label for app, name in migration.dependencies):
return False, project_state
elif migration.initial is False:
# Bail if it's NOT an initial migration
return False, project_state
if project_state is None:
after_state = self.loader.project_state(
(migration.app_label, migration.name), at_end=True
)
else:
after_state = migration.mutate_state(project_state)
apps = after_state.apps
found_create_model_migration = False
found_add_field_migration = False
fold_identifier_case = self.connection.features.ignores_table_name_case
with self.connection.cursor() as cursor:
existing_table_names = set(
self.connection.introspection.table_names(cursor)
)
if fold_identifier_case:
existing_table_names = {
name.casefold() for name in existing_table_names
}
# Make sure all create model and add field operations are done
for operation in migration.operations:
if isinstance(operation, migrations.CreateModel):
model = apps.get_model(migration.app_label, operation.name)
if model._meta.swapped:
# We have to fetch the model to test with from the
# main app cache, as it's not a direct dependency.
model = global_apps.get_model(model._meta.swapped)
if should_skip_detecting_model(migration, model):
continue
db_table = model._meta.db_table
if fold_identifier_case:
db_table = db_table.casefold()
if db_table not in existing_table_names:
return False, project_state
found_create_model_migration = True
elif isinstance(operation, migrations.AddField):
model = apps.get_model(migration.app_label, operation.model_name)
if model._meta.swapped:
# We have to fetch the model to test with from the
# main app cache, as it's not a direct dependency.
model = global_apps.get_model(model._meta.swapped)
if should_skip_detecting_model(migration, model):
continue
table = model._meta.db_table
field = model._meta.get_field(operation.name)
# Handle implicit many-to-many tables created by AddField.
if field.many_to_many:
through_db_table = field.remote_field.through._meta.db_table
if fold_identifier_case:
through_db_table = through_db_table.casefold()
if through_db_table not in existing_table_names:
return False, project_state
else:
found_add_field_migration = True
continue
with self.connection.cursor() as cursor:
columns = self.connection.introspection.get_table_description(
cursor, table
)
for column in columns:
field_column = field.column
column_name = column.name
if fold_identifier_case:
column_name = column_name.casefold()
field_column = field_column.casefold()
if column_name == field_column:
found_add_field_migration = True
break
else:
return False, project_state
# If we get this far and we found at least one CreateModel or AddField
# migration, the migration is considered implicitly applied.
return (found_create_model_migration or found_add_field_migration), after_state

View File

@@ -0,0 +1,333 @@
from functools import total_ordering
from django.db.migrations.state import ProjectState
from .exceptions import CircularDependencyError, NodeNotFoundError
@total_ordering
class Node:
"""
A single node in the migration graph. Contains direct links to adjacent
nodes in either direction.
"""
def __init__(self, key):
self.key = key
self.children = set()
self.parents = set()
def __eq__(self, other):
return self.key == other
def __lt__(self, other):
return self.key < other
def __hash__(self):
return hash(self.key)
def __getitem__(self, item):
return self.key[item]
def __str__(self):
return str(self.key)
def __repr__(self):
return "<%s: (%r, %r)>" % (self.__class__.__name__, self.key[0], self.key[1])
def add_child(self, child):
self.children.add(child)
def add_parent(self, parent):
self.parents.add(parent)
class DummyNode(Node):
"""
A node that doesn't correspond to a migration file on disk.
(A squashed migration that was removed, for example.)
After the migration graph is processed, all dummy nodes should be removed.
If there are any left, a nonexistent dependency error is raised.
"""
def __init__(self, key, origin, error_message):
super().__init__(key)
self.origin = origin
self.error_message = error_message
def raise_error(self):
raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
class MigrationGraph:
"""
Represent the digraph of all migrations in a project.
Each migration is a node, and each dependency is an edge. There are
no implicit dependencies between numbered migrations - the numbering is
merely a convention to aid file listing. Every new numbered migration
has a declared dependency to the previous number, meaning that VCS
branch merges can be detected and resolved.
Migrations files can be marked as replacing another set of migrations -
this is to support the "squash" feature. The graph handler isn't responsible
for these; instead, the code to load them in here should examine the
migration files and if the replaced migrations are all either unapplied
or not present, it should ignore the replaced ones, load in just the
replacing migration, and repoint any dependencies that pointed to the
replaced migrations to point to the replacing one.
A node should be a tuple: (app_path, migration_name). The tree special-cases
things within an app - namely, root nodes and leaf nodes ignore dependencies
to other apps.
"""
def __init__(self):
self.node_map = {}
self.nodes = {}
def add_node(self, key, migration):
assert key not in self.node_map
node = Node(key)
self.node_map[key] = node
self.nodes[key] = migration
def add_dummy_node(self, key, origin, error_message):
node = DummyNode(key, origin, error_message)
self.node_map[key] = node
self.nodes[key] = None
def add_dependency(self, migration, child, parent, skip_validation=False):
"""
This may create dummy nodes if they don't yet exist. If
`skip_validation=True`, validate_consistency() should be called
afterward.
"""
if child not in self.nodes:
error_message = (
"Migration %s dependencies reference nonexistent"
" child node %r" % (migration, child)
)
self.add_dummy_node(child, migration, error_message)
if parent not in self.nodes:
error_message = (
"Migration %s dependencies reference nonexistent"
" parent node %r" % (migration, parent)
)
self.add_dummy_node(parent, migration, error_message)
self.node_map[child].add_parent(self.node_map[parent])
self.node_map[parent].add_child(self.node_map[child])
if not skip_validation:
self.validate_consistency()
def remove_replaced_nodes(self, replacement, replaced):
"""
Remove each of the `replaced` nodes (when they exist). Any
dependencies that were referencing them are changed to reference the
`replacement` node instead.
"""
# Cast list of replaced keys to set to speed up lookup later.
replaced = set(replaced)
try:
replacement_node = self.node_map[replacement]
except KeyError as err:
raise NodeNotFoundError(
"Unable to find replacement node %r. It was either never added"
" to the migration graph, or has been removed." % (replacement,),
replacement,
) from err
for replaced_key in replaced:
self.nodes.pop(replaced_key, None)
replaced_node = self.node_map.pop(replaced_key, None)
if replaced_node:
for child in replaced_node.children:
child.parents.remove(replaced_node)
# We don't want to create dependencies between the replaced
# node and the replacement node as this would lead to
# self-referencing on the replacement node at a later iteration.
if child.key not in replaced:
replacement_node.add_child(child)
child.add_parent(replacement_node)
for parent in replaced_node.parents:
parent.children.remove(replaced_node)
# Again, to avoid self-referencing.
if parent.key not in replaced:
replacement_node.add_parent(parent)
parent.add_child(replacement_node)
def remove_replacement_node(self, replacement, replaced):
"""
The inverse operation to `remove_replaced_nodes`. Almost. Remove the
replacement node `replacement` and remap its child nodes to `replaced`
- the list of nodes it would have replaced. Don't remap its parent
nodes as they are expected to be correct already.
"""
self.nodes.pop(replacement, None)
try:
replacement_node = self.node_map.pop(replacement)
except KeyError as err:
raise NodeNotFoundError(
"Unable to remove replacement node %r. It was either never added"
" to the migration graph, or has been removed already."
% (replacement,),
replacement,
) from err
replaced_nodes = set()
replaced_nodes_parents = set()
for key in replaced:
replaced_node = self.node_map.get(key)
if replaced_node:
replaced_nodes.add(replaced_node)
replaced_nodes_parents |= replaced_node.parents
# We're only interested in the latest replaced node, so filter out
# replaced nodes that are parents of other replaced nodes.
replaced_nodes -= replaced_nodes_parents
for child in replacement_node.children:
child.parents.remove(replacement_node)
for replaced_node in replaced_nodes:
replaced_node.add_child(child)
child.add_parent(replaced_node)
for parent in replacement_node.parents:
parent.children.remove(replacement_node)
# NOTE: There is no need to remap parent dependencies as we can
# assume the replaced nodes already have the correct ancestry.
def validate_consistency(self):
"""Ensure there are no dummy nodes remaining in the graph."""
[n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
def forwards_plan(self, target):
"""
Given a node, return a list of which previous nodes (dependencies) must
be applied, ending with the node itself. This is the list you would
follow if applying the migrations to a database.
"""
if target not in self.nodes:
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
return self.iterative_dfs(self.node_map[target])
def backwards_plan(self, target):
"""
Given a node, return a list of which dependent nodes (dependencies)
must be unapplied, ending with the node itself. This is the list you
would follow if removing the migrations from a database.
"""
if target not in self.nodes:
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
return self.iterative_dfs(self.node_map[target], forwards=False)
def iterative_dfs(self, start, forwards=True):
"""Iterative depth-first search for finding dependencies."""
visited = []
visited_set = set()
stack = [(start, False)]
while stack:
node, processed = stack.pop()
if node in visited_set:
pass
elif processed:
visited_set.add(node)
visited.append(node.key)
else:
stack.append((node, True))
stack += [
(n, False)
for n in sorted(node.parents if forwards else node.children)
]
return visited
def root_nodes(self, app=None):
"""
Return all root nodes - that is, nodes with no dependencies inside
their app. These are the starting point for an app.
"""
roots = set()
for node in self.nodes:
if all(key[0] != node[0] for key in self.node_map[node].parents) and (
not app or app == node[0]
):
roots.add(node)
return sorted(roots)
def leaf_nodes(self, app=None):
"""
Return all leaf nodes - that is, nodes with no dependents in their app.
These are the "most current" version of an app's schema.
Having more than one per app is technically an error, but one that
gets handled further up, in the interactive command - it's usually the
result of a VCS merge and needs some user input.
"""
leaves = set()
for node in self.nodes:
if all(key[0] != node[0] for key in self.node_map[node].children) and (
not app or app == node[0]
):
leaves.add(node)
return sorted(leaves)
def ensure_not_cyclic(self):
# Algo from GvR:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
todo = set(self.nodes)
while todo:
node = todo.pop()
stack = [node]
while stack:
top = stack[-1]
for child in self.node_map[top].children:
# Use child.key instead of child to speed up the frequent
# hashing.
node = child.key
if node in stack:
cycle = stack[stack.index(node) :]
raise CircularDependencyError(
", ".join("%s.%s" % n for n in cycle)
)
if node in todo:
stack.append(node)
todo.remove(node)
break
else:
node = stack.pop()
def __str__(self):
return "Graph: %s nodes, %s edges" % self._nodes_and_edges()
def __repr__(self):
nodes, edges = self._nodes_and_edges()
return "<%s: nodes=%s, edges=%s>" % (self.__class__.__name__, nodes, edges)
def _nodes_and_edges(self):
return len(self.nodes), sum(
len(node.parents) for node in self.node_map.values()
)
def _generate_plan(self, nodes, at_end):
plan = []
for node in nodes:
for migration in self.forwards_plan(node):
if migration not in plan and (at_end or migration not in nodes):
plan.append(migration)
return plan
def make_state(self, nodes=None, at_end=True, real_apps=None):
"""
Given a migration node or nodes, return a complete ProjectState for it.
If at_end is False, return the state before the migration has run.
If nodes is not provided, return the overall most current project state.
"""
if nodes is None:
nodes = list(self.leaf_nodes())
if not nodes:
return ProjectState()
if not isinstance(nodes[0], tuple):
nodes = [nodes]
plan = self._generate_plan(nodes, at_end)
project_state = ProjectState(real_apps=real_apps)
for node in plan:
project_state = self.nodes[node].mutate_state(project_state, preserve=False)
return project_state
def __contains__(self, node):
return node in self.nodes

View File

@@ -0,0 +1,385 @@
import pkgutil
import sys
from importlib import import_module, reload
from django.apps import apps
from django.conf import settings
from django.db.migrations.graph import MigrationGraph
from django.db.migrations.recorder import MigrationRecorder
from .exceptions import (
AmbiguityError,
BadMigrationError,
InconsistentMigrationHistory,
NodeNotFoundError,
)
MIGRATIONS_MODULE_NAME = "migrations"
class MigrationLoader:
"""
Load migration files from disk and their status from the database.
Migration files are expected to live in the "migrations" directory of
an app. Their names are entirely unimportant from a code perspective,
but will probably follow the 1234_name.py convention.
On initialization, this class will scan those directories, and open and
read the Python files, looking for a class called Migration, which should
inherit from django.db.migrations.Migration. See
django.db.migrations.migration for what that looks like.
Some migrations will be marked as "replacing" another set of migrations.
These are loaded into a separate set of migrations away from the main ones.
If all the migrations they replace are either unapplied or missing from
disk, then they are injected into the main set, replacing the named migrations.
Any dependency pointers to the replaced migrations are re-pointed to the
new migration.
This does mean that this class MUST also talk to the database as well as
to disk, but this is probably fine. We're already not just operating
in memory.
"""
def __init__(
self,
connection,
load=True,
ignore_no_migrations=False,
replace_migrations=True,
):
self.connection = connection
self.disk_migrations = None
self.applied_migrations = None
self.ignore_no_migrations = ignore_no_migrations
self.replace_migrations = replace_migrations
if load:
self.build_graph()
@classmethod
def migrations_module(cls, app_label):
"""
Return the path to the migrations module for the specified app_label
and a boolean indicating if the module is specified in
settings.MIGRATION_MODULE.
"""
if app_label in settings.MIGRATION_MODULES:
return settings.MIGRATION_MODULES[app_label], True
else:
app_package_name = apps.get_app_config(app_label).name
return "%s.%s" % (app_package_name, MIGRATIONS_MODULE_NAME), False
def load_disk(self):
"""Load the migrations from all INSTALLED_APPS from disk."""
self.disk_migrations = {}
self.unmigrated_apps = set()
self.migrated_apps = set()
for app_config in apps.get_app_configs():
# Get the migrations module directory
module_name, explicit = self.migrations_module(app_config.label)
if module_name is None:
self.unmigrated_apps.add(app_config.label)
continue
was_loaded = module_name in sys.modules
try:
module = import_module(module_name)
except ModuleNotFoundError as e:
if (explicit and self.ignore_no_migrations) or (
not explicit and MIGRATIONS_MODULE_NAME in e.name.split(".")
):
self.unmigrated_apps.add(app_config.label)
continue
raise
else:
# Module is not a package (e.g. migrations.py).
if not hasattr(module, "__path__"):
self.unmigrated_apps.add(app_config.label)
continue
# Empty directories are namespaces. Namespace packages have no
# __file__ and don't use a list for __path__. See
# https://docs.python.org/3/reference/import.html#namespace-packages
if getattr(module, "__file__", None) is None and not isinstance(
module.__path__, list
):
self.unmigrated_apps.add(app_config.label)
continue
# Force a reload if it's already loaded (tests need this)
if was_loaded:
reload(module)
self.migrated_apps.add(app_config.label)
migration_names = {
name
for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
if not is_pkg and name[0] not in "_~"
}
# Load migrations
for migration_name in migration_names:
migration_path = "%s.%s" % (module_name, migration_name)
try:
migration_module = import_module(migration_path)
except ImportError as e:
if "bad magic number" in str(e):
raise ImportError(
"Couldn't import %r as it appears to be a stale "
".pyc file." % migration_path
) from e
else:
raise
if not hasattr(migration_module, "Migration"):
raise BadMigrationError(
"Migration %s in app %s has no Migration class"
% (migration_name, app_config.label)
)
self.disk_migrations[
app_config.label, migration_name
] = migration_module.Migration(
migration_name,
app_config.label,
)
def get_migration(self, app_label, name_prefix):
"""Return the named migration or raise NodeNotFoundError."""
return self.graph.nodes[app_label, name_prefix]
def get_migration_by_prefix(self, app_label, name_prefix):
"""
Return the migration(s) which match the given app label and name_prefix.
"""
# Do the search
results = []
for migration_app_label, migration_name in self.disk_migrations:
if migration_app_label == app_label and migration_name.startswith(
name_prefix
):
results.append((migration_app_label, migration_name))
if len(results) > 1:
raise AmbiguityError(
"There is more than one migration for '%s' with the prefix '%s'"
% (app_label, name_prefix)
)
elif not results:
raise KeyError(
f"There is no migration for '{app_label}' with the prefix "
f"'{name_prefix}'"
)
else:
return self.disk_migrations[results[0]]
def check_key(self, key, current_app):
if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
return key
# Special-case __first__, which means "the first migration" for
# migrated apps, and is ignored for unmigrated apps. It allows
# makemigrations to declare dependencies on apps before they even have
# migrations.
if key[0] == current_app:
# Ignore __first__ references to the same app (#22325)
return
if key[0] in self.unmigrated_apps:
# This app isn't migrated, but something depends on it.
# The models will get auto-added into the state, though
# so we're fine.
return
if key[0] in self.migrated_apps:
try:
if key[1] == "__first__":
return self.graph.root_nodes(key[0])[0]
else: # "__latest__"
return self.graph.leaf_nodes(key[0])[0]
except IndexError:
if self.ignore_no_migrations:
return None
else:
raise ValueError(
"Dependency on app with no migrations: %s" % key[0]
)
raise ValueError("Dependency on unknown app: %s" % key[0])
def add_internal_dependencies(self, key, migration):
"""
Internal dependencies need to be added first to ensure `__first__`
dependencies find the correct root node.
"""
for parent in migration.dependencies:
# Ignore __first__ references to the same app.
if parent[0] == key[0] and parent[1] != "__first__":
self.graph.add_dependency(migration, key, parent, skip_validation=True)
def add_external_dependencies(self, key, migration):
for parent in migration.dependencies:
# Skip internal dependencies
if key[0] == parent[0]:
continue
parent = self.check_key(parent, key[0])
if parent is not None:
self.graph.add_dependency(migration, key, parent, skip_validation=True)
for child in migration.run_before:
child = self.check_key(child, key[0])
if child is not None:
self.graph.add_dependency(migration, child, key, skip_validation=True)
def build_graph(self):
"""
Build a migration dependency graph using both the disk and database.
You'll need to rebuild the graph if you apply migrations. This isn't
usually a problem as generally migration stuff runs in a one-shot process.
"""
# Load disk data
self.load_disk()
# Load database data
if self.connection is None:
self.applied_migrations = {}
else:
recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations()
# To start, populate the migration graph with nodes for ALL migrations
# and their dependencies. Also make note of replacing migrations at this step.
self.graph = MigrationGraph()
self.replacements = {}
for key, migration in self.disk_migrations.items():
self.graph.add_node(key, migration)
# Replacing migrations.
if migration.replaces:
self.replacements[key] = migration
for key, migration in self.disk_migrations.items():
# Internal (same app) dependencies.
self.add_internal_dependencies(key, migration)
# Add external dependencies now that the internal ones have been resolved.
for key, migration in self.disk_migrations.items():
self.add_external_dependencies(key, migration)
# Carry out replacements where possible and if enabled.
if self.replace_migrations:
for key, migration in self.replacements.items():
# Get applied status of each of this migration's replacement
# targets.
applied_statuses = [
(target in self.applied_migrations) for target in migration.replaces
]
# The replacing migration is only marked as applied if all of
# its replacement targets are.
if all(applied_statuses):
self.applied_migrations[key] = migration
else:
self.applied_migrations.pop(key, None)
# A replacing migration can be used if either all or none of
# its replacement targets have been applied.
if all(applied_statuses) or (not any(applied_statuses)):
self.graph.remove_replaced_nodes(key, migration.replaces)
else:
# This replacing migration cannot be used because it is
# partially applied. Remove it from the graph and remap
# dependencies to it (#25945).
self.graph.remove_replacement_node(key, migration.replaces)
# Ensure the graph is consistent.
try:
self.graph.validate_consistency()
except NodeNotFoundError as exc:
# Check if the missing node could have been replaced by any squash
# migration but wasn't because the squash migration was partially
# applied before. In that case raise a more understandable exception
# (#23556).
# Get reverse replacements.
reverse_replacements = {}
for key, migration in self.replacements.items():
for replaced in migration.replaces:
reverse_replacements.setdefault(replaced, set()).add(key)
# Try to reraise exception with more detail.
if exc.node in reverse_replacements:
candidates = reverse_replacements.get(exc.node, set())
is_replaced = any(
candidate in self.graph.nodes for candidate in candidates
)
if not is_replaced:
tries = ", ".join("%s.%s" % c for c in candidates)
raise NodeNotFoundError(
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
"Django tried to replace migration {1}.{2} with any of [{3}] "
"but wasn't able to because some of the replaced migrations "
"are already applied.".format(
exc.origin, exc.node[0], exc.node[1], tries
),
exc.node,
) from exc
raise
self.graph.ensure_not_cyclic()
def check_consistent_history(self, connection):
"""
Raise InconsistentMigrationHistory if any applied migrations have
unapplied dependencies.
"""
recorder = MigrationRecorder(connection)
applied = recorder.applied_migrations()
for migration in applied:
# If the migration is unknown, skip it.
if migration not in self.graph.nodes:
continue
for parent in self.graph.node_map[migration].parents:
if parent not in applied:
# Skip unapplied squashed migrations that have all of their
# `replaces` applied.
if parent in self.replacements:
if all(
m in applied for m in self.replacements[parent].replaces
):
continue
raise InconsistentMigrationHistory(
"Migration {}.{} is applied before its dependency "
"{}.{} on database '{}'.".format(
migration[0],
migration[1],
parent[0],
parent[1],
connection.alias,
)
)
def detect_conflicts(self):
"""
Look through the loaded graph and detect any conflicts - apps
with more than one leaf migration. Return a dict of the app labels
that conflict with the migration names that conflict.
"""
seen_apps = {}
conflicting_apps = set()
for app_label, migration_name in self.graph.leaf_nodes():
if app_label in seen_apps:
conflicting_apps.add(app_label)
seen_apps.setdefault(app_label, set()).add(migration_name)
return {
app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps
}
def project_state(self, nodes=None, at_end=True):
"""
Return a ProjectState object representing the most recent state
that the loaded migrations represent.
See graph.make_state() for the meaning of "nodes" and "at_end".
"""
return self.graph.make_state(
nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps
)
def collect_sql(self, plan):
"""
Take a migration plan and return a list of collected SQL statements
that represent the best-efforts version of that plan.
"""
statements = []
state = None
for migration, backwards in plan:
with self.connection.schema_editor(
collect_sql=True, atomic=migration.atomic
) as schema_editor:
if state is None:
state = self.project_state(
(migration.app_label, migration.name), at_end=False
)
if not backwards:
state = migration.apply(state, schema_editor, collect_sql=True)
else:
state = migration.unapply(state, schema_editor, collect_sql=True)
statements.extend(schema_editor.collected_sql)
return statements

View File

@@ -0,0 +1,237 @@
from django.db.migrations.utils import get_migration_name_timestamp
from django.db.transaction import atomic
from .exceptions import IrreversibleError
class Migration:
"""
The base class for all migrations.
Migration files will import this from django.db.migrations.Migration
and subclass it as a class called Migration. It will have one or more
of the following attributes:
- operations: A list of Operation instances, probably from
django.db.migrations.operations
- dependencies: A list of tuples of (app_path, migration_name)
- run_before: A list of tuples of (app_path, migration_name)
- replaces: A list of migration_names
Note that all migrations come out of migrations and into the Loader or
Graph as instances, having been initialized with their app label and name.
"""
# Operations to apply during this migration, in order.
operations = []
# Other migrations that should be run before this migration.
# Should be a list of (app, migration_name).
dependencies = []
# Other migrations that should be run after this one (i.e. have
# this migration added to their dependencies). Useful to make third-party
# apps' migrations run after your AUTH_USER replacement, for example.
run_before = []
# Migration names in this app that this migration replaces. If this is
# non-empty, this migration will only be applied if all these migrations
# are not applied.
replaces = []
# Is this an initial migration? Initial migrations are skipped on
# --fake-initial if the table or fields already exist. If None, check if
# the migration has any dependencies to determine if there are dependencies
# to tell if db introspection needs to be done. If True, always perform
# introspection. If False, never perform introspection.
initial = None
# Whether to wrap the whole migration in a transaction. Only has an effect
# on database backends which support transactional DDL.
atomic = True
def __init__(self, name, app_label):
self.name = name
self.app_label = app_label
# Copy dependencies & other attrs as we might mutate them at runtime
self.operations = list(self.__class__.operations)
self.dependencies = list(self.__class__.dependencies)
self.run_before = list(self.__class__.run_before)
self.replaces = list(self.__class__.replaces)
def __eq__(self, other):
return (
isinstance(other, Migration)
and self.name == other.name
and self.app_label == other.app_label
)
def __repr__(self):
return "<Migration %s.%s>" % (self.app_label, self.name)
def __str__(self):
return "%s.%s" % (self.app_label, self.name)
def __hash__(self):
return hash("%s.%s" % (self.app_label, self.name))
def mutate_state(self, project_state, preserve=True):
"""
Take a ProjectState and return a new one with the migration's
operations applied to it. Preserve the original object state by
default and return a mutated state from a copy.
"""
new_state = project_state
if preserve:
new_state = project_state.clone()
for operation in self.operations:
operation.state_forwards(self.app_label, new_state)
return new_state
def apply(self, project_state, schema_editor, collect_sql=False):
"""
Take a project_state representing all migrations prior to this one
and a schema_editor for a live database and apply the migration
in a forwards order.
Return the resulting project state for efficient reuse by following
Migrations.
"""
for operation in self.operations:
# If this operation cannot be represented as SQL, place a comment
# there instead
if collect_sql:
schema_editor.collected_sql.append("--")
schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
schema_editor.collected_sql.append(
"-- THIS OPERATION CANNOT BE WRITTEN AS SQL"
)
continue
collected_sql_before = len(schema_editor.collected_sql)
# Save the state before the operation has run
old_state = project_state.clone()
operation.state_forwards(self.app_label, project_state)
# Run the operation
atomic_operation = operation.atomic or (
self.atomic and operation.atomic is not False
)
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
operation.database_forwards(
self.app_label, schema_editor, old_state, project_state
)
else:
# Normal behaviour
operation.database_forwards(
self.app_label, schema_editor, old_state, project_state
)
if collect_sql and collected_sql_before == len(schema_editor.collected_sql):
schema_editor.collected_sql.append("-- (no-op)")
return project_state
def unapply(self, project_state, schema_editor, collect_sql=False):
"""
Take a project_state representing all migrations prior to this one
and a schema_editor for a live database and apply the migration
in a reverse order.
The backwards migration process consists of two phases:
1. The intermediate states from right before the first until right
after the last operation inside this migration are preserved.
2. The operations are applied in reverse order using the states
recorded in step 1.
"""
# Construct all the intermediate states we need for a reverse migration
to_run = []
new_state = project_state
# Phase 1
for operation in self.operations:
# If it's irreversible, error out
if not operation.reversible:
raise IrreversibleError(
"Operation %s in %s is not reversible" % (operation, self)
)
# Preserve new state from previous run to not tamper the same state
# over all operations
new_state = new_state.clone()
old_state = new_state.clone()
operation.state_forwards(self.app_label, new_state)
to_run.insert(0, (operation, old_state, new_state))
# Phase 2
for operation, to_state, from_state in to_run:
if collect_sql:
schema_editor.collected_sql.append("--")
schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
schema_editor.collected_sql.append(
"-- THIS OPERATION CANNOT BE WRITTEN AS SQL"
)
continue
collected_sql_before = len(schema_editor.collected_sql)
atomic_operation = operation.atomic or (
self.atomic and operation.atomic is not False
)
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
operation.database_backwards(
self.app_label, schema_editor, from_state, to_state
)
else:
# Normal behaviour
operation.database_backwards(
self.app_label, schema_editor, from_state, to_state
)
if collect_sql and collected_sql_before == len(schema_editor.collected_sql):
schema_editor.collected_sql.append("-- (no-op)")
return project_state
def suggest_name(self):
"""
Suggest a name for the operations this migration might represent. Names
are not guaranteed to be unique, but put some effort into the fallback
name to avoid VCS conflicts if possible.
"""
if self.initial:
return "initial"
raw_fragments = [op.migration_name_fragment for op in self.operations]
fragments = [name for name in raw_fragments if name]
if not fragments or len(fragments) != len(self.operations):
return "auto_%s" % get_migration_name_timestamp()
name = fragments[0]
for fragment in fragments[1:]:
new_name = f"{name}_{fragment}"
if len(new_name) > 52:
name = f"{name}_and_more"
break
name = new_name
return name
class SwappableTuple(tuple):
"""
Subclass of tuple so Django can tell this was originally a swappable
dependency when it reads the migration file.
"""
def __new__(cls, value, setting):
self = tuple.__new__(cls, value)
self.setting = setting
return self
def swappable_dependency(value):
"""Turn a setting value into a dependency."""
return SwappableTuple((value.split(".", 1)[0], "__first__"), value)

View File

@@ -0,0 +1,42 @@
from .fields import AddField, AlterField, RemoveField, RenameField
from .models import (
AddConstraint,
AddIndex,
AlterIndexTogether,
AlterModelManagers,
AlterModelOptions,
AlterModelTable,
AlterOrderWithRespectTo,
AlterUniqueTogether,
CreateModel,
DeleteModel,
RemoveConstraint,
RemoveIndex,
RenameIndex,
RenameModel,
)
from .special import RunPython, RunSQL, SeparateDatabaseAndState
__all__ = [
"CreateModel",
"DeleteModel",
"AlterModelTable",
"AlterUniqueTogether",
"RenameModel",
"AlterIndexTogether",
"AlterModelOptions",
"AddIndex",
"RemoveIndex",
"RenameIndex",
"AddField",
"RemoveField",
"AlterField",
"RenameField",
"AddConstraint",
"RemoveConstraint",
"SeparateDatabaseAndState",
"RunSQL",
"RunPython",
"AlterOrderWithRespectTo",
"AlterModelManagers",
]

View File

@@ -0,0 +1,146 @@
from django.db import router
class Operation:
"""
Base class for migration operations.
It's responsible for both mutating the in-memory model state
(see db/migrations/state.py) to represent what it performs, as well
as actually performing it against a live database.
Note that some operations won't modify memory state at all (e.g. data
copying operations), and some will need their modifications to be
optionally specified by the user (e.g. custom Python code snippets)
Due to the way this class deals with deconstruction, it should be
considered immutable.
"""
# If this migration can be run in reverse.
# Some operations are impossible to reverse, like deleting data.
reversible = True
# Can this migration be represented as SQL? (things like RunPython cannot)
reduces_to_sql = True
# Should this operation be forced as atomic even on backends with no
# DDL transaction support (i.e., does it have no DDL, like RunPython)
atomic = False
# Should this operation be considered safe to elide and optimize across?
elidable = False
serialization_expand_args = []
def __new__(cls, *args, **kwargs):
# We capture the arguments to make returning them trivial
self = object.__new__(cls)
self._constructor_args = (args, kwargs)
return self
def deconstruct(self):
"""
Return a 3-tuple of class import path (or just name if it lives
under django.db.migrations), positional arguments, and keyword
arguments.
"""
return (
self.__class__.__name__,
self._constructor_args[0],
self._constructor_args[1],
)
def state_forwards(self, app_label, state):
"""
Take the state from the previous migration, and mutate it
so that it matches what this migration would perform.
"""
raise NotImplementedError(
"subclasses of Operation must provide a state_forwards() method"
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
"""
Perform the mutation on the database schema in the normal
(forwards) direction.
"""
raise NotImplementedError(
"subclasses of Operation must provide a database_forwards() method"
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
"""
Perform the mutation on the database schema in the reverse
direction - e.g. if this were CreateModel, it would in fact
drop the model's table.
"""
raise NotImplementedError(
"subclasses of Operation must provide a database_backwards() method"
)
def describe(self):
"""
Output a brief summary of what the action does.
"""
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
@property
def migration_name_fragment(self):
"""
A filename part suitable for automatically naming a migration
containing this operation, or None if not applicable.
"""
return None
def references_model(self, name, app_label):
"""
Return True if there is a chance this operation references the given
model name (as a string), with an app label for accuracy.
Used for optimization. If in doubt, return True;
returning a false positive will merely make the optimizer a little
less efficient, while returning a false negative may result in an
unusable optimized migration.
"""
return True
def references_field(self, model_name, name, app_label):
"""
Return True if there is a chance this operation references the given
field name, with an app label for accuracy.
Used for optimization. If in doubt, return True.
"""
return self.references_model(model_name, app_label)
def allow_migrate_model(self, connection_alias, model):
"""
Return whether or not a model may be migrated.
This is a thin wrapper around router.allow_migrate_model() that
preemptively rejects any proxy, swapped out, or unmanaged model.
"""
if not model._meta.can_migrate(connection_alias):
return False
return router.allow_migrate_model(connection_alias, model)
def reduce(self, operation, app_label):
"""
Return either a list of operations the actual operation should be
replaced with or a boolean that indicates whether or not the specified
operation can be optimized across.
"""
if self.elidable:
return [operation]
elif operation.elidable:
return [self]
return False
def __repr__(self):
return "<%s %s%s>" % (
self.__class__.__name__,
", ".join(map(repr, self._constructor_args[0])),
",".join(" %s=%r" % x for x in self._constructor_args[1].items()),
)

View File

@@ -0,0 +1,357 @@
from django.db.migrations.utils import field_references
from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property
from .base import Operation
class FieldOperation(Operation):
def __init__(self, model_name, name, field=None):
self.model_name = model_name
self.name = name
self.field = field
@cached_property
def model_name_lower(self):
return self.model_name.lower()
@cached_property
def name_lower(self):
return self.name.lower()
def is_same_model_operation(self, operation):
return self.model_name_lower == operation.model_name_lower
def is_same_field_operation(self, operation):
return (
self.is_same_model_operation(operation)
and self.name_lower == operation.name_lower
)
def references_model(self, name, app_label):
name_lower = name.lower()
if name_lower == self.model_name_lower:
return True
if self.field:
return bool(
field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, name_lower),
)
)
return False
def references_field(self, model_name, name, app_label):
model_name_lower = model_name.lower()
# Check if this operation locally references the field.
if model_name_lower == self.model_name_lower:
if name == self.name:
return True
elif (
self.field
and hasattr(self.field, "from_fields")
and name in self.field.from_fields
):
return True
# Check if this operation remotely references the field.
if self.field is None:
return False
return bool(
field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, model_name_lower),
name,
)
)
def reduce(self, operation, app_label):
return super().reduce(operation, app_label) or not operation.references_field(
self.model_name, self.name, app_label
)
class AddField(FieldOperation):
"""Add a field to a model."""
def __init__(self, model_name, name, field, preserve_default=True):
self.preserve_default = preserve_default
super().__init__(model_name, name, field)
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
"field": self.field,
}
if self.preserve_default is not True:
kwargs["preserve_default"] = self.preserve_default
return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.add_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
field = to_model._meta.get_field(self.name)
if not self.preserve_default:
field.default = self.field.default
schema_editor.add_field(
from_model,
field,
)
if not self.preserve_default:
field.default = NOT_PROVIDED
def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
schema_editor.remove_field(
from_model, from_model._meta.get_field(self.name)
)
def describe(self):
return "Add field %s to %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return "%s_%s" % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
if isinstance(operation, FieldOperation) and self.is_same_field_operation(
operation
):
if isinstance(operation, AlterField):
return [
AddField(
model_name=self.model_name,
name=operation.name,
field=operation.field,
),
]
elif isinstance(operation, RemoveField):
return []
elif isinstance(operation, RenameField):
return [
AddField(
model_name=self.model_name,
name=operation.new_name,
field=self.field,
),
]
return super().reduce(operation, app_label)
class RemoveField(FieldOperation):
"""Remove a field from a model."""
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
}
return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.remove_field(app_label, self.model_name_lower, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
schema_editor.remove_field(
from_model, from_model._meta.get_field(self.name)
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
def describe(self):
return "Remove field %s from %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return "remove_%s_%s" % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
from .models import DeleteModel
if (
isinstance(operation, DeleteModel)
and operation.name_lower == self.model_name_lower
):
return [operation]
return super().reduce(operation, app_label)
class AlterField(FieldOperation):
"""
Alter a field's database column (e.g. null, max_length) to the provided
new field.
"""
def __init__(self, model_name, name, field, preserve_default=True):
self.preserve_default = preserve_default
super().__init__(model_name, name, field)
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
"field": self.field,
}
if self.preserve_default is not True:
kwargs["preserve_default"] = self.preserve_default
return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.alter_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
from_field = from_model._meta.get_field(self.name)
to_field = to_model._meta.get_field(self.name)
if not self.preserve_default:
to_field.default = self.field.default
schema_editor.alter_field(from_model, from_field, to_field)
if not self.preserve_default:
to_field.default = NOT_PROVIDED
def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.database_forwards(app_label, schema_editor, from_state, to_state)
def describe(self):
return "Alter field %s on %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return "alter_%s_%s" % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
if isinstance(operation, RemoveField) and self.is_same_field_operation(
operation
):
return [operation]
elif (
isinstance(operation, RenameField)
and self.is_same_field_operation(operation)
and self.field.db_column is None
):
return [
operation,
AlterField(
model_name=self.model_name,
name=operation.new_name,
field=self.field,
),
]
return super().reduce(operation, app_label)
class RenameField(FieldOperation):
"""Rename a field on the model. Might affect db_column too."""
def __init__(self, model_name, old_name, new_name):
self.old_name = old_name
self.new_name = new_name
super().__init__(model_name, old_name)
@cached_property
def old_name_lower(self):
return self.old_name.lower()
@cached_property
def new_name_lower(self):
return self.new_name.lower()
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"old_name": self.old_name,
"new_name": self.new_name,
}
return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.rename_field(
app_label, self.model_name_lower, self.old_name, self.new_name
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
schema_editor.alter_field(
from_model,
from_model._meta.get_field(self.old_name),
to_model._meta.get_field(self.new_name),
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
schema_editor.alter_field(
from_model,
from_model._meta.get_field(self.new_name),
to_model._meta.get_field(self.old_name),
)
def describe(self):
return "Rename field %s on %s to %s" % (
self.old_name,
self.model_name,
self.new_name,
)
@property
def migration_name_fragment(self):
return "rename_%s_%s_%s" % (
self.old_name_lower,
self.model_name_lower,
self.new_name_lower,
)
def references_field(self, model_name, name, app_label):
return self.references_model(model_name, app_label) and (
name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
)
def reduce(self, operation, app_label):
if (
isinstance(operation, RenameField)
and self.is_same_model_operation(operation)
and self.new_name_lower == operation.old_name_lower
):
return [
RenameField(
self.model_name,
self.old_name,
operation.new_name,
),
]
# Skip `FieldOperation.reduce` as we want to run `references_field`
# against self.old_name and self.new_name.
return super(FieldOperation, self).reduce(operation, app_label) or not (
operation.references_field(self.model_name, self.old_name, app_label)
or operation.references_field(self.model_name, self.new_name, app_label)
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,208 @@
from django.db import router
from .base import Operation
class SeparateDatabaseAndState(Operation):
"""
Take two lists of operations - ones that will be used for the database,
and ones that will be used for the state change. This allows operations
that don't support state change to have it applied, or have operations
that affect the state or not the database, or so on.
"""
serialization_expand_args = ["database_operations", "state_operations"]
def __init__(self, database_operations=None, state_operations=None):
self.database_operations = database_operations or []
self.state_operations = state_operations or []
def deconstruct(self):
kwargs = {}
if self.database_operations:
kwargs["database_operations"] = self.database_operations
if self.state_operations:
kwargs["state_operations"] = self.state_operations
return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
for state_operation in self.state_operations:
state_operation.state_forwards(app_label, state)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
# We calculate state separately in here since our state functions aren't useful
for database_operation in self.database_operations:
to_state = from_state.clone()
database_operation.state_forwards(app_label, to_state)
database_operation.database_forwards(
app_label, schema_editor, from_state, to_state
)
from_state = to_state
def database_backwards(self, app_label, schema_editor, from_state, to_state):
# We calculate state separately in here since our state functions aren't useful
to_states = {}
for dbop in self.database_operations:
to_states[dbop] = to_state
to_state = to_state.clone()
dbop.state_forwards(app_label, to_state)
# to_state now has the states of all the database_operations applied
# which is the from_state for the backwards migration of the last
# operation.
for database_operation in reversed(self.database_operations):
from_state = to_state
to_state = to_states[database_operation]
database_operation.database_backwards(
app_label, schema_editor, from_state, to_state
)
def describe(self):
return "Custom state/database change combination"
class RunSQL(Operation):
"""
Run some raw SQL. A reverse SQL statement may be provided.
Also accept a list of operations that represent the state change effected
by this SQL change, in case it's custom column/table creation/deletion.
"""
noop = ""
def __init__(
self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False
):
self.sql = sql
self.reverse_sql = reverse_sql
self.state_operations = state_operations or []
self.hints = hints or {}
self.elidable = elidable
def deconstruct(self):
kwargs = {
"sql": self.sql,
}
if self.reverse_sql is not None:
kwargs["reverse_sql"] = self.reverse_sql
if self.state_operations:
kwargs["state_operations"] = self.state_operations
if self.hints:
kwargs["hints"] = self.hints
return (self.__class__.__qualname__, [], kwargs)
@property
def reversible(self):
return self.reverse_sql is not None
def state_forwards(self, app_label, state):
for state_operation in self.state_operations:
state_operation.state_forwards(app_label, state)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
self._run_sql(schema_editor, self.sql)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_sql is None:
raise NotImplementedError("You cannot reverse this operation")
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
self._run_sql(schema_editor, self.reverse_sql)
def describe(self):
return "Raw SQL operation"
def _run_sql(self, schema_editor, sqls):
if isinstance(sqls, (list, tuple)):
for sql in sqls:
params = None
if isinstance(sql, (list, tuple)):
elements = len(sql)
if elements == 2:
sql, params = sql
else:
raise ValueError("Expected a 2-tuple but got %d" % elements)
schema_editor.execute(sql, params=params)
elif sqls != RunSQL.noop:
statements = schema_editor.connection.ops.prepare_sql_script(sqls)
for statement in statements:
schema_editor.execute(statement, params=None)
class RunPython(Operation):
"""
Run Python code in a context suitable for doing versioned ORM operations.
"""
reduces_to_sql = False
def __init__(
self, code, reverse_code=None, atomic=None, hints=None, elidable=False
):
self.atomic = atomic
# Forwards code
if not callable(code):
raise ValueError("RunPython must be supplied with a callable")
self.code = code
# Reverse code
if reverse_code is None:
self.reverse_code = None
else:
if not callable(reverse_code):
raise ValueError("RunPython must be supplied with callable arguments")
self.reverse_code = reverse_code
self.hints = hints or {}
self.elidable = elidable
def deconstruct(self):
kwargs = {
"code": self.code,
}
if self.reverse_code is not None:
kwargs["reverse_code"] = self.reverse_code
if self.atomic is not None:
kwargs["atomic"] = self.atomic
if self.hints:
kwargs["hints"] = self.hints
return (self.__class__.__qualname__, [], kwargs)
@property
def reversible(self):
return self.reverse_code is not None
def state_forwards(self, app_label, state):
# RunPython objects have no state effect. To add some, combine this
# with SeparateDatabaseAndState.
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
# RunPython has access to all models. Ensure that all models are
# reloaded in case any are delayed.
from_state.clear_delayed_apps_cache()
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
# We now execute the Python code in a context that contains a 'models'
# object, representing the versioned models as an app registry.
# We could try to override the global cache, but then people will still
# use direct imports, so we go with a documentation approach instead.
self.code(from_state.apps, schema_editor)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_code is None:
raise NotImplementedError("You cannot reverse this operation")
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
self.reverse_code(from_state.apps, schema_editor)
def describe(self):
return "Raw Python operation"
@staticmethod
def noop(apps, schema_editor):
return None

View File

@@ -0,0 +1,69 @@
class MigrationOptimizer:
"""
Power the optimization process, where you provide a list of Operations
and you are returned a list of equal or shorter length - operations
are merged into one if possible.
For example, a CreateModel and an AddField can be optimized into a
new CreateModel, and CreateModel and DeleteModel can be optimized into
nothing.
"""
def optimize(self, operations, app_label):
"""
Main optimization entry point. Pass in a list of Operation instances,
get out a new list of Operation instances.
Unfortunately, due to the scope of the optimization (two combinable
operations might be separated by several hundred others), this can't be
done as a peephole optimization with checks/output implemented on
the Operations themselves; instead, the optimizer looks at each
individual operation and scans forwards in the list to see if there
are any matches, stopping at boundaries - operations which can't
be optimized over (RunSQL, operations on the same field/model, etc.)
The inner loop is run until the starting list is the same as the result
list, and then the result is returned. This means that operation
optimization must be stable and always return an equal or shorter list.
"""
# Internal tracking variable for test assertions about # of loops
if app_label is None:
raise TypeError("app_label must be a str.")
self._iterations = 0
while True:
result = self.optimize_inner(operations, app_label)
self._iterations += 1
if result == operations:
return result
operations = result
def optimize_inner(self, operations, app_label):
"""Inner optimization loop."""
new_operations = []
for i, operation in enumerate(operations):
right = True # Should we reduce on the right or on the left.
# Compare it to each operation after it
for j, other in enumerate(operations[i + 1 :]):
result = operation.reduce(other, app_label)
if isinstance(result, list):
in_between = operations[i + 1 : i + j + 1]
if right:
new_operations.extend(in_between)
new_operations.extend(result)
elif all(op.reduce(other, app_label) is True for op in in_between):
# Perform a left reduction if all of the in-between
# operations can optimize through other.
new_operations.extend(result)
new_operations.extend(in_between)
else:
# Otherwise keep trying.
new_operations.append(operation)
break
new_operations.extend(operations[i + j + 2 :])
return new_operations
elif not result:
# Can't perform a right reduction.
right = False
else:
new_operations.append(operation)
return new_operations

View File

@@ -0,0 +1,341 @@
import datetime
import importlib
import os
import sys
from django.apps import apps
from django.core.management.base import OutputWrapper
from django.db.models import NOT_PROVIDED
from django.utils import timezone
from django.utils.version import get_docs_version
from .loader import MigrationLoader
class MigrationQuestioner:
"""
Give the autodetector responses to questions it might have.
This base class has a built-in noninteractive mode, but the
interactive subclass is what the command-line arguments will use.
"""
def __init__(self, defaults=None, specified_apps=None, dry_run=None):
self.defaults = defaults or {}
self.specified_apps = specified_apps or set()
self.dry_run = dry_run
def ask_initial(self, app_label):
"""Should we create an initial migration for the app?"""
# If it was specified on the command line, definitely true
if app_label in self.specified_apps:
return True
# Otherwise, we look to see if it has a migrations module
# without any Python files in it, apart from __init__.py.
# Apps from the new app template will have these; the Python
# file check will ensure we skip South ones.
try:
app_config = apps.get_app_config(app_label)
except LookupError: # It's a fake app.
return self.defaults.get("ask_initial", False)
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
if migrations_import_path is None:
# It's an application with migrations disabled.
return self.defaults.get("ask_initial", False)
try:
migrations_module = importlib.import_module(migrations_import_path)
except ImportError:
return self.defaults.get("ask_initial", False)
else:
if getattr(migrations_module, "__file__", None):
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
elif hasattr(migrations_module, "__path__"):
if len(migrations_module.__path__) > 1:
return False
filenames = os.listdir(list(migrations_module.__path__)[0])
return not any(x.endswith(".py") for x in filenames if x != "__init__.py")
def ask_not_null_addition(self, field_name, model_name):
"""Adding a NOT NULL field to a model."""
# None means quit
return None
def ask_not_null_alteration(self, field_name, model_name):
"""Changing a NULL field to NOT NULL."""
# None means quit
return None
def ask_rename(self, model_name, old_name, new_name, field_instance):
"""Was this field really renamed?"""
return self.defaults.get("ask_rename", False)
def ask_rename_model(self, old_model_state, new_model_state):
"""Was this model really renamed?"""
return self.defaults.get("ask_rename_model", False)
def ask_merge(self, app_label):
"""Should these migrations really be merged?"""
return self.defaults.get("ask_merge", False)
def ask_auto_now_add_addition(self, field_name, model_name):
"""Adding an auto_now_add field to a model."""
# None means quit
return None
def ask_unique_callable_default_addition(self, field_name, model_name):
"""Adding a unique field with a callable default."""
# None means continue.
return None
class InteractiveMigrationQuestioner(MigrationQuestioner):
def __init__(
self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None
):
super().__init__(
defaults=defaults, specified_apps=specified_apps, dry_run=dry_run
)
self.prompt_output = prompt_output or OutputWrapper(sys.stdout)
def _boolean_input(self, question, default=None):
self.prompt_output.write(f"{question} ", ending="")
result = input()
if not result and default is not None:
return default
while not result or result[0].lower() not in "yn":
self.prompt_output.write("Please answer yes or no: ", ending="")
result = input()
return result[0].lower() == "y"
def _choice_input(self, question, choices):
self.prompt_output.write(f"{question}")
for i, choice in enumerate(choices):
self.prompt_output.write(" %s) %s" % (i + 1, choice))
self.prompt_output.write("Select an option: ", ending="")
result = input()
while True:
try:
value = int(result)
except ValueError:
pass
else:
if 0 < value <= len(choices):
return value
self.prompt_output.write("Please select a valid option: ", ending="")
result = input()
def _ask_default(self, default=""):
"""
Prompt for a default value.
The ``default`` argument allows providing a custom default value (as a
string) which will be shown to the user and used as the return value
if the user doesn't provide any other input.
"""
self.prompt_output.write("Please enter the default value as valid Python.")
if default:
self.prompt_output.write(
f"Accept the default '{default}' by pressing 'Enter' or "
f"provide another value."
)
self.prompt_output.write(
"The datetime and django.utils.timezone modules are available, so "
"it is possible to provide e.g. timezone.now as a value."
)
self.prompt_output.write("Type 'exit' to exit this prompt")
while True:
if default:
prompt = "[default: {}] >>> ".format(default)
else:
prompt = ">>> "
self.prompt_output.write(prompt, ending="")
code = input()
if not code and default:
code = default
if not code:
self.prompt_output.write(
"Please enter some code, or 'exit' (without quotes) to exit."
)
elif code == "exit":
sys.exit(1)
else:
try:
return eval(code, {}, {"datetime": datetime, "timezone": timezone})
except (SyntaxError, NameError) as e:
self.prompt_output.write("Invalid input: %s" % e)
def ask_not_null_addition(self, field_name, model_name):
"""Adding a NOT NULL field to a model."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to add a non-nullable field '{field_name}' "
f"to {model_name} without specifying a default. This is "
f"because the database needs something to populate existing "
f"rows.\n"
f"Please select a fix:",
[
(
"Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"
),
"Quit and manually define a default value in models.py.",
],
)
if choice == 2:
sys.exit(3)
else:
return self._ask_default()
return None
def ask_not_null_alteration(self, field_name, model_name):
"""Changing a NULL field to NOT NULL."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to change a nullable field '{field_name}' "
f"on {model_name} to non-nullable without providing a "
f"default. This is because the database needs something to "
f"populate existing rows.\n"
f"Please select a fix:",
[
(
"Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"
),
"Ignore for now. Existing rows that contain NULL values "
"will have to be handled manually, for example with a "
"RunPython or RunSQL operation.",
"Quit and manually define a default value in models.py.",
],
)
if choice == 2:
return NOT_PROVIDED
elif choice == 3:
sys.exit(3)
else:
return self._ask_default()
return None
def ask_rename(self, model_name, old_name, new_name, field_instance):
"""Was this field really renamed?"""
msg = "Was %s.%s renamed to %s.%s (a %s)? [y/N]"
return self._boolean_input(
msg
% (
model_name,
old_name,
model_name,
new_name,
field_instance.__class__.__name__,
),
False,
)
def ask_rename_model(self, old_model_state, new_model_state):
"""Was this model really renamed?"""
msg = "Was the model %s.%s renamed to %s? [y/N]"
return self._boolean_input(
msg
% (old_model_state.app_label, old_model_state.name, new_model_state.name),
False,
)
def ask_merge(self, app_label):
return self._boolean_input(
"\nMerging will only work if the operations printed above do not conflict\n"
+ "with each other (working on different fields or models)\n"
+ "Should these migration branches be merged? [y/N]",
False,
)
def ask_auto_now_add_addition(self, field_name, model_name):
"""Adding an auto_now_add field to a model."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to add the field '{field_name}' with "
f"'auto_now_add=True' to {model_name} without providing a "
f"default. This is because the database needs something to "
f"populate existing rows.\n",
[
"Provide a one-off default now which will be set on all "
"existing rows",
"Quit and manually define a default value in models.py.",
],
)
if choice == 2:
sys.exit(3)
else:
return self._ask_default(default="timezone.now")
return None
def ask_unique_callable_default_addition(self, field_name, model_name):
"""Adding a unique field with a callable default."""
if not self.dry_run:
version = get_docs_version()
choice = self._choice_input(
f"Callable default on unique field {model_name}.{field_name} "
f"will not generate unique values upon migrating.\n"
f"Please choose how to proceed:\n",
[
f"Continue making this migration as the first step in "
f"writing a manual migration to generate unique values "
f"described here: "
f"https://docs.djangoproject.com/en/{version}/howto/"
f"writing-migrations/#migrations-that-add-unique-fields.",
"Quit and edit field options in models.py.",
],
)
if choice == 2:
sys.exit(3)
return None
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
def __init__(
self,
defaults=None,
specified_apps=None,
dry_run=None,
verbosity=1,
log=None,
):
self.verbosity = verbosity
self.log = log
super().__init__(
defaults=defaults,
specified_apps=specified_apps,
dry_run=dry_run,
)
def log_lack_of_migration(self, field_name, model_name, reason):
if self.verbosity > 0:
self.log(
f"Field '{field_name}' on model '{model_name}' not migrated: "
f"{reason}."
)
def ask_not_null_addition(self, field_name, model_name):
# We can't ask the user, so act like the user aborted.
self.log_lack_of_migration(
field_name,
model_name,
"it is impossible to add a non-nullable field without specifying "
"a default",
)
sys.exit(3)
def ask_not_null_alteration(self, field_name, model_name):
# We can't ask the user, so set as not provided.
self.log(
f"Field '{field_name}' on model '{model_name}' given a default of "
f"NOT PROVIDED and must be corrected."
)
return NOT_PROVIDED
def ask_auto_now_add_addition(self, field_name, model_name):
# We can't ask the user, so act like the user aborted.
self.log_lack_of_migration(
field_name,
model_name,
"it is impossible to add a field with 'auto_now_add=True' without "
"specifying a default",
)
sys.exit(3)

View File

@@ -0,0 +1,103 @@
from django.apps.registry import Apps
from django.db import DatabaseError, models
from django.utils.functional import classproperty
from django.utils.timezone import now
from .exceptions import MigrationSchemaMissing
class MigrationRecorder:
"""
Deal with storing migration records in the database.
Because this table is actually itself used for dealing with model
creation, it's the one thing we can't do normally via migrations.
We manually handle table creation/schema updating (using schema backend)
and then have a floating model to do queries with.
If a migration is unapplied its row is removed from the table. Having
a row in the table always means a migration is applied.
"""
_migration_class = None
@classproperty
def Migration(cls):
"""
Lazy load to avoid AppRegistryNotReady if installed apps import
MigrationRecorder.
"""
if cls._migration_class is None:
class Migration(models.Model):
app = models.CharField(max_length=255)
name = models.CharField(max_length=255)
applied = models.DateTimeField(default=now)
class Meta:
apps = Apps()
app_label = "migrations"
db_table = "django_migrations"
def __str__(self):
return "Migration %s for %s" % (self.name, self.app)
cls._migration_class = Migration
return cls._migration_class
def __init__(self, connection):
self.connection = connection
@property
def migration_qs(self):
return self.Migration.objects.using(self.connection.alias)
def has_table(self):
"""Return True if the django_migrations table exists."""
with self.connection.cursor() as cursor:
tables = self.connection.introspection.table_names(cursor)
return self.Migration._meta.db_table in tables
def ensure_schema(self):
"""Ensure the table exists and has the correct schema."""
# If the table's there, that's fine - we've never changed its schema
# in the codebase.
if self.has_table():
return
# Make the table
try:
with self.connection.schema_editor() as editor:
editor.create_model(self.Migration)
except DatabaseError as exc:
raise MigrationSchemaMissing(
"Unable to create the django_migrations table (%s)" % exc
)
def applied_migrations(self):
"""
Return a dict mapping (app_name, migration_name) to Migration instances
for all applied migrations.
"""
if self.has_table():
return {
(migration.app, migration.name): migration
for migration in self.migration_qs
}
else:
# If the django_migrations table doesn't exist, then no migrations
# are applied.
return {}
def record_applied(self, app, name):
"""Record that a migration was applied."""
self.ensure_schema()
self.migration_qs.create(app=app, name=name)
def record_unapplied(self, app, name):
"""Record that a migration was unapplied."""
self.ensure_schema()
self.migration_qs.filter(app=app, name=name).delete()
def flush(self):
"""Delete all migration records. Useful for testing migrations."""
self.migration_qs.all().delete()

View File

@@ -0,0 +1,382 @@
import builtins
import collections.abc
import datetime
import decimal
import enum
import functools
import math
import os
import pathlib
import re
import types
import uuid
from django.conf import SettingsReference
from django.db import models
from django.db.migrations.operations.base import Operation
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
from django.utils.functional import LazyObject, Promise
from django.utils.version import get_docs_version
class BaseSerializer:
def __init__(self, value):
self.value = value
def serialize(self):
raise NotImplementedError(
"Subclasses of BaseSerializer must implement the serialize() method."
)
class BaseSequenceSerializer(BaseSerializer):
def _format(self):
raise NotImplementedError(
"Subclasses of BaseSequenceSerializer must implement the _format() method."
)
def serialize(self):
imports = set()
strings = []
for item in self.value:
item_string, item_imports = serializer_factory(item).serialize()
imports.update(item_imports)
strings.append(item_string)
value = self._format()
return value % (", ".join(strings)), imports
class BaseSimpleSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), set()
class ChoicesSerializer(BaseSerializer):
def serialize(self):
return serializer_factory(self.value.value).serialize()
class DateTimeSerializer(BaseSerializer):
"""For datetime.*, except datetime.datetime."""
def serialize(self):
return repr(self.value), {"import datetime"}
class DatetimeDatetimeSerializer(BaseSerializer):
"""For datetime.datetime."""
def serialize(self):
if self.value.tzinfo is not None and self.value.tzinfo != datetime.timezone.utc:
self.value = self.value.astimezone(datetime.timezone.utc)
imports = ["import datetime"]
return repr(self.value), set(imports)
class DecimalSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), {"from decimal import Decimal"}
class DeconstructableSerializer(BaseSerializer):
@staticmethod
def serialize_deconstructed(path, args, kwargs):
name, imports = DeconstructableSerializer._serialize_path(path)
strings = []
for arg in args:
arg_string, arg_imports = serializer_factory(arg).serialize()
strings.append(arg_string)
imports.update(arg_imports)
for kw, arg in sorted(kwargs.items()):
arg_string, arg_imports = serializer_factory(arg).serialize()
imports.update(arg_imports)
strings.append("%s=%s" % (kw, arg_string))
return "%s(%s)" % (name, ", ".join(strings)), imports
@staticmethod
def _serialize_path(path):
module, name = path.rsplit(".", 1)
if module == "django.db.models":
imports = {"from django.db import models"}
name = "models.%s" % name
else:
imports = {"import %s" % module}
name = path
return name, imports
def serialize(self):
return self.serialize_deconstructed(*self.value.deconstruct())
class DictionarySerializer(BaseSerializer):
def serialize(self):
imports = set()
strings = []
for k, v in sorted(self.value.items()):
k_string, k_imports = serializer_factory(k).serialize()
v_string, v_imports = serializer_factory(v).serialize()
imports.update(k_imports)
imports.update(v_imports)
strings.append((k_string, v_string))
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
class EnumSerializer(BaseSerializer):
def serialize(self):
enum_class = self.value.__class__
module = enum_class.__module__
return (
"%s.%s[%r]" % (module, enum_class.__qualname__, self.value.name),
{"import %s" % module},
)
class FloatSerializer(BaseSimpleSerializer):
def serialize(self):
if math.isnan(self.value) or math.isinf(self.value):
return 'float("{}")'.format(self.value), set()
return super().serialize()
class FrozensetSerializer(BaseSequenceSerializer):
def _format(self):
return "frozenset([%s])"
class FunctionTypeSerializer(BaseSerializer):
def serialize(self):
if getattr(self.value, "__self__", None) and isinstance(
self.value.__self__, type
):
klass = self.value.__self__
module = klass.__module__
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {
"import %s" % module
}
# Further error checking
if self.value.__name__ == "<lambda>":
raise ValueError("Cannot serialize function: lambda")
if self.value.__module__ is None:
raise ValueError("Cannot serialize function %r: No module" % self.value)
module_name = self.value.__module__
if "<" not in self.value.__qualname__: # Qualname can include <locals>
return "%s.%s" % (module_name, self.value.__qualname__), {
"import %s" % self.value.__module__
}
raise ValueError(
"Could not find function %s in %s.\n" % (self.value.__name__, module_name)
)
class FunctoolsPartialSerializer(BaseSerializer):
def serialize(self):
# Serialize functools.partial() arguments
func_string, func_imports = serializer_factory(self.value.func).serialize()
args_string, args_imports = serializer_factory(self.value.args).serialize()
keywords_string, keywords_imports = serializer_factory(
self.value.keywords
).serialize()
# Add any imports needed by arguments
imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
return (
"functools.%s(%s, *%s, **%s)"
% (
self.value.__class__.__name__,
func_string,
args_string,
keywords_string,
),
imports,
)
class IterableSerializer(BaseSerializer):
def serialize(self):
imports = set()
strings = []
for item in self.value:
item_string, item_imports = serializer_factory(item).serialize()
imports.update(item_imports)
strings.append(item_string)
# When len(strings)==0, the empty iterable should be serialized as
# "()", not "(,)" because (,) is invalid Python syntax.
value = "(%s)" if len(strings) != 1 else "(%s,)"
return value % (", ".join(strings)), imports
class ModelFieldSerializer(DeconstructableSerializer):
def serialize(self):
attr_name, path, args, kwargs = self.value.deconstruct()
return self.serialize_deconstructed(path, args, kwargs)
class ModelManagerSerializer(DeconstructableSerializer):
def serialize(self):
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
if as_manager:
name, imports = self._serialize_path(qs_path)
return "%s.as_manager()" % name, imports
else:
return self.serialize_deconstructed(manager_path, args, kwargs)
class OperationSerializer(BaseSerializer):
def serialize(self):
from django.db.migrations.writer import OperationWriter
string, imports = OperationWriter(self.value, indentation=0).serialize()
# Nested operation, trailing comma is handled in upper OperationWriter._write()
return string.rstrip(","), imports
class PathLikeSerializer(BaseSerializer):
def serialize(self):
return repr(os.fspath(self.value)), {}
class PathSerializer(BaseSerializer):
def serialize(self):
# Convert concrete paths to pure paths to avoid issues with migrations
# generated on one platform being used on a different platform.
prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
return "pathlib.%s%r" % (prefix, self.value), {"import pathlib"}
class RegexSerializer(BaseSerializer):
def serialize(self):
regex_pattern, pattern_imports = serializer_factory(
self.value.pattern
).serialize()
# Turn off default implicit flags (e.g. re.U) because regexes with the
# same implicit and explicit flags aren't equal.
flags = self.value.flags ^ re.compile("").flags
regex_flags, flag_imports = serializer_factory(flags).serialize()
imports = {"import re", *pattern_imports, *flag_imports}
args = [regex_pattern]
if flags:
args.append(regex_flags)
return "re.compile(%s)" % ", ".join(args), imports
class SequenceSerializer(BaseSequenceSerializer):
def _format(self):
return "[%s]"
class SetSerializer(BaseSequenceSerializer):
def _format(self):
# Serialize as a set literal except when value is empty because {}
# is an empty dict.
return "{%s}" if self.value else "set(%s)"
class SettingsReferenceSerializer(BaseSerializer):
def serialize(self):
return "settings.%s" % self.value.setting_name, {
"from django.conf import settings"
}
class TupleSerializer(BaseSequenceSerializer):
def _format(self):
# When len(value)==0, the empty tuple should be serialized as "()",
# not "(,)" because (,) is invalid Python syntax.
return "(%s)" if len(self.value) != 1 else "(%s,)"
class TypeSerializer(BaseSerializer):
def serialize(self):
special_cases = [
(models.Model, "models.Model", ["from django.db import models"]),
(type(None), "type(None)", []),
]
for case, string, imports in special_cases:
if case is self.value:
return string, set(imports)
if hasattr(self.value, "__module__"):
module = self.value.__module__
if module == builtins.__name__:
return self.value.__name__, set()
else:
return "%s.%s" % (module, self.value.__qualname__), {
"import %s" % module
}
class UUIDSerializer(BaseSerializer):
def serialize(self):
return "uuid.%s" % repr(self.value), {"import uuid"}
class Serializer:
_registry = {
# Some of these are order-dependent.
frozenset: FrozensetSerializer,
list: SequenceSerializer,
set: SetSerializer,
tuple: TupleSerializer,
dict: DictionarySerializer,
models.Choices: ChoicesSerializer,
enum.Enum: EnumSerializer,
datetime.datetime: DatetimeDatetimeSerializer,
(datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
SettingsReference: SettingsReferenceSerializer,
float: FloatSerializer,
(bool, int, type(None), bytes, str, range): BaseSimpleSerializer,
decimal.Decimal: DecimalSerializer,
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
): FunctionTypeSerializer,
collections.abc.Iterable: IterableSerializer,
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
uuid.UUID: UUIDSerializer,
pathlib.PurePath: PathSerializer,
os.PathLike: PathLikeSerializer,
}
@classmethod
def register(cls, type_, serializer):
if not issubclass(serializer, BaseSerializer):
raise ValueError(
"'%s' must inherit from 'BaseSerializer'." % serializer.__name__
)
cls._registry[type_] = serializer
@classmethod
def unregister(cls, type_):
cls._registry.pop(type_)
def serializer_factory(value):
if isinstance(value, Promise):
value = str(value)
elif isinstance(value, LazyObject):
# The unwrapped value is returned as the first item of the arguments
# tuple.
value = value.__reduce__()[1][0]
if isinstance(value, models.Field):
return ModelFieldSerializer(value)
if isinstance(value, models.manager.BaseManager):
return ModelManagerSerializer(value)
if isinstance(value, Operation):
return OperationSerializer(value)
if isinstance(value, type):
return TypeSerializer(value)
# Anything that knows how to deconstruct itself.
if hasattr(value, "deconstruct"):
return DeconstructableSerializer(value)
for type_, serializer_cls in Serializer._registry.items():
if isinstance(value, type_):
return serializer_cls(value)
raise ValueError(
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
"topics/migrations/#migration-serializing" % (value, get_docs_version())
)

View File

@@ -0,0 +1,985 @@
import copy
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from django.apps import AppConfig
from django.apps.registry import Apps
from django.apps.registry import apps as global_apps
from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.migrations.utils import field_is_referenced, get_references
from django.db.models import NOT_PROVIDED
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
from django.db.models.options import DEFAULT_NAMES, normalize_together
from django.db.models.utils import make_model_tuple
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
from django.utils.version import get_docs_version
from .exceptions import InvalidBasesError
from .utils import resolve_relation
def _get_app_label_and_model_name(model, app_label=""):
if isinstance(model, str):
split = model.split(".", 1)
return tuple(split) if len(split) == 2 else (app_label, split[0])
else:
return model._meta.app_label, model._meta.model_name
def _get_related_models(m):
"""Return all models that have a direct relationship to the given model."""
related_models = [
subclass
for subclass in m.__subclasses__()
if issubclass(subclass, models.Model)
]
related_fields_models = set()
for f in m._meta.get_fields(include_parents=True, include_hidden=True):
if (
f.is_relation
and f.related_model is not None
and not isinstance(f.related_model, str)
):
related_fields_models.add(f.model)
related_models.append(f.related_model)
# Reverse accessors of foreign keys to proxy models are attached to their
# concrete proxied model.
opts = m._meta
if opts.proxy and m in related_fields_models:
related_models.append(opts.concrete_model)
return related_models
def get_related_models_tuples(model):
"""
Return a list of typical (app_label, model_name) tuples for all related
models for the given model.
"""
return {
(rel_mod._meta.app_label, rel_mod._meta.model_name)
for rel_mod in _get_related_models(model)
}
def get_related_models_recursive(model):
"""
Return all models that have a direct or indirect relationship
to the given model.
Relationships are either defined by explicit relational fields, like
ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another
model (a superclass is related to its subclasses, but not vice versa). Note,
however, that a model inheriting from a concrete model is also related to
its superclass through the implicit *_ptr OneToOneField on the subclass.
"""
seen = set()
queue = _get_related_models(model)
for rel_mod in queue:
rel_app_label, rel_model_name = (
rel_mod._meta.app_label,
rel_mod._meta.model_name,
)
if (rel_app_label, rel_model_name) in seen:
continue
seen.add((rel_app_label, rel_model_name))
queue.extend(_get_related_models(rel_mod))
return seen - {(model._meta.app_label, model._meta.model_name)}
class ProjectState:
"""
Represent the entire project's overall state. This is the item that is
passed around - do it here rather than at the app level so that cross-app
FKs/etc. resolve properly.
"""
def __init__(self, models=None, real_apps=None):
self.models = models or {}
# Apps to include from main registry, usually unmigrated ones
if real_apps is None:
real_apps = set()
else:
assert isinstance(real_apps, set)
self.real_apps = real_apps
self.is_delayed = False
# {remote_model_key: {model_key: {field_name: field}}}
self._relations = None
@property
def relations(self):
if self._relations is None:
self.resolve_fields_and_relations()
return self._relations
def add_model(self, model_state):
model_key = model_state.app_label, model_state.name_lower
self.models[model_key] = model_state
if self._relations is not None:
self.resolve_model_relations(model_key)
if "apps" in self.__dict__: # hasattr would cache the property
self.reload_model(*model_key)
def remove_model(self, app_label, model_name):
model_key = app_label, model_name
del self.models[model_key]
if self._relations is not None:
self._relations.pop(model_key, None)
# Call list() since _relations can change size during iteration.
for related_model_key, model_relations in list(self._relations.items()):
model_relations.pop(model_key, None)
if not model_relations:
del self._relations[related_model_key]
if "apps" in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(*model_key)
# Need to do this explicitly since unregister_model() doesn't clear
# the cache automatically (#24513)
self.apps.clear_cache()
def rename_model(self, app_label, old_name, new_name):
# Add a new model.
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
renamed_model = self.models[app_label, old_name_lower].clone()
renamed_model.name = new_name
self.models[app_label, new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, old_name_lower)
new_remote_model = f"{app_label}.{new_name}"
to_reload = set()
for model_state, name, field, reference in get_references(
self, old_model_tuple
):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
if self._relations is not None:
old_name_key = app_label, old_name_lower
new_name_key = app_label, new_name_lower
if old_name_key in self._relations:
self._relations[new_name_key] = self._relations.pop(old_name_key)
for model_relations in self._relations.values():
if old_name_key in model_relations:
model_relations[new_name_key] = model_relations.pop(old_name_key)
# Reload models related to old model before removing the old model.
self.reload_models(to_reload, delay=True)
# Remove the old model.
self.remove_model(app_label, old_name_lower)
self.reload_model(app_label, new_name_lower, delay=True)
def alter_model_options(self, app_label, model_name, options, option_keys=None):
model_state = self.models[app_label, model_name]
model_state.options = {**model_state.options, **options}
if option_keys:
for key in option_keys:
if key not in options:
model_state.options.pop(key, False)
self.reload_model(app_label, model_name, delay=True)
def remove_model_options(self, app_label, model_name, option_name, value_to_remove):
model_state = self.models[app_label, model_name]
if objs := model_state.options.get(option_name):
model_state.options[option_name] = [
obj for obj in objs if tuple(obj) != tuple(value_to_remove)
]
self.reload_model(app_label, model_name, delay=True)
def alter_model_managers(self, app_label, model_name, managers):
model_state = self.models[app_label, model_name]
model_state.managers = list(managers)
self.reload_model(app_label, model_name, delay=True)
def _append_option(self, app_label, model_name, option_name, obj):
model_state = self.models[app_label, model_name]
model_state.options[option_name] = [*model_state.options[option_name], obj]
self.reload_model(app_label, model_name, delay=True)
def _remove_option(self, app_label, model_name, option_name, obj_name):
model_state = self.models[app_label, model_name]
objs = model_state.options[option_name]
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
self.reload_model(app_label, model_name, delay=True)
def add_index(self, app_label, model_name, index):
self._append_option(app_label, model_name, "indexes", index)
def remove_index(self, app_label, model_name, index_name):
self._remove_option(app_label, model_name, "indexes", index_name)
def rename_index(self, app_label, model_name, old_index_name, new_index_name):
model_state = self.models[app_label, model_name]
objs = model_state.options["indexes"]
new_indexes = []
for obj in objs:
if obj.name == old_index_name:
obj = obj.clone()
obj.name = new_index_name
new_indexes.append(obj)
model_state.options["indexes"] = new_indexes
self.reload_model(app_label, model_name, delay=True)
def add_constraint(self, app_label, model_name, constraint):
self._append_option(app_label, model_name, "constraints", constraint)
def remove_constraint(self, app_label, model_name, constraint_name):
self._remove_option(app_label, model_name, "constraints", constraint_name)
def add_field(self, app_label, model_name, name, field, preserve_default):
# If preserve default is off, don't use the default for future state.
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_key = app_label, model_name
self.models[model_key].fields[name] = field
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, field)
# Delay rendering of relationships if it's not a relational field.
delay = not field.is_relation
self.reload_model(*model_key, delay=delay)
def remove_field(self, app_label, model_name, name):
model_key = app_label, model_name
model_state = self.models[model_key]
old_field = model_state.fields.pop(name)
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, old_field)
# Delay rendering of relationships if it's not a relational field.
delay = not old_field.is_relation
self.reload_model(*model_key, delay=delay)
def alter_field(self, app_label, model_name, name, field, preserve_default):
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_key = app_label, model_name
fields = self.models[model_key].fields
if self._relations is not None:
old_field = fields.pop(name)
if old_field.is_relation:
self.resolve_model_field_relations(model_key, name, old_field)
fields[name] = field
if field.is_relation:
self.resolve_model_field_relations(model_key, name, field)
else:
fields[name] = field
# TODO: investigate if old relational fields must be reloaded or if
# it's sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = not field.is_relation and not field_is_referenced(
self, model_key, (name, field)
)
self.reload_model(*model_key, delay=delay)
def rename_field(self, app_label, model_name, old_name, new_name):
model_key = app_label, model_name
model_state = self.models[model_key]
# Rename the field.
fields = model_state.fields
try:
found = fields.pop(old_name)
except KeyError:
raise FieldDoesNotExist(
f"{app_label}.{model_name} has no field named '{old_name}'"
)
fields[new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, "from_fields", None)
if from_fields:
field.from_fields = tuple(
[
new_name if from_field_name == old_name else from_field_name
for from_field_name in from_fields
]
)
# Fix index/unique_together to refer to the new field.
options = model_state.options
for option in ("index_together", "unique_together"):
if option in options:
options[option] = [
[new_name if n == old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(self, model_key, (old_name, found))
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, "field_name", None) == old_name:
remote_field.field_name = new_name
if to_fields:
field.to_fields = tuple(
[
new_name if to_field_name == old_name else to_field_name
for to_field_name in to_fields
]
)
if self._relations is not None:
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
for to_model in self._relations.values():
if old_name_lower in to_model[model_key]:
field = to_model[model_key].pop(old_name_lower)
field.name = new_name_lower
to_model[model_key][new_name_lower] = field
self.reload_model(*model_key, delay=delay)
def _find_reload_model(self, app_label, model_name, delay=False):
if delay:
self.is_delayed = True
related_models = set()
try:
old_model = self.apps.get_model(app_label, model_name)
except LookupError:
pass
else:
# Get all relations to and from the old model before reloading,
# as _meta.apps may change
if delay:
related_models = get_related_models_tuples(old_model)
else:
related_models = get_related_models_recursive(old_model)
# Get all outgoing references from the model to be rendered
model_state = self.models[(app_label, model_name)]
# Directly related models are the models pointed to by ForeignKeys,
# OneToOneFields, and ManyToManyFields.
direct_related_models = set()
for field in model_state.fields.values():
if field.is_relation:
if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
continue
rel_app_label, rel_model_name = _get_app_label_and_model_name(
field.related_model, app_label
)
direct_related_models.add((rel_app_label, rel_model_name.lower()))
# For all direct related models recursively get all related models.
related_models.update(direct_related_models)
for rel_app_label, rel_model_name in direct_related_models:
try:
rel_model = self.apps.get_model(rel_app_label, rel_model_name)
except LookupError:
pass
else:
if delay:
related_models.update(get_related_models_tuples(rel_model))
else:
related_models.update(get_related_models_recursive(rel_model))
# Include the model itself
related_models.add((app_label, model_name))
return related_models
def reload_model(self, app_label, model_name, delay=False):
if "apps" in self.__dict__: # hasattr would cache the property
related_models = self._find_reload_model(app_label, model_name, delay)
self._reload(related_models)
def reload_models(self, models, delay=True):
if "apps" in self.__dict__: # hasattr would cache the property
related_models = set()
for app_label, model_name in models:
related_models.update(
self._find_reload_model(app_label, model_name, delay)
)
self._reload(related_models)
def _reload(self, related_models):
# Unregister all related models
with self.apps.bulk_update():
for rel_app_label, rel_model_name in related_models:
self.apps.unregister_model(rel_app_label, rel_model_name)
states_to_be_rendered = []
# Gather all models states of those models that will be rerendered.
# This includes:
# 1. All related models of unmigrated apps
for model_state in self.apps.real_models:
if (model_state.app_label, model_state.name_lower) in related_models:
states_to_be_rendered.append(model_state)
# 2. All related models of migrated apps
for rel_app_label, rel_model_name in related_models:
try:
model_state = self.models[rel_app_label, rel_model_name]
except KeyError:
pass
else:
states_to_be_rendered.append(model_state)
# Render all models
self.apps.render_multiple(states_to_be_rendered)
def update_model_field_relation(
self,
model,
model_key,
field_name,
field,
concretes,
):
remote_model_key = resolve_relation(model, *model_key)
if remote_model_key[0] not in self.real_apps and remote_model_key in concretes:
remote_model_key = concretes[remote_model_key]
relations_to_remote_model = self._relations[remote_model_key]
if field_name in self.models[model_key].fields:
# The assert holds because it's a new relation, or an altered
# relation, in which case references have been removed by
# alter_field().
assert field_name not in relations_to_remote_model[model_key]
relations_to_remote_model[model_key][field_name] = field
else:
del relations_to_remote_model[model_key][field_name]
if not relations_to_remote_model[model_key]:
del relations_to_remote_model[model_key]
def resolve_model_field_relations(
self,
model_key,
field_name,
field,
concretes=None,
):
remote_field = field.remote_field
if not remote_field:
return
if concretes is None:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
self.update_model_field_relation(
remote_field.model,
model_key,
field_name,
field,
concretes,
)
through = getattr(remote_field, "through", None)
if not through:
return
self.update_model_field_relation(
through, model_key, field_name, field, concretes
)
def resolve_model_relations(self, model_key, concretes=None):
if concretes is None:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
model_state = self.models[model_key]
for field_name, field in model_state.fields.items():
self.resolve_model_field_relations(model_key, field_name, field, concretes)
def resolve_fields_and_relations(self):
# Resolve fields.
for model_state in self.models.values():
for field_name, field in model_state.fields.items():
field.name = field_name
# Resolve relations.
# {remote_model_key: {model_key: {field_name: field}}}
self._relations = defaultdict(partial(defaultdict, dict))
concretes, proxies = self._get_concrete_models_mapping_and_proxy_models()
for model_key in concretes:
self.resolve_model_relations(model_key, concretes)
for model_key in proxies:
self._relations[model_key] = self._relations[concretes[model_key]]
def get_concrete_model_key(self, model):
(
concrete_models_mapping,
_,
) = self._get_concrete_models_mapping_and_proxy_models()
model_key = make_model_tuple(model)
return concrete_models_mapping[model_key]
def _get_concrete_models_mapping_and_proxy_models(self):
concrete_models_mapping = {}
proxy_models = {}
# Split models to proxy and concrete models.
for model_key, model_state in self.models.items():
if model_state.options.get("proxy"):
proxy_models[model_key] = model_state
# Find a concrete model for the proxy.
concrete_models_mapping[
model_key
] = self._find_concrete_model_from_proxy(
proxy_models,
model_state,
)
else:
concrete_models_mapping[model_key] = model_key
return concrete_models_mapping, proxy_models
def _find_concrete_model_from_proxy(self, proxy_models, model_state):
for base in model_state.bases:
if not (isinstance(base, str) or issubclass(base, models.Model)):
continue
base_key = make_model_tuple(base)
base_state = proxy_models.get(base_key)
if not base_state:
# Concrete model found, stop looking at bases.
return base_key
return self._find_concrete_model_from_proxy(proxy_models, base_state)
def clone(self):
"""Return an exact copy of this ProjectState."""
new_state = ProjectState(
models={k: v.clone() for k, v in self.models.items()},
real_apps=self.real_apps,
)
if "apps" in self.__dict__:
new_state.apps = self.apps.clone()
new_state.is_delayed = self.is_delayed
return new_state
def clear_delayed_apps_cache(self):
if self.is_delayed and "apps" in self.__dict__:
del self.__dict__["apps"]
@cached_property
def apps(self):
return StateApps(self.real_apps, self.models)
@classmethod
def from_apps(cls, apps):
"""Take an Apps and return a ProjectState matching it."""
app_models = {}
for model in apps.get_models(include_swapped=True):
model_state = ModelState.from_model(model)
app_models[(model_state.app_label, model_state.name_lower)] = model_state
return cls(app_models)
def __eq__(self, other):
return self.models == other.models and self.real_apps == other.real_apps
class AppConfigStub(AppConfig):
"""Stub of an AppConfig. Only provides a label and a dict of models."""
def __init__(self, label):
self.apps = None
self.models = {}
# App-label and app-name are not the same thing, so technically passing
# in the label here is wrong. In practice, migrations don't care about
# the app name, but we need something unique, and the label works fine.
self.label = label
self.name = label
def import_models(self):
self.models = self.apps.all_models[self.label]
class StateApps(Apps):
"""
Subclass of the global Apps registry class to better handle dynamic model
additions and removals.
"""
def __init__(self, real_apps, models, ignore_swappable=False):
# Any apps in self.real_apps should have all their models included
# in the render. We don't use the original model instances as there
# are some variables that refer to the Apps object.
# FKs/M2Ms from real apps are also not included as they just
# mess things up with partial states (due to lack of dependencies)
self.real_models = []
for app_label in real_apps:
app = global_apps.get_app_config(app_label)
for model in app.get_models():
self.real_models.append(ModelState.from_model(model, exclude_rels=True))
# Populate the app registry with a stub for each application.
app_labels = {model_state.app_label for model_state in models.values()}
app_configs = [
AppConfigStub(label) for label in sorted([*real_apps, *app_labels])
]
super().__init__(app_configs)
# These locks get in the way of copying as implemented in clone(),
# which is called whenever Django duplicates a StateApps before
# updating it.
self._lock = None
self.ready_event = None
self.render_multiple([*models.values(), *self.real_models])
# There shouldn't be any operations pending at this point.
from django.core.checks.model_checks import _check_lazy_references
ignore = (
{make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
)
errors = _check_lazy_references(self, ignore=ignore)
if errors:
raise ValueError("\n".join(error.msg for error in errors))
@contextmanager
def bulk_update(self):
# Avoid clearing each model's cache for each change. Instead, clear
# all caches when we're finished updating the model instances.
ready = self.ready
self.ready = False
try:
yield
finally:
self.ready = ready
self.clear_cache()
def render_multiple(self, model_states):
# We keep trying to render the models in a loop, ignoring invalid
# base errors, until the size of the unrendered models doesn't
# decrease by at least one, meaning there's a base dependency loop/
# missing base.
if not model_states:
return
# Prevent that all model caches are expired for each render.
with self.bulk_update():
unrendered_models = model_states
while unrendered_models:
new_unrendered_models = []
for model in unrendered_models:
try:
model.render(self)
except InvalidBasesError:
new_unrendered_models.append(model)
if len(new_unrendered_models) == len(unrendered_models):
raise InvalidBasesError(
"Cannot resolve bases for %r\nThis can happen if you are "
"inheriting models from an app with migrations (e.g. "
"contrib.auth)\n in an app with no migrations; see "
"https://docs.djangoproject.com/en/%s/topics/migrations/"
"#dependencies for more"
% (new_unrendered_models, get_docs_version())
)
unrendered_models = new_unrendered_models
def clone(self):
"""Return a clone of this registry."""
clone = StateApps([], {})
clone.all_models = copy.deepcopy(self.all_models)
clone.app_configs = copy.deepcopy(self.app_configs)
# Set the pointer to the correct app registry.
for app_config in clone.app_configs.values():
app_config.apps = clone
# No need to actually clone them, they'll never change
clone.real_models = self.real_models
return clone
def register_model(self, app_label, model):
self.all_models[app_label][model._meta.model_name] = model
if app_label not in self.app_configs:
self.app_configs[app_label] = AppConfigStub(app_label)
self.app_configs[app_label].apps = self
self.app_configs[app_label].models[model._meta.model_name] = model
self.do_pending_operations(model)
self.clear_cache()
def unregister_model(self, app_label, model_name):
try:
del self.all_models[app_label][model_name]
del self.app_configs[app_label].models[model_name]
except KeyError:
pass
class ModelState:
"""
Represent a Django Model. Don't use the actual Model class as it's not
designed to have its options changed - instead, mutate this one and then
render it into a Model as required.
Note that while you are allowed to mutate .fields, you are not allowed
to mutate the Field instances inside there themselves - you must instead
assign new ones, as these are not detached during a clone.
"""
def __init__(
self, app_label, name, fields, options=None, bases=None, managers=None
):
self.app_label = app_label
self.name = name
self.fields = dict(fields)
self.options = options or {}
self.options.setdefault("indexes", [])
self.options.setdefault("constraints", [])
self.bases = bases or (models.Model,)
self.managers = managers or []
for name, field in self.fields.items():
# Sanity-check that fields are NOT already bound to a model.
if hasattr(field, "model"):
raise ValueError(
'ModelState.fields cannot be bound to a model - "%s" is.' % name
)
# Sanity-check that relation fields are NOT referring to a model class.
if field.is_relation and hasattr(field.related_model, "_meta"):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.to" does. '
"Use a string reference instead." % name
)
if field.many_to_many and hasattr(field.remote_field.through, "_meta"):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.through" '
"does. Use a string reference instead." % name
)
# Sanity-check that indexes have their name set.
for index in self.options["indexes"]:
if not index.name:
raise ValueError(
"Indexes passed to ModelState require a name attribute. "
"%r doesn't have one." % index
)
@cached_property
def name_lower(self):
return self.name.lower()
def get_field(self, field_name):
if field_name == "_order":
field_name = self.options.get("order_with_respect_to", field_name)
return self.fields[field_name]
@classmethod
def from_model(cls, model, exclude_rels=False):
"""Given a model, return a ModelState representing it."""
# Deconstruct the fields
fields = []
for field in model._meta.local_fields:
if getattr(field, "remote_field", None) and exclude_rels:
continue
if isinstance(field, models.OrderWrt):
continue
name = field.name
try:
fields.append((name, field.clone()))
except TypeError as e:
raise TypeError(
"Couldn't reconstruct field %s on %s: %s"
% (
name,
model._meta.label,
e,
)
)
if not exclude_rels:
for field in model._meta.local_many_to_many:
name = field.name
try:
fields.append((name, field.clone()))
except TypeError as e:
raise TypeError(
"Couldn't reconstruct m2m field %s on %s: %s"
% (
name,
model._meta.object_name,
e,
)
)
# Extract the options
options = {}
for name in DEFAULT_NAMES:
# Ignore some special options
if name in ["apps", "app_label"]:
continue
elif name in model._meta.original_attrs:
if name == "unique_together":
ut = model._meta.original_attrs["unique_together"]
options[name] = set(normalize_together(ut))
elif name == "index_together":
it = model._meta.original_attrs["index_together"]
options[name] = set(normalize_together(it))
elif name == "indexes":
indexes = [idx.clone() for idx in model._meta.indexes]
for index in indexes:
if not index.name:
index.set_name_with_model(model)
options["indexes"] = indexes
elif name == "constraints":
options["constraints"] = [
con.clone() for con in model._meta.constraints
]
else:
options[name] = model._meta.original_attrs[name]
# If we're ignoring relationships, remove all field-listing model
# options (that option basically just means "make a stub model")
if exclude_rels:
for key in ["unique_together", "index_together", "order_with_respect_to"]:
if key in options:
del options[key]
# Private fields are ignored, so remove options that refer to them.
elif options.get("order_with_respect_to") in {
field.name for field in model._meta.private_fields
}:
del options["order_with_respect_to"]
def flatten_bases(model):
bases = []
for base in model.__bases__:
if hasattr(base, "_meta") and base._meta.abstract:
bases.extend(flatten_bases(base))
else:
bases.append(base)
return bases
# We can't rely on __mro__ directly because we only want to flatten
# abstract models and not the whole tree. However by recursing on
# __bases__ we may end up with duplicates and ordering issues, we
# therefore discard any duplicates and reorder the bases according
# to their index in the MRO.
flattened_bases = sorted(
set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
)
# Make our record
bases = tuple(
(base._meta.label_lower if hasattr(base, "_meta") else base)
for base in flattened_bases
)
# Ensure at least one base inherits from models.Model
if not any(
(isinstance(base, str) or issubclass(base, models.Model)) for base in bases
):
bases = (models.Model,)
managers = []
manager_names = set()
default_manager_shim = None
for manager in model._meta.managers:
if manager.name in manager_names:
# Skip overridden managers.
continue
elif manager.use_in_migrations:
# Copy managers usable in migrations.
new_manager = copy.copy(manager)
new_manager._set_creation_counter()
elif manager is model._base_manager or manager is model._default_manager:
# Shim custom managers used as default and base managers.
new_manager = models.Manager()
new_manager.model = manager.model
new_manager.name = manager.name
if manager is model._default_manager:
default_manager_shim = new_manager
else:
continue
manager_names.add(manager.name)
managers.append((manager.name, new_manager))
# Ignore a shimmed default manager called objects if it's the only one.
if managers == [("objects", default_manager_shim)]:
managers = []
# Construct the new ModelState
return cls(
model._meta.app_label,
model._meta.object_name,
fields,
options,
bases,
managers,
)
def construct_managers(self):
"""Deep-clone the managers using deconstruction."""
# Sort all managers by their creation counter
sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter)
for mgr_name, manager in sorted_managers:
as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct()
if as_manager:
qs_class = import_string(qs_path)
yield mgr_name, qs_class.as_manager()
else:
manager_class = import_string(manager_path)
yield mgr_name, manager_class(*args, **kwargs)
def clone(self):
"""Return an exact copy of this ModelState."""
return self.__class__(
app_label=self.app_label,
name=self.name,
fields=dict(self.fields),
# Since options are shallow-copied here, operations such as
# AddIndex must replace their option (e.g 'indexes') rather
# than mutating it.
options=dict(self.options),
bases=self.bases,
managers=list(self.managers),
)
def render(self, apps):
"""Create a Model object from our current state into the given apps."""
# First, make a Meta object
meta_contents = {"app_label": self.app_label, "apps": apps, **self.options}
meta = type("Meta", (), meta_contents)
# Then, work out our bases
try:
bases = tuple(
(apps.get_model(base) if isinstance(base, str) else base)
for base in self.bases
)
except LookupError:
raise InvalidBasesError(
"Cannot resolve one or more bases from %r" % (self.bases,)
)
# Clone fields for the body, add other bits.
body = {name: field.clone() for name, field in self.fields.items()}
body["Meta"] = meta
body["__module__"] = "__fake__"
# Restore managers
body.update(self.construct_managers())
# Then, make a Model object (apps.register_model is called in __new__)
return type(self.name, bases, body)
def get_index_by_name(self, name):
for index in self.options["indexes"]:
if index.name == name:
return index
raise ValueError("No index named %s on model %s" % (name, self.name))
def get_constraint_by_name(self, name):
for constraint in self.options["constraints"]:
if constraint.name == name:
return constraint
raise ValueError("No constraint named %s on model %s" % (name, self.name))
def __repr__(self):
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
def __eq__(self, other):
return (
(self.app_label == other.app_label)
and (self.name == other.name)
and (len(self.fields) == len(other.fields))
and all(
k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
for (k1, f1), (k2, f2) in zip(
sorted(self.fields.items()),
sorted(other.fields.items()),
)
)
and (self.options == other.options)
and (self.bases == other.bases)
and (self.managers == other.managers)
)

View File

@@ -0,0 +1,129 @@
import datetime
import re
from collections import namedtuple
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
FieldReference = namedtuple("FieldReference", "to through")
COMPILED_REGEX_TYPE = type(re.compile(""))
class RegexObject:
def __init__(self, obj):
self.pattern = obj.pattern
self.flags = obj.flags
def __eq__(self, other):
if not isinstance(other, RegexObject):
return NotImplemented
return self.pattern == other.pattern and self.flags == other.flags
def get_migration_name_timestamp():
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
def resolve_relation(model, app_label=None, model_name=None):
"""
Turn a model class or model reference string and return a model tuple.
app_label and model_name are used to resolve the scope of recursive and
unscoped model relationship.
"""
if isinstance(model, str):
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
if app_label is None or model_name is None:
raise TypeError(
"app_label and model_name must be provided to resolve "
"recursive relationships."
)
return app_label, model_name
if "." in model:
app_label, model_name = model.split(".", 1)
return app_label, model_name.lower()
if app_label is None:
raise TypeError(
"app_label must be provided to resolve unscoped model relationships."
)
return app_label, model.lower()
return model._meta.app_label, model._meta.model_name
def field_references(
model_tuple,
field,
reference_model_tuple,
reference_field_name=None,
reference_field=None,
):
"""
Return either False or a FieldReference if `field` references provided
context.
False positives can be returned if `reference_field_name` is provided
without `reference_field` because of the introspection limitation it
incurs. This should not be an issue when this function is used to determine
whether or not an optimization can take place.
"""
remote_field = field.remote_field
if not remote_field:
return False
references_to = None
references_through = None
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
to_fields = getattr(field, "to_fields", None)
if (
reference_field_name is None
or
# Unspecified to_field(s).
to_fields is None
or
# Reference to primary key.
(
None in to_fields
and (reference_field is None or reference_field.primary_key)
)
or
# Reference to field.
reference_field_name in to_fields
):
references_to = (remote_field, to_fields)
through = getattr(remote_field, "through", None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
through_fields = remote_field.through_fields
if (
reference_field_name is None
or
# Unspecified through_fields.
through_fields is None
or
# Reference to field.
reference_field_name in through_fields
):
references_through = (remote_field, through_fields)
if not (references_to or references_through):
return False
return FieldReference(references_to, references_through)
def get_references(state, model_tuple, field_tuple=()):
"""
Generator of (model_state, name, field, reference) referencing
provided context.
If field_tuple is provided only references to this particular field of
model_tuple will be generated.
"""
for state_model_tuple, model_state in state.models.items():
for name, field in model_state.fields.items():
reference = field_references(
state_model_tuple, field, model_tuple, *field_tuple
)
if reference:
yield model_state, name, field, reference
def field_is_referenced(state, model_tuple, field_tuple):
"""Return whether `field_tuple` is referenced by any state models."""
return next(get_references(state, model_tuple, field_tuple), None) is not None

View File

@@ -0,0 +1,311 @@
import os
import re
from importlib import import_module
from django import get_version
from django.apps import apps
# SettingsReference imported for backwards compatibility in Django 2.2.
from django.conf import SettingsReference # NOQA
from django.db import migrations
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.serializer import Serializer, serializer_factory
from django.utils.inspect import get_func_args
from django.utils.module_loading import module_dir
from django.utils.timezone import now
class OperationWriter:
def __init__(self, operation, indentation=2):
self.operation = operation
self.buff = []
self.indentation = indentation
def serialize(self):
def _write(_arg_name, _arg_value):
if _arg_name in self.operation.serialization_expand_args and isinstance(
_arg_value, (list, tuple, dict)
):
if isinstance(_arg_value, dict):
self.feed("%s={" % _arg_name)
self.indent()
for key, value in _arg_value.items():
key_string, key_imports = MigrationWriter.serialize(key)
arg_string, arg_imports = MigrationWriter.serialize(value)
args = arg_string.splitlines()
if len(args) > 1:
self.feed("%s: %s" % (key_string, args[0]))
for arg in args[1:-1]:
self.feed(arg)
self.feed("%s," % args[-1])
else:
self.feed("%s: %s," % (key_string, arg_string))
imports.update(key_imports)
imports.update(arg_imports)
self.unindent()
self.feed("},")
else:
self.feed("%s=[" % _arg_name)
self.indent()
for item in _arg_value:
arg_string, arg_imports = MigrationWriter.serialize(item)
args = arg_string.splitlines()
if len(args) > 1:
for arg in args[:-1]:
self.feed(arg)
self.feed("%s," % args[-1])
else:
self.feed("%s," % arg_string)
imports.update(arg_imports)
self.unindent()
self.feed("],")
else:
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
args = arg_string.splitlines()
if len(args) > 1:
self.feed("%s=%s" % (_arg_name, args[0]))
for arg in args[1:-1]:
self.feed(arg)
self.feed("%s," % args[-1])
else:
self.feed("%s=%s," % (_arg_name, arg_string))
imports.update(arg_imports)
imports = set()
name, args, kwargs = self.operation.deconstruct()
operation_args = get_func_args(self.operation.__init__)
# See if this operation is in django.db.migrations. If it is,
# We can just use the fact we already have that imported,
# otherwise, we need to add an import for the operation class.
if getattr(migrations, name, None) == self.operation.__class__:
self.feed("migrations.%s(" % name)
else:
imports.add("import %s" % (self.operation.__class__.__module__))
self.feed("%s.%s(" % (self.operation.__class__.__module__, name))
self.indent()
for i, arg in enumerate(args):
arg_value = arg
arg_name = operation_args[i]
_write(arg_name, arg_value)
i = len(args)
# Only iterate over remaining arguments
for arg_name in operation_args[i:]:
if arg_name in kwargs: # Don't sort to maintain signature order
arg_value = kwargs[arg_name]
_write(arg_name, arg_value)
self.unindent()
self.feed("),")
return self.render(), imports
def indent(self):
self.indentation += 1
def unindent(self):
self.indentation -= 1
def feed(self, line):
self.buff.append(" " * (self.indentation * 4) + line)
def render(self):
return "\n".join(self.buff)
class MigrationWriter:
"""
Take a Migration instance and is able to produce the contents
of the migration file from it.
"""
def __init__(self, migration, include_header=True):
self.migration = migration
self.include_header = include_header
self.needs_manual_porting = False
def as_string(self):
"""Return a string of the file contents."""
items = {
"replaces_str": "",
"initial_str": "",
}
imports = set()
# Deconstruct operations
operations = []
for operation in self.migration.operations:
operation_string, operation_imports = OperationWriter(operation).serialize()
imports.update(operation_imports)
operations.append(operation_string)
items["operations"] = "\n".join(operations) + "\n" if operations else ""
# Format dependencies and write out swappable dependencies right
dependencies = []
for dependency in self.migration.dependencies:
if dependency[0] == "__setting__":
dependencies.append(
" migrations.swappable_dependency(settings.%s),"
% dependency[1]
)
imports.add("from django.conf import settings")
else:
dependencies.append(" %s," % self.serialize(dependency)[0])
items["dependencies"] = "\n".join(dependencies) + "\n" if dependencies else ""
# Format imports nicely, swapping imports of functions from migration files
# for comments
migration_imports = set()
for line in list(imports):
if re.match(r"^import (.*)\.\d+[^\s]*$", line):
migration_imports.add(line.split("import")[1].strip())
imports.remove(line)
self.needs_manual_porting = True
# django.db.migrations is always used, but models import may not be.
# If models import exists, merge it with migrations import.
if "from django.db import models" in imports:
imports.discard("from django.db import models")
imports.add("from django.db import migrations, models")
else:
imports.add("from django.db import migrations")
# Sort imports by the package / module to be imported (the part after
# "from" in "from ... import ..." or after "import" in "import ...").
sorted_imports = sorted(imports, key=lambda i: i.split()[1])
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
if migration_imports:
items["imports"] += (
"\n\n# Functions from the following migrations need manual "
"copying.\n# Move them and any dependencies into this file, "
"then update the\n# RunPython operations to refer to the local "
"versions:\n# %s"
) % "\n# ".join(sorted(migration_imports))
# If there's a replaces, make a string for it
if self.migration.replaces:
items["replaces_str"] = (
"\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
)
# Hinting that goes into comment
if self.include_header:
items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
"version": get_version(),
"timestamp": now().strftime("%Y-%m-%d %H:%M"),
}
else:
items["migration_header"] = ""
if self.migration.initial:
items["initial_str"] = "\n initial = True\n"
return MIGRATION_TEMPLATE % items
@property
def basedir(self):
migrations_package_name, _ = MigrationLoader.migrations_module(
self.migration.app_label
)
if migrations_package_name is None:
raise ValueError(
"Django can't create migrations for app '%s' because "
"migrations have been disabled via the MIGRATION_MODULES "
"setting." % self.migration.app_label
)
# See if we can import the migrations module directly
try:
migrations_module = import_module(migrations_package_name)
except ImportError:
pass
else:
try:
return module_dir(migrations_module)
except ValueError:
pass
# Alright, see if it's a direct submodule of the app
app_config = apps.get_app_config(self.migration.app_label)
(
maybe_app_name,
_,
migrations_package_basename,
) = migrations_package_name.rpartition(".")
if app_config.name == maybe_app_name:
return os.path.join(app_config.path, migrations_package_basename)
# In case of using MIGRATION_MODULES setting and the custom package
# doesn't exist, create one, starting from an existing package
existing_dirs, missing_dirs = migrations_package_name.split("."), []
while existing_dirs:
missing_dirs.insert(0, existing_dirs.pop(-1))
try:
base_module = import_module(".".join(existing_dirs))
except (ImportError, ValueError):
continue
else:
try:
base_dir = module_dir(base_module)
except ValueError:
continue
else:
break
else:
raise ValueError(
"Could not locate an appropriate location to create "
"migrations package %s. Make sure the toplevel "
"package exists and can be imported." % migrations_package_name
)
final_dir = os.path.join(base_dir, *missing_dirs)
os.makedirs(final_dir, exist_ok=True)
for missing_dir in missing_dirs:
base_dir = os.path.join(base_dir, missing_dir)
with open(os.path.join(base_dir, "__init__.py"), "w"):
pass
return final_dir
@property
def filename(self):
return "%s.py" % self.migration.name
@property
def path(self):
return os.path.join(self.basedir, self.filename)
@classmethod
def serialize(cls, value):
return serializer_factory(value).serialize()
@classmethod
def register_serializer(cls, type_, serializer):
Serializer.register(type_, serializer)
@classmethod
def unregister_serializer(cls, type_):
Serializer.unregister(type_)
MIGRATION_HEADER_TEMPLATE = """\
# Generated by Django %(version)s on %(timestamp)s
"""
MIGRATION_TEMPLATE = """\
%(migration_header)s%(imports)s
class Migration(migrations.Migration):
%(replaces_str)s%(initial_str)s
dependencies = [
%(dependencies)s\
]
operations = [
%(operations)s\
]
"""