+ {% with messages = get_flashed_messages() %}
+ {% if messages %}
+ {% for message in messages %}
+
{{ message }}
+ {% endfor %}
+ {% endif %}
+ {% endwith %}
+ {% if not current_user.is_authenticated %}
+
+ {% else %}
+
Hi, {{ current_user.login }}. Your last login: {{ current_user.last_login.strftime('%Y-%m-%d') }}
+
Logout
+
List of users
+
+ {% for user in users %}
+ -
+ {% if user == current_user %}
+ {{ user.login }}
+ {% else %}
+ {{ user.login }}
+ {% endif %}
+
+ {% endfor %}
+
+ {% endif %}
+
+
+
\ No newline at end of file
diff --git a/pony/flask/example/templates/login.html b/pony/flask/example/templates/login.html
new file mode 100644
index 000000000..562525904
--- /dev/null
+++ b/pony/flask/example/templates/login.html
@@ -0,0 +1,30 @@
+
+
+
+ {% with messages = get_flashed_messages() %}
+ {% if messages %}
+ {% for message in messages %}
+
{{ message }}
+ {% endfor %}
+ {% endif %}
+ {% endwith %}
+
Please login
+
+
+ {% if error %}
+
Error: {{ error }}
+ {% endif %}
+
+
+
\ No newline at end of file
diff --git a/pony/flask/example/templates/reg.html b/pony/flask/example/templates/reg.html
new file mode 100644
index 000000000..ae9a27d91
--- /dev/null
+++ b/pony/flask/example/templates/reg.html
@@ -0,0 +1,30 @@
+
+
+
+ {% with messages = get_flashed_messages() %}
+ {% if messages %}
+ {% for message in messages %}
+
{{ message }}
+ {% endfor %}
+ {% endif %}
+ {% endwith %}
+
Register
+
+
+ {% if error %}
+
Error: {{ error }}
+ {% endif %}
+
+
+
\ No newline at end of file
diff --git a/pony/flask/example/views.py b/pony/flask/example/views.py
new file mode 100644
index 000000000..a8477a025
--- /dev/null
+++ b/pony/flask/example/views.py
@@ -0,0 +1,56 @@
+from .app import app
+from .models import db
+from flask import render_template, request, flash, redirect, abort
+from flask_login import current_user, logout_user, login_user, login_required
+from datetime import datetime
+from pony.orm import flush
+
+@app.route('/')
+def index():
+ users = db.User.select()
+ return render_template('index.html', user=current_user, users=users)
+
+@app.route('/login', methods=['GET', 'POST'])
+def login():
+ if request.method == 'POST':
+ username = request.form['username']
+ password = request.form['password']
+ possible_user = db.User.get(login=username)
+ if not possible_user:
+ flash('Wrong username')
+ return redirect('/login')
+ if possible_user.password == password:
+ possible_user.last_login = datetime.now()
+ login_user(possible_user)
+ return redirect('/')
+
+ flash('Wrong password')
+ return redirect('/login')
+ else:
+ return render_template('login.html')
+
+@app.route('/reg', methods=['GET', 'POST'])
+def reg():
+ if request.method == 'POST':
+ username = request.form['username']
+ password = request.form['password']
+ exist = db.User.get(login=username)
+ if exist:
+ flash('Username %s is already taken, choose another one' % username)
+ return redirect('/reg')
+
+ user = db.User(login=username, password=password)
+ user.last_login = datetime.now()
+ flush()
+ login_user(user)
+ flash('Successfully registered')
+ return redirect('/')
+ else:
+ return render_template('reg.html')
+
+@app.route('/logout')
+@login_required
+def logout():
+ logout_user()
+ flash('Logged out')
+ return redirect('/')
\ No newline at end of file
diff --git a/pony/options.py b/pony/options.py
index 8e26fcad6..6c31ab487 100644
--- a/pony/options.py
+++ b/pony/options.py
@@ -59,7 +59,6 @@
CONSOLE_ENCODING = None
# db options
-PREFETCHING = True
MAX_FETCH_COUNT = None
# used for select(...).show()
diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py
index f989d18c5..6c9068b60 100644
--- a/pony/orm/asttranslation.py
+++ b/pony/orm/asttranslation.py
@@ -1,39 +1,47 @@
from __future__ import absolute_import, print_function, division
+from pony.py23compat import basestring, iteritems
from functools import update_wrapper
from pony.thirdparty.compiler import ast
-from pony.utils import throw
+from pony.utils import HashableDict, throw, copy_ast
class TranslationError(Exception): pass
+pre_method_caches = {}
+post_method_caches = {}
+
class ASTTranslator(object):
def __init__(translator, tree):
translator.tree = tree
- translator.pre_methods = {}
- translator.post_methods = {}
+ translator_cls = translator.__class__
+ pre_method_caches.setdefault(translator_cls, {})
+ post_method_caches.setdefault(translator_cls, {})
def dispatch(translator, node):
- cls = node.__class__
+ translator_cls = translator.__class__
+ pre_methods = pre_method_caches[translator_cls]
+ post_methods = post_method_caches[translator_cls]
+ node_cls = node.__class__
- try: pre_method = translator.pre_methods[cls]
+ try: pre_method = pre_methods[node_cls]
except KeyError:
- pre_method = getattr(translator, 'pre' + cls.__name__, translator.default_pre)
- translator.pre_methods[cls] = pre_method
- stop = translator.call(pre_method, node)
+ pre_method = getattr(translator_cls, 'pre' + node_cls.__name__, translator_cls.default_pre)
+ pre_methods[node_cls] = pre_method
+ stop = translator.call(pre_method, node)
if stop: return
for child in node.getChildNodes():
translator.dispatch(child)
- try: post_method = translator.post_methods[cls]
+ try: post_method = post_methods[node_cls]
except KeyError:
- post_method = getattr(translator, 'post' + cls.__name__, translator.default_post)
- translator.post_methods[cls] = post_method
+ post_method = getattr(translator_cls, 'post' + node_cls.__name__, translator_cls.default_post)
+ post_methods[node_cls] = post_method
translator.call(post_method, node)
def call(translator, method, node):
- return method(node)
+ return method(translator, node)
def default_pre(translator, node):
pass
def default_post(translator, node):
@@ -53,15 +61,22 @@ def binop_src(op, node):
return op.join((node.left.src, node.right.src))
def ast2src(tree):
+ src = getattr(tree, 'src', None)
+ if src is not None:
+ return src
PythonTranslator(tree)
return tree.src
class PythonTranslator(ASTTranslator):
def __init__(translator, tree):
ASTTranslator.__init__(translator, tree)
+ translator.top_level_f_str = None
translator.dispatch(tree)
def call(translator, method, node):
- node.src = method(node)
+ node.src = method(translator, node)
+ def default_pre(translator, node):
+ if getattr(node, 'src', None) is not None:
+ return True # node.src is already calculated, stop dispatching
def default_post(translator, node):
throw(NotImplementedError, node)
def postGenExpr(translator, node):
@@ -76,6 +91,20 @@ def postGenExprFor(translator, node):
return src
def postGenExprIf(translator, node):
return 'if %s' % node.test.src
+ def postIfExp(translator, node):
+ return '%s if %s else %s' % (node.then.src, node.test.src, node.else_.src)
+ def postLambda(translator, node):
+ argnames = list(node.argnames)
+ kwargs_name = argnames.pop() if node.kwargs else None
+ varargs_name = argnames.pop() if node.varargs else None
+ def_argnames = argnames[-len(node.defaults):] if node.defaults else []
+ nodef_argnames = argnames[:-len(node.defaults)] if node.defaults else argnames
+ args = ', '.join(nodef_argnames)
+ d_args = ', '.join('%s=%s' % (argname, default.src) for argname, default in zip(def_argnames, node.defaults))
+ v_arg = '*%s' % varargs_name if varargs_name else None
+ kw_arg = '**%s' % kwargs_name if kwargs_name else None
+ args = ', '.join(x for x in [args, d_args, v_arg, kw_arg] if x)
+ return 'lambda %s: %s' % (args, node.code.src)
@priority(14)
def postOr(translator, node):
return ' or '.join(expr.src for expr in node.nodes)
@@ -171,6 +200,8 @@ def postConst(translator, node):
s = str(value)
if float(s) == value: return s
return repr(value)
+ def postEllipsis(translator, node):
+ return '...'
def postList(translator, node):
node.priority = 1
return '[%s]' % ', '.join(item.src for item in node.nodes)
@@ -199,20 +230,48 @@ def postAssName(translator, node):
return node.name
def postKeyword(translator, node):
return '='.join((node.name, node.expr.src))
+ def preStr(self, node):
+ if self.top_level_f_str is None:
+ self.top_level_f_str = node
+ def postStr(self, node):
+ if self.top_level_f_str is node:
+ self.top_level_f_str = None
+ return "f%r" % ('{%s}' % node.value.src)
+ return '{%s}' % node.value.src
+ def preJoinedStr(self, node):
+ if self.top_level_f_str is None:
+ self.top_level_f_str = node
+ def postJoinedStr(self, node):
+ result = ''.join(
+ value.value if isinstance(value, ast.Const) else value.src
+ for value in node.values)
+ if self.top_level_f_str is node:
+ self.top_level_f_str = None
+ return "f%r" % result
+ return result
+ def preFormattedValue(self, node):
+ if self.top_level_f_str is None:
+ self.top_level_f_str = node
+ def postFormattedValue(self, node):
+ res = '{%s:%s}' % (node.value.src, node.fmt_spec.src)
+ if self.top_level_f_str is node:
+ self.top_level_f_str = None
+ return "f%r" % res
+ return res
nonexternalizable_types = (ast.Keyword, ast.Sliceobj, ast.List, ast.Tuple)
class PreTranslator(ASTTranslator):
def __init__(translator, tree, globals, locals,
- special_functions, const_functions, additional_internal_names=()):
+ special_functions, const_functions, outer_names=()):
ASTTranslator.__init__(translator, tree)
translator.globals = globals
translator.locals = locals
translator.special_functions = special_functions
translator.const_functions = const_functions
translator.contexts = []
- if additional_internal_names:
- translator.contexts.append(additional_internal_names)
+ if outer_names:
+ translator.contexts.append(outer_names)
translator.externals = externals = set()
translator.dispatch(tree)
for node in externals.copy():
@@ -224,13 +283,13 @@ def __init__(translator, tree, globals, locals,
def dispatch(translator, node):
node.external = node.constant = None
ASTTranslator.dispatch(translator, node)
- childs = node.getChildNodes()
- if node.external is None and childs and all(
- getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in childs):
+ children = node.getChildNodes()
+ if node.external is None and children and all(
+ getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in children):
node.external = True
if node.external and not node.constant:
externals = translator.externals
- externals.difference_update(childs)
+ externals.difference_update(children)
externals.add(node)
def preGenExprInner(translator, node):
translator.contexts.append(set())
@@ -260,6 +319,10 @@ def postName(translator, node):
node.external = True
def postConst(translator, node):
node.external = node.constant = True
+ def postDict(translator, node):
+ node.external = True
+ def postList(translator, node):
+ node.external = True
def postKeyword(translator, node):
node.constant = node.expr.constant
def postCallFunc(translator, node):
@@ -274,32 +337,37 @@ def postCallFunc(translator, node):
expr = '.'.join(reversed(attrs))
x = eval(expr, translator.globals, translator.locals)
try: hash(x)
- except TypeError: x = None
- if x in translator.special_functions:
- if x.__name__ == 'raw_sql': node.raw_sql = True
- else: node.external = False
- elif x in translator.const_functions:
- for arg in node.args:
- if not arg.constant: return
- if node.star_args is not None and not node.star_args.constant: return
- if node.dstar_args is not None and not node.dstar_args.constant: return
- node.constant = True
+ except TypeError: pass
+ else:
+ if x in translator.special_functions:
+ if x.__name__ == 'raw_sql': node.raw_sql = True
+ elif x is getattr:
+ attr_node = node.args[1]
+ attr_node.parent_node = node
+ else: node.external = False
+ elif x in translator.const_functions:
+ for arg in node.args:
+ if not arg.constant: return
+ if node.star_args is not None and not node.star_args.constant: return
+ if node.dstar_args is not None and not node.dstar_args.constant: return
+ node.constant = True
extractors_cache = {}
-def create_extractors(code_key, tree, filter_num, globals, locals,
- special_functions, const_functions, additional_internal_names=()):
- cache_key = code_key, filter_num
- result = extractors_cache.get(cache_key)
- if result is None:
- pretranslator = PreTranslator(
- tree, globals, locals, special_functions, const_functions, additional_internal_names)
+def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, outer_names=()):
+ result = extractors_cache.get(code_key)
+ if not result:
+ pretranslator = PreTranslator(tree, globals, locals, special_functions, const_functions, outer_names)
extractors = {}
for node in pretranslator.externals:
src = node.src = ast2src(node)
- if src == '.0': code = None
- else: code = compile(src, src, 'eval')
- extractors[filter_num, src] = code
- varnames = list(sorted(extractors))
- result = extractors_cache[cache_key] = extractors, varnames, tree
+ if src == '.0':
+ def extractor(globals, locals):
+ return locals['.0']
+ else:
+ code = compile(src, src, 'eval')
+ def extractor(globals, locals, code=code):
+ return eval(code, globals, locals)
+ extractors[src] = extractor
+ result = extractors_cache[code_key] = tree, extractors
return result
diff --git a/pony/orm/core.py b/pony/orm/core.py
index 326563b6e..fc54cb498 100644
--- a/pony/orm/core.py
+++ b/pony/orm/core.py
@@ -1,8 +1,8 @@
from __future__ import absolute_import, print_function, division
from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \
- basestring, unicode, buffer, int_types, builtins, pickle, with_metaclass
+ basestring, unicode, buffer, int_types, builtins, with_metaclass
-import json, re, sys, types, datetime, logging, itertools
+import json, re, sys, types, datetime, logging, itertools, warnings, inspect
from operator import attrgetter, itemgetter
from itertools import chain, starmap, repeat
from time import time
@@ -13,75 +13,97 @@
from collections import defaultdict
from hashlib import md5
from inspect import isgeneratorfunction
+from functools import wraps
from pony.thirdparty.compiler import ast, parse
import pony
from pony import options
from pony.orm.decompiling import decompile
-from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of
+from pony.orm.ormtypes import (
+ LongStr, LongUnicode, numeric_types, raw_sql, RawSQL, normalize, Json, TrackedValue, QueryType,
+ Array, IntArray, StrArray, FloatArray
+ )
from pony.orm.asttranslation import ast2src, create_extractors, TranslationError
from pony.orm.dbapiprovider import (
DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError,
OperationalError, IntegrityError, InternalError, ProgrammingError, NotSupportedError
)
from pony import utils
-from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \
- deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat
-
-__all__ = '''
- pony
+from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \
+ get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \
+ between, concat, coalesce, HashableDict, deref_proxy, deduplicate
- DBException RowNotFound MultipleRowsFound TooManyRowsFound
+__all__ = [
+ 'pony',
- Warning Error InterfaceError DatabaseError DataError OperationalError
- IntegrityError InternalError ProgrammingError NotSupportedError
+ 'DBException', 'RowNotFound', 'MultipleRowsFound', 'TooManyRowsFound',
- OrmError ERDiagramError DBSchemaError MappingError
- TableDoesNotExist TableIsNotEmpty ConstraintError CacheIndexError PermissionError
- ObjectNotFound MultipleObjectsFoundError TooManyObjectsFoundError OperationWithDeletedObjectError
- TransactionError ConnectionClosedError TransactionIntegrityError IsolationError CommitException RollbackException
- UnrepeatableReadError OptimisticCheckError UnresolvableCyclicDependency UnexpectedError DatabaseSessionIsOver
+ 'Warning', 'Error', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError',
+ 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError',
- TranslationError ExprEvalError
+ 'OrmError', 'ERDiagramError', 'DBSchemaError', 'MappingError', 'BindingError',
+ 'TableDoesNotExist', 'TableIsNotEmpty', 'ConstraintError', 'CacheIndexError',
+ 'ObjectNotFound', 'MultipleObjectsFoundError', 'TooManyObjectsFoundError', 'OperationWithDeletedObjectError',
+ 'TransactionError', 'ConnectionClosedError', 'TransactionIntegrityError', 'IsolationError',
+ 'CommitException', 'RollbackException', 'UnrepeatableReadError', 'OptimisticCheckError',
+ 'UnresolvableCyclicDependency', 'UnexpectedError', 'DatabaseSessionIsOver',
+ 'PonyRuntimeWarning', 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue',
+ 'TranslationError', 'ExprEvalError', 'PermissionError',
- RowNotFound MultipleRowsFound TooManyRowsFound
+ 'Database', 'sql_debug', 'set_sql_debug', 'sql_debugging', 'show',
- Database sql_debug show
+ 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator',
+ 'composite_key', 'composite_index',
+ 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', 'make_proxy',
- PrimaryKey Required Optional Set Discriminator
- composite_key composite_index
- flush commit rollback db_session with_transaction
+ 'LongStr', 'LongUnicode', 'Json', 'IntArray', 'StrArray', 'FloatArray',
- LongStr LongUnicode
+ 'select', 'left_join', 'get', 'exists', 'delete',
- select left_join get exists delete
+ 'count', 'sum', 'min', 'max', 'avg', 'group_concat', 'distinct',
- count sum min max avg distinct
+ 'JOIN', 'desc', 'between', 'concat', 'coalesce', 'raw_sql',
- JOIN desc concat raw_sql
+ 'buffer', 'unicode',
- buffer unicode
+ 'get_current_user', 'set_current_user', 'perm', 'has_perm',
+ 'get_user_groups', 'get_user_roles', 'get_object_labels',
+ 'user_groups_getter', 'user_roles_getter', 'obj_labels_getter'
+]
- get_current_user set_current_user perm has_perm
- get_user_groups get_user_roles get_object_labels
- user_groups_getter user_roles_getter obj_labels_getter
- '''.split()
-
-debug = False
suppress_debug_change = False
def sql_debug(value):
- global debug
- if not suppress_debug_change: debug = value
+ # todo: make sql_debug deprecated
+ if not suppress_debug_change:
+ local.debug = value
+
+
+def set_sql_debug(debug=True, show_values=None):
+ if not suppress_debug_change:
+ local.debug = debug
+ local.show_values = show_values
+
orm_logger = logging.getLogger('pony.orm')
sql_logger = logging.getLogger('pony.orm.sql')
orm_log_level = logging.INFO
+def has_handlers(logger):
+ if not PY2:
+ return logger.hasHandlers()
+ while logger:
+ if logger.handlers:
+ return True
+ elif not logger.propagate:
+ return False
+ logger = logger.parent
+ return False
+
def log_orm(msg):
- if logging.root.handlers:
+ if has_handlers(orm_logger):
orm_logger.log(orm_log_level, msg)
else:
print(msg)
@@ -89,15 +111,18 @@ def log_orm(msg):
def log_sql(sql, arguments=None):
if type(arguments) is list:
sql = 'EXECUTEMANY (%d)\n%s' % (len(arguments), sql)
- if logging.root.handlers:
- sql_logger.log(orm_log_level, sql) # arguments can hold sensitive information
+ if has_handlers(sql_logger):
+ if local.show_values and arguments:
+ sql = '%s\n%s' % (sql, format_arguments(arguments))
+ sql_logger.log(orm_log_level, sql)
else:
- print(sql)
- if not arguments: pass
- elif type(arguments) is list:
- for args in arguments: print(args2str(args))
- else: print(args2str(arguments))
- print()
+ if (local.show_values is None or local.show_values) and arguments:
+ sql = '%s\n%s' % (sql, format_arguments(arguments))
+ print(sql, end='\n\n')
+
+def format_arguments(arguments):
+ if type(arguments) is not list: return args2str(arguments)
+ return '\n'.join(args2str(args) for args in arguments)
def args2str(args):
if isinstance(args, (tuple, list)):
@@ -113,6 +138,7 @@ class OrmError(Exception): pass
class ERDiagramError(OrmError): pass
class DBSchemaError(OrmError): pass
class MappingError(OrmError): pass
+class BindingError(OrmError): pass
class TableDoesNotExist(OrmError): pass
class TableIsNotEmpty(OrmError): pass
@@ -180,13 +206,30 @@ def __init__(exc, msg, original_exc):
class ExprEvalError(TranslationError):
def __init__(exc, src, cause):
assert isinstance(cause, Exception)
- msg = '%s raises %s: %s' % (src, type(cause).__name__, str(cause))
+ msg = '`%s` raises %s: %s' % (src, type(cause).__name__, str(cause))
TranslationError.__init__(exc, msg)
exc.cause = cause
-class OptimizationFailed(Exception):
+class PonyInternalException(Exception):
+ pass
+
+class OptimizationFailed(PonyInternalException):
pass # Internal exception, cannot be encountered in user code
+class UseAnotherTranslator(PonyInternalException):
+ def __init__(self, translator):
+ Exception.__init__(self, 'This exception should be catched internally by PonyORM')
+ self.translator = translator
+
+class PonyRuntimeWarning(RuntimeWarning):
+ pass
+
+class DatabaseContainsIncorrectValue(PonyRuntimeWarning):
+ pass
+
+class DatabaseContainsIncorrectEmptyValue(DatabaseContainsIncorrectValue):
+ pass
+
def adapt_sql(sql, paramstyle):
result = adapted_sql_cache.get((sql, paramstyle))
if result is not None: return result
@@ -194,6 +237,7 @@ def adapt_sql(sql, paramstyle):
result = []
args = []
kwargs = {}
+ original_sql = sql
if paramstyle in ('format', 'pyformat'): sql = sql.replace('%', '%%')
while True:
try: i = sql.index('$', pos)
@@ -229,31 +273,74 @@ def adapt_sql(sql, paramstyle):
kwargs[key] = expr
result.append('%%(%s)s' % key)
else: throw(NotImplementedError)
- adapted_sql = ''.join(result)
- if args:
- source = '(%s,)' % ', '.join(args)
- code = compile(source, '>', 'eval')
- elif kwargs:
- source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items())
+ if args or kwargs:
+ adapted_sql = ''.join(result)
+ if args: source = '(%s,)' % ', '.join(args)
+ else: source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items())
code = compile(source, '>', 'eval')
else:
+ adapted_sql = original_sql.replace('$$', '$')
code = compile('None', '>', 'eval')
- if paramstyle in ('format', 'pyformat'): sql = sql.replace('%%', '%')
result = adapted_sql, code
adapted_sql_cache[(sql, paramstyle)] = result
return result
-num_counter = itertools.count()
+
+class PrefetchContext(object):
+ def __init__(self, database=None):
+ self.database = database
+ self.attrs_to_prefetch_dict = defaultdict(set)
+ self.entities_to_prefetch = set()
+ self.relations_to_prefetch_cache = {}
+ def copy(self):
+ result = PrefetchContext(self.database)
+ result.attrs_to_prefetch_dict = self.attrs_to_prefetch_dict.copy()
+ result.entities_to_prefetch = self.entities_to_prefetch.copy()
+ return result
+ def __enter__(self):
+ assert local.prefetch_context is None
+ local.prefetch_context = self
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ assert local.prefetch_context is self
+ local.prefetch_context = None
+ def get_frozen_attrs_to_prefetch(self, entity):
+ attrs_to_prefetch = self.attrs_to_prefetch_dict.get(entity, ())
+ if type(attrs_to_prefetch) is set:
+ attrs_to_prefetch = frozenset(attrs_to_prefetch)
+ self.attrs_to_prefetch_dict[entity] = attrs_to_prefetch
+ return attrs_to_prefetch
+ def get_relations_to_prefetch(self, entity):
+ result = self.relations_to_prefetch_cache.get(entity)
+ if result is None:
+ attrs_to_prefetch = self.attrs_to_prefetch_dict[entity]
+ result = tuple(attr for attr in entity._attrs_
+ if attr.is_relation and (
+ attr in attrs_to_prefetch or
+ attr.py_type in self.entities_to_prefetch and not attr.is_collection))
+ self.relations_to_prefetch_cache[entity] = result
+ return result
+
class Local(localbase):
def __init__(local):
+ local.debug = False
+ local.show_values = None
+ local.debug_stack = []
local.db2cache = {}
local.db_context_counter = 0
local.db_session = None
+ local.prefetch_context = None
local.current_user = None
local.perms_context = None
local.user_groups_cache = {}
local.user_roles_cache = defaultdict(dict)
+ def push_debug_state(local, debug, show_values):
+ local.debug_stack.append((local.debug, local.show_values))
+ if not suppress_debug_change:
+ local.debug = debug
+ local.show_values = show_values
+ def pop_debug_state(local):
+ local.debug, local.show_values = local.debug_stack.pop()
local = Local()
@@ -276,14 +363,28 @@ def transact_reraise(exc_class, exceptions):
reraise(exc_class, new_exc, tb)
finally: del exceptions, exc, tb, new_exc
+def rollback_and_reraise(exc_info):
+ try:
+ rollback()
+ finally:
+ reraise(*exc_info)
+
@cut_traceback
def commit():
caches = _get_caches()
if not caches: return
+
+ try:
+ for cache in caches:
+ cache.flush()
+ except:
+ rollback_and_reraise(sys.exc_info())
+
primary_cache = caches[0]
other_caches = caches[1:]
exceptions = []
- try: primary_cache.commit()
+ try:
+ primary_cache.commit()
except:
exceptions.append(sys.exc_info())
for cache in other_caches:
@@ -315,10 +416,12 @@ def rollback():
select_re = re.compile(r'\s*select\b', re.IGNORECASE)
class DBSessionContextManager(object):
- __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', 'immediate', 'ddl', 'serializable', 'strict'
- def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False,
- retry_exceptions=(TransactionError,), allowed_exceptions=()):
- if retry is not 0:
+ __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', \
+ 'immediate', 'ddl', 'serializable', 'strict', 'optimistic', \
+ 'sql_debug', 'show_values'
+ def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, optimistic=True,
+ retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None):
+ if retry != 0:
if type(retry) is not int: throw(TypeError,
"'retry' parameter of db_session must be of integer type. Got: %s" % type(retry))
if retry < 0: throw(TypeError,
@@ -332,10 +435,13 @@ def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False
db_session.retry = retry
db_session.ddl = ddl
db_session.serializable = serializable
- db_session.immediate = immediate or ddl or serializable
+ db_session.immediate = immediate or ddl or serializable or not optimistic
db_session.strict = strict
+ db_session.optimistic = optimistic and not serializable
db_session.retry_exceptions = retry_exceptions
db_session.allowed_exceptions = allowed_exceptions
+ db_session.sql_debug = sql_debug
+ db_session.show_values = show_values
def __call__(db_session, *args, **kwargs):
if not args and not kwargs: return db_session
if len(args) > 1: throw(TypeError,
@@ -344,41 +450,50 @@ def __call__(db_session, *args, **kwargs):
if kwargs: throw(TypeError,
'Pass only keyword arguments to db_session or use db_session as decorator')
func = args[0]
- if not isgeneratorfunction(func):
- return db_session._wrap_function(func)
- return db_session._wrap_generator_function(func)
+ if isgeneratorfunction(func) or hasattr(inspect, 'iscoroutinefunction') and inspect.iscoroutinefunction(func):
+ return db_session._wrap_coroutine_or_generator_function(func)
+ return db_session._wrap_function(func)
def __enter__(db_session):
- if db_session.retry is not 0: throw(TypeError,
+ if db_session.retry != 0: throw(TypeError,
"@db_session can accept 'retry' parameter only when used as decorator and not as context manager")
- if db_session.ddl: throw(TypeError,
- "@db_session can accept 'ddl' parameter only when used as decorator and not as context manager")
db_session._enter()
def _enter(db_session):
if local.db_session is None:
assert not local.db_context_counter
local.db_session = db_session
+ elif db_session.ddl and not local.db_session.ddl: throw(TransactionError,
+ 'Cannot start ddl transaction inside non-ddl transaction')
elif db_session.serializable and not local.db_session.serializable: throw(TransactionError,
'Cannot start serializable transaction inside non-serializable transaction')
local.db_context_counter += 1
+ if db_session.sql_debug is not None:
+ local.push_debug_state(db_session.sql_debug, db_session.show_values)
def __exit__(db_session, exc_type=None, exc=None, tb=None):
local.db_context_counter -= 1
- if local.db_context_counter: return
- assert local.db_session is db_session
+ try:
+ if not local.db_context_counter:
+ assert local.db_session is db_session
+ db_session._commit_or_rollback(exc_type, exc, tb)
+ finally:
+ if db_session.sql_debug is not None:
+ local.pop_debug_state()
+ def _commit_or_rollback(db_session, exc_type, exc, tb):
try:
if exc_type is None: can_commit = True
elif not callable(db_session.allowed_exceptions):
can_commit = issubclass(exc_type, tuple(db_session.allowed_exceptions))
else:
- # exc can be None in Python 2.6 even if exc_type is not None
- try: can_commit = exc is not None and db_session.allowed_exceptions(exc)
- except:
- rollback()
- raise
+ assert exc is not None # exc can be None in Python 2.6 even if exc_type is not None
+ try: can_commit = db_session.allowed_exceptions(exc)
+ except: rollback_and_reraise(sys.exc_info())
if can_commit:
commit()
for cache in _get_caches(): cache.release()
assert not local.db2cache
- else: rollback()
+ else:
+ try: rollback()
+ except:
+ if exc_type is None: raise # if exc_type is not None it will be reraised outside of __exit__
finally:
del exc, tb
local.db_session = None
@@ -386,28 +501,54 @@ def __exit__(db_session, exc_type=None, exc=None, tb=None):
local.user_roles_cache.clear()
def _wrap_function(db_session, func):
def new_func(func, *args, **kwargs):
- if db_session.ddl and local.db_context_counter:
- if isinstance(func, types.FunctionType): func = func.__name__ + '()'
- throw(TransactionError, '%s cannot be called inside of db_session' % func)
+ if local.db_context_counter:
+ if db_session.ddl:
+ fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func
+ throw(TransactionError, '@db_session-decorated %s function with `ddl` option '
+ 'cannot be called inside of another db_session' % fname)
+ if db_session.retry:
+ fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func
+ message = '@db_session decorator with `retry=%d` option is ignored for %s function ' \
+ 'because it is called inside another db_session' % (db_session.retry, fname)
+ warnings.warn(message, PonyRuntimeWarning, stacklevel=3)
+ if db_session.sql_debug is None:
+ return func(*args, **kwargs)
+ local.push_debug_state(db_session.sql_debug, db_session.show_values)
+ try:
+ return func(*args, **kwargs)
+ finally:
+ local.pop_debug_state()
+
exc = tb = None
try:
for i in xrange(db_session.retry+1):
db_session._enter()
exc_type = exc = tb = None
- try: return func(*args, **kwargs)
+ try:
+ result = func(*args, **kwargs)
+ commit()
+ return result
except:
- exc_type, exc, tb = sys.exc_info() # exc can be None in Python 2.6
- retry_exceptions = db_session.retry_exceptions
- if not callable(retry_exceptions):
- do_retry = issubclass(exc_type, tuple(retry_exceptions))
+ exc_type, exc, tb = sys.exc_info()
+ if getattr(exc, 'should_retry', False):
+ do_retry = True
else:
- do_retry = exc is not None and retry_exceptions(exc)
- if not do_retry: raise
- finally: db_session.__exit__(exc_type, exc, tb)
+ retry_exceptions = db_session.retry_exceptions
+ if not callable(retry_exceptions):
+ do_retry = issubclass(exc_type, tuple(retry_exceptions))
+ else:
+ assert exc is not None # exc can be None in Python 2.6
+ do_retry = retry_exceptions(exc)
+ if not do_retry:
+ raise
+ rollback()
+ finally:
+ db_session.__exit__(exc_type, exc, tb)
reraise(exc_type, exc, tb)
- finally: del exc, tb
+ finally:
+ del exc, tb
return decorator(new_func, func)
- def _wrap_generator_function(db_session, gen_func):
+ def _wrap_coroutine_or_generator_function(db_session, gen_func):
for option in ('ddl', 'retry', 'serializable'):
if getattr(db_session, option, None): throw(TypeError,
"db_session with `%s` option cannot be applied to generator function" % option)
@@ -425,7 +566,8 @@ def interact(iterator, input=None, exc_info=None):
if throw_ is None: reraise(*exc_info)
return throw_(*exc_info)
- def new_gen_func(gen_func, *args, **kwargs):
+ @wraps(gen_func)
+ def new_gen_func(*args, **kwargs):
db2cache_copy = {}
def wrapped_interact(iterator, input=None, exc_info=None):
@@ -436,28 +578,99 @@ def wrapped_interact(iterator, input=None, exc_info=None):
local.db_session = db_session
local.db2cache.update(db2cache_copy)
db2cache_copy.clear()
+ if db_session.sql_debug is not None:
+ local.push_debug_state(db_session.sql_debug, db_session.show_values)
try:
try:
output = interact(iterator, input, exc_info)
except StopIteration as e:
+ commit()
for cache in _get_caches():
- if cache.modified or cache.in_transaction: throw(TransactionError,
- 'You need to manually commit() changes before exiting from the generator')
- raise
+ cache.release()
+ assert not local.db2cache
+ raise e
for cache in _get_caches():
if cache.modified or cache.in_transaction: throw(TransactionError,
- 'You need to manually commit() changes before yielding from the generator')
+ 'You need to manually commit() changes before suspending the generator')
except:
- rollback()
- raise
+ rollback_and_reraise(sys.exc_info())
else:
return output
finally:
+ if db_session.sql_debug is not None:
+ local.pop_debug_state()
db2cache_copy.update(local.db2cache)
local.db2cache.clear()
local.db_context_counter = 0
local.db_session = None
+ gen = gen_func(*args, **kwargs)
+ iterator = gen.__await__() if hasattr(gen, '__await__') else iter(gen)
+ try:
+ output = wrapped_interact(iterator)
+ while True:
+ try:
+ input = yield output
+ except:
+ output = wrapped_interact(iterator, exc_info=sys.exc_info())
+ else:
+ output = wrapped_interact(iterator, input)
+ except StopIteration:
+ assert not db2cache_copy and not local.db2cache
+ return
+
+ if hasattr(types, 'coroutine'):
+ new_gen_func = types.coroutine(new_gen_func)
+ return new_gen_func
+
+db_session = DBSessionContextManager()
+
+
+class SQLDebuggingContextManager(object):
+ def __init__(self, debug=True, show_values=None):
+ self.debug = debug
+ self.show_values = show_values
+ def __call__(self, *args, **kwargs):
+ if not kwargs and len(args) == 1 and callable(args[0]):
+ arg = args[0]
+ if not isgeneratorfunction(arg):
+ return self._wrap_function(arg)
+ return self._wrap_generator_function(arg)
+ return self.__class__(*args, **kwargs)
+ def __enter__(self):
+ local.push_debug_state(self.debug, self.show_values)
+ def __exit__(self, exc_type=None, exc=None, tb=None):
+ local.pop_debug_state()
+ def _wrap_function(self, func):
+ def new_func(func, *args, **kwargs):
+ self.__enter__()
+ try:
+ return func(*args, **kwargs)
+ finally:
+ self.__exit__()
+ return decorator(new_func, func)
+ def _wrap_generator_function(self, gen_func):
+ def interact(iterator, input=None, exc_info=None):
+ if exc_info is None:
+ return next(iterator) if input is None else iterator.send(input)
+
+ if exc_info[0] is GeneratorExit:
+ close = getattr(iterator, 'close', None)
+ if close is not None: close()
+ reraise(*exc_info)
+
+ throw_ = getattr(iterator, 'throw', None)
+ if throw_ is None: reraise(*exc_info)
+ return throw_(*exc_info)
+
+ def new_gen_func(gen_func, *args, **kwargs):
+ def wrapped_interact(iterator, input=None, exc_info=None):
+ self.__enter__()
+ try:
+ return interact(iterator, input, exc_info)
+ finally:
+ self.__exit__()
+
gen = gen_func(*args, **kwargs)
iterator = iter(gen)
output = wrapped_interact(iterator)
@@ -473,11 +686,12 @@ def wrapped_interact(iterator, input=None, exc_info=None):
return
return decorator(new_gen_func, gen_func)
-db_session = DBSessionContextManager()
+sql_debugging = SQLDebuggingContextManager()
-def throw_db_session_is_over(obj, attr):
- throw(DatabaseSessionIsOver, 'Cannot read value of %s.%s: the database session is over'
- % (safe_repr(obj), attr.name))
+
+def throw_db_session_is_over(action, obj, attr=None):
+ msg = 'Cannot %s %s%s: the database session is over'
+ throw(DatabaseSessionIsOver, msg % (action, safe_repr(obj), '.%s' % attr.name if attr else ''))
def with_transaction(*args, **kwargs):
deprecated(3, "@with_transaction decorator is deprecated, use @db_session decorator instead")
@@ -494,6 +708,31 @@ def db_decorator(func, *args, **kwargs):
if web: throw(web.Http404NotFound)
raise
+known_providers = ('sqlite', 'postgres', 'mysql', 'oracle')
+
+class OnConnectDecorator(object):
+
+ @staticmethod
+ def check_provider(provider):
+ if provider:
+ if not isinstance(provider, basestring):
+ throw(TypeError, "'provider' option should be type of 'string', got %r" % type(provider).__name__)
+ if provider not in known_providers:
+ throw(BindingError, 'Unknown provider %s' % provider)
+
+ def __init__(self, database, provider):
+ OnConnectDecorator.check_provider(provider)
+ self.provider = provider
+ self.database = database
+
+ def __call__(self, func=None, provider=None):
+ if isinstance(func, types.FunctionType):
+ self.database._on_connect_funcs.append((func, provider or self.provider))
+ if not provider and func is basestring:
+ provider = func
+ OnConnectDecorator.check_provider(provider)
+ return OnConnectDecorator(self.database, provider)
+
class Database(object):
def __deepcopy__(self, memo):
return self # Database cannot be cloned by deepcopy()
@@ -516,27 +755,40 @@ def __init__(self, *args, **kwargs):
self._global_stats_lock = RLock()
self._dblocal = DbLocal()
- self.provider = None
+ self.on_connect = OnConnectDecorator(self, None)
+ self._on_connect_funcs = []
+ self.provider = self.provider_name = None
if args or kwargs: self._bind(*args, **kwargs)
+ def call_on_connect(database, con):
+ for func, provider in database._on_connect_funcs:
+ if not provider or provider == database.provider_name:
+ func(database, con)
+ con.commit()
@cut_traceback
def bind(self, *args, **kwargs):
self._bind(*args, **kwargs)
def _bind(self, *args, **kwargs):
# argument 'self' cannot be named 'database', because 'database' can be in kwargs
if self.provider is not None:
- throw(TypeError, 'Database object was already bound to %s provider' % self.provider.dialect)
- if not args:
- throw(TypeError, 'Database provider should be specified as a first positional argument')
- provider, args = args[0], args[1:]
+ throw(BindingError, 'Database object was already bound to %s provider' % self.provider.dialect)
+ if len(args) == 1 and not kwargs and hasattr(args[0], 'keys'):
+ args, kwargs = (), args[0]
+ provider = None
+ if args: provider, args = args[0], args[1:]
+ elif 'provider' not in kwargs: throw(TypeError, 'Database provider is not specified')
+ else: provider = kwargs.pop('provider')
if isinstance(provider, type) and issubclass(provider, DBAPIProvider):
provider_cls = provider
else:
- if not isinstance(provider, basestring): throw(TypeError)
+ if not isinstance(provider, basestring):
+ throw(TypeError, 'Provider name should be string. Got: %r' % type(provider).__name__)
if provider == 'pygresql': throw(TypeError,
'Pony no longer supports PyGreSQL module. Please use psycopg2 instead.')
+ self.provider_name = provider
provider_module = import_module('pony.orm.dbproviders.' + provider)
provider_cls = provider_module.provider_cls
- self.provider = provider = provider_cls(*args, **kwargs)
+ kwargs['pony_call_on_connect'] = self.call_on_connect
+ self.provider = provider_cls(*args, **kwargs)
@property
def last_sql(database):
return database._dblocal.last_sql
@@ -547,20 +799,31 @@ def _update_local_stat(database, sql, query_start_time):
dblocal = database._dblocal
dblocal.last_sql = sql
stats = dblocal.stats
+ query_end_time = time()
+ duration = query_end_time - query_start_time
+
stat = stats.get(sql)
- if stat is not None: stat.query_executed(query_start_time)
- else: stats[sql] = QueryStat(sql, query_start_time)
+ if stat is not None:
+ stat.query_executed(duration)
+ else:
+ stats[sql] = QueryStat(sql, duration)
+
+ total_stat = stats.get(None)
+ if total_stat is not None:
+ total_stat.query_executed(duration)
+ else:
+ stats[None] = QueryStat(None, duration)
def merge_local_stats(database):
setdefault = database._global_stats.setdefault
with database._global_stats_lock:
for sql, stat in iteritems(database._dblocal.stats):
global_stat = setdefault(sql, stat)
if global_stat is not stat: global_stat.merge(stat)
- database._dblocal.stats.clear()
+ database._dblocal.stats = {None: QueryStat(None)}
@property
def global_stats(database):
with database._global_stats_lock:
- return dict((sql, stat.copy()) for sql, stat in iteritems(database._global_stats))
+ return {sql: stat.copy() for sql, stat in iteritems(database._global_stats)}
@property
def global_stats_lock(database):
deprecated(3, "global_stats_lock is deprecated, just use global_stats property without any locking")
@@ -579,7 +842,7 @@ def get_connection(database):
def disconnect(database):
provider = database.provider
if provider is None: return
- if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_sesison')
+ if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_session')
cache = local.db2cache.get(database)
if cache is not None: cache.rollback()
provider.disconnect()
@@ -598,14 +861,17 @@ def flush(database):
@cut_traceback
def commit(database):
cache = local.db2cache.get(database)
- if cache is not None: cache.commit()
+ if cache is not None:
+ cache.flush_and_commit()
@cut_traceback
def rollback(database):
cache = local.db2cache.get(database)
- if cache is not None: cache.rollback()
+ if cache is not None:
+ try: cache.rollback()
+ except: transact_reraise(RollbackException, [sys.exc_info()])
@cut_traceback
def execute(database, sql, globals=None, locals=None):
- return database._exec_raw_sql(sql, globals, locals, frame_depth=3, start_transaction=True)
+ return database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1, start_transaction=True)
def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction=False):
provider = database.provider
if provider is None: throw(MappingError, 'Database object is not bound with a provider yet')
@@ -621,7 +887,7 @@ def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction
@cut_traceback
def select(database, sql, globals=None, locals=None, frame_depth=0):
if not select_re.match(sql): sql = 'select ' + sql
- cursor = database._exec_raw_sql(sql, globals, locals, frame_depth + 3)
+ cursor = database._exec_raw_sql(sql, globals, locals, frame_depth+cut_traceback_depth+1)
max_fetch_count = options.MAX_FETCH_COUNT
if max_fetch_count is not None:
result = cursor.fetchmany(max_fetch_count)
@@ -637,7 +903,7 @@ def select(database, sql, globals=None, locals=None, frame_depth=0):
return [ row_class(row) for row in result ]
@cut_traceback
def get(database, sql, globals=None, locals=None):
- rows = database.select(sql, globals, locals, frame_depth=3)
+ rows = database.select(sql, globals, locals, frame_depth=cut_traceback_depth+1)
if not rows: throw(RowNotFound)
if len(rows) > 1: throw(MultipleRowsFound)
row = rows[0]
@@ -645,7 +911,7 @@ def get(database, sql, globals=None, locals=None):
@cut_traceback
def exists(database, sql, globals=None, locals=None):
if not select_re.match(sql): sql = 'select ' + sql
- cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=3)
+ cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1)
result = cursor.fetchone()
return bool(result)
@cut_traceback
@@ -675,17 +941,18 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti
if start_transaction: cache.immediate = True
connection = cache.prepare_connection_for_query_execution()
cursor = connection.cursor()
- if debug: log_sql(sql, arguments)
+ if local.debug: log_sql(sql, arguments)
provider = database.provider
t = time()
try: new_id = provider.execute(cursor, sql, arguments, returning_id)
except Exception as e:
connection = cache.reconnect(e)
cursor = connection.cursor()
- if debug: log_sql(sql, arguments)
+ if local.debug: log_sql(sql, arguments)
t = time()
new_id = provider.execute(cursor, sql, arguments, returning_id)
- if cache.immediate: cache.in_transaction = True
+ if cache.immediate:
+ cache.in_transaction = True
database._update_local_stat(sql, t)
if not returning_id: return cursor
if PY2 and type(new_id) is long: new_id = int(new_id)
@@ -694,7 +961,7 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti
def generate_mapping(database, filename=None, check_tables=True, create_tables=False):
provider = database.provider
if provider is None: throw(MappingError, 'Database object is not bound with a provider yet')
- if database.schema: throw(MappingError, 'Mapping was already generated')
+ if database.schema: throw(BindingError, 'Mapping was already generated')
if filename is not None: throw(NotImplementedError)
schema = database.schema = provider.dbschema_cls(provider)
entities = list(sorted(database.entities.values(), key=attrgetter('_id_')))
@@ -702,6 +969,8 @@ def generate_mapping(database, filename=None, check_tables=True, create_tables=F
entity._resolve_attr_types_()
for entity in entities:
entity._link_reverse_attrs_()
+ for entity in entities:
+ entity._check_table_options_()
def get_columns(table, column_names):
column_dict = table.column_dict
@@ -713,7 +982,8 @@ def get_columns(table, column_names):
is_subclass = entity._root_ is not entity
if is_subclass:
- if table_name is not None: throw(NotImplementedError)
+ if table_name is not None: throw(NotImplementedError,
+ 'Cannot specify table name for entity %r which is subclass of %r' % (entity.__name__, entity._root_.__name__))
table_name = entity._root_._table_
entity._table_ = table_name
elif table_name is None:
@@ -722,14 +992,8 @@ def get_columns(table, column_names):
else: assert isinstance(table_name, (basestring, tuple))
table = schema.tables.get(table_name)
- if table is None: table = schema.add_table(table_name)
- elif table.entities:
- for e in table.entities:
- if e._root_ is not entity._root_:
- throw(MappingError, "Entities %s and %s cannot be mapped to table %s "
- "because they don't belong to the same hierarchy"
- % (e, entity, table_name))
- table.entities.add(entity)
+ if table is None: table = schema.add_table(table_name, entity)
+ else: table.add_entity(entity)
for attr in entity._new_attrs_:
if attr.is_collection:
@@ -760,12 +1024,15 @@ def get_columns(table, column_names):
if not attr.table:
seq_counter = itertools.count(2)
while m2m_table is not None:
- new_table_name = table_name + '_%d' % next(seq_counter)
+ if isinstance(table_name, basestring):
+ new_table_name = table_name + '_%d' % next(seq_counter)
+ else:
+ schema_name, base_name = provider.split_table_name(table_name)
+ new_table_name = schema_name, base_name + '_%d' % next(seq_counter)
m2m_table = schema.tables.get(new_table_name)
table_name = new_table_name
- elif m2m_table.entities or m2m_table.m2m:
- if isinstance(table_name, tuple): table_name = '.'.join(table_name)
- throw(MappingError, "Table name '%s' is already in use" % table_name)
+ elif m2m_table.entities or m2m_table.m2m: throw(MappingError,
+ "Table name %s is already in use" % provider.format_table_name(table_name))
else: throw(NotImplementedError)
attr.table = reverse.table = table_name
m2m_table = schema.add_table(table_name)
@@ -782,7 +1049,7 @@ def get_columns(table, column_names):
m2m_table.m2m.add(reverse)
else:
if attr.is_required: pass
- elif not attr.is_string:
+ elif not attr.type_has_empty_value:
if attr.nullable is False:
throw(TypeError, 'Optional attribute with non-string type %s must be nullable' % attr)
attr.nullable = True
@@ -844,19 +1111,32 @@ def get_columns(table, column_names):
m2m_table = schema.tables[attr.table]
parent_columns = get_columns(table, entity._pk_columns_)
child_columns = get_columns(m2m_table, reverse.columns)
- m2m_table.add_foreign_key(None, child_columns, table, parent_columns, attr.index)
+ on_delete = 'CASCADE'
+ m2m_table.add_foreign_key(reverse.fk_name, child_columns, table, parent_columns,
+ attr.index, on_delete)
if attr.symmetric:
- child_columns = get_columns(m2m_table, attr.reverse_columns)
- m2m_table.add_foreign_key(None, child_columns, table, parent_columns)
+ reverse_child_columns = get_columns(m2m_table, attr.reverse_columns)
+ m2m_table.add_foreign_key(attr.reverse_fk_name, reverse_child_columns, table, parent_columns,
+ attr.reverse_index, on_delete)
elif attr.reverse and attr.columns:
rentity = attr.reverse.entity
parent_table = schema.tables[rentity._table_]
parent_columns = get_columns(parent_table, rentity._pk_columns_)
child_columns = get_columns(table, attr.columns)
- table.add_foreign_key(None, child_columns, parent_table, parent_columns, attr.index)
+ if attr.reverse.cascade_delete:
+ on_delete = 'CASCADE'
+ elif isinstance(attr, Optional) and attr.nullable:
+ on_delete = 'SET NULL'
+ else:
+ on_delete = None
+ table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index,
+ on_delete, interleave=attr.interleave)
elif attr.index and attr.columns:
- columns = tuple(imap(table.column_dict.__getitem__, attr.columns))
- table.add_index(attr.index, columns, is_unique=attr.is_unique)
+ if isinstance(attr.py_type, Array) and provider.dialect != 'PostgreSQL':
+ pass # GIN indexes are supported only in PostgreSQL
+ else:
+ columns = tuple(imap(table.column_dict.__getitem__, attr.columns))
+ table.add_index(attr.index, columns, is_unique=attr.is_unique)
entity._initialize_bits_()
if create_tables: database.create_tables(check_tables)
@@ -864,7 +1144,6 @@ def get_columns(table, column_names):
@cut_traceback
@db_session(ddl=True)
def drop_table(database, table_name, if_exists=False, with_all_data=False):
- table_name = database._get_table_name(table_name)
database._drop_tables([ table_name ], if_exists, with_all_data, try_normalized=True)
def _get_table_name(database, table_name):
if isinstance(table_name, EntityMeta):
@@ -878,9 +1157,13 @@ def _get_table_name(database, table_name):
elif table_name is None:
if database.schema is None: throw(MappingError, 'No mapping was generated for the database')
else: throw(TypeError, 'Table name cannot be None')
- elif not isinstance(table_name, basestring):
- throw(TypeError, 'Invalid table name: %r' % table_name)
- table_name = table_name[:] # table_name = templating.plainstr(table_name)
+ elif isinstance(table_name, tuple):
+ for component in table_name:
+ if not isinstance(component, basestring):
+ throw(TypeError, 'Invalid table name component: {}'.format(component))
+ elif isinstance(table_name, basestring):
+ table_name = table_name[:] # table_name = templating.plainstr(table_name)
+ else: throw(TypeError, 'Invalid table name: {}'.format(table_name))
return table_name
@cut_traceback
@db_session(ddl=True)
@@ -897,19 +1180,24 @@ def _drop_tables(database, table_names, if_exists, with_all_data, try_normalized
if provider.table_exists(connection, table_name): existed_tables.append(table_name)
elif not if_exists:
if try_normalized:
- normalized_table_name = provider.normalize_name(table_name)
- if normalized_table_name != table_name \
- and provider.table_exists(connection, normalized_table_name):
- throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)'
- % (table_name, normalized_table_name))
- throw(TableDoesNotExist, 'Table %s does not exist' % table_name)
+ if isinstance(table_name, basestring):
+ normalized_table_name = provider.normalize_name(table_name)
+ else:
+ schema_name, base_name = provider.split_table_name(table_name)
+ normalized_table_name = schema_name, provider.normalize_name(base_name)
+ if normalized_table_name != table_name and provider.table_exists(connection, normalized_table_name):
+ throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)' % (
+ provider.format_table_name(table_name),
+ provider.format_table_name(normalized_table_name)))
+ throw(TableDoesNotExist, 'Table %s does not exist' % provider.format_table_name(table_name))
if not with_all_data:
for table_name in existed_tables:
if provider.table_has_data(connection, table_name): throw(TableIsNotEmpty,
'Cannot drop table %s because it is not empty. Specify option '
- 'with_all_data=True if you want to drop table with all data' % table_name)
+ 'with_all_data=True if you want to drop table with all data'
+ % provider.format_table_name(table_name))
for table_name in existed_tables:
- if debug: log_orm('DROPPING TABLE %s' % table_name)
+ if local.debug: log_orm('DROPPING TABLE %s' % provider.format_table_name(table_name))
provider.drop_table(connection, table_name)
@cut_traceback
@db_session(ddl=True)
@@ -971,7 +1259,7 @@ def _get_schema_dict(database):
return result
def _get_schema_json(database):
schema_json = json.dumps(database._get_schema_dict(), default=basic_converter, sort_keys=True)
- schema_hash = md5(schema_json).hexdigest()
+ schema_hash = md5(schema_json.encode('utf-8')).hexdigest()
return schema_json, schema_hash
@cut_traceback
def to_json(database, data, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None):
@@ -994,7 +1282,8 @@ def user_has_no_rights_to_see(obj, attr=None):
caches = set()
def obj_converter(obj):
if not isinstance(obj, Entity): return converter(obj)
- caches.add(obj._session_cache_)
+ cache = obj._session_cache_
+ if cache is not None: caches.add(cache)
if len(caches) > 1: throw(TransactionError,
'An attempt to serialize objects belonging to different transactions')
if not can_view(user, obj):
@@ -1160,7 +1449,7 @@ def deserialize(x):
if t is list: return list(imap(deserialize, x))
if t is dict:
if '_id_' not in x:
- return dict((key, deserialize(val)) for key, val in iteritems(x))
+ return {key: deserialize(val) for key, val in iteritems(x)}
obj = objmap.get(x['_id_'])
if obj is None:
entity_name = x['class']
@@ -1324,7 +1613,7 @@ def get_user_groups(user):
result = local.user_groups_cache.get(user)
if result is not None: return result
if user is None: return anybody_frozenset
- result = set(['anybody'])
+ result = {'anybody'}
for cls, func in usergroup_functions:
if cls is None or isinstance(user, cls):
groups = func(user)
@@ -1400,14 +1689,12 @@ def decorator(func):
class DbLocal(localbase):
def __init__(dblocal):
- dblocal.stats = {}
+ dblocal.stats = {None: QueryStat(None)}
dblocal.last_sql = None
class QueryStat(object):
- def __init__(stat, sql, query_start_time=None):
- if query_start_time is not None:
- query_end_time = time()
- duration = query_end_time - query_start_time
+ def __init__(stat, sql, duration=None):
+ if duration is not None:
stat.min_time = stat.max_time = stat.sum_time = duration
stat.db_count = 1
stat.cache_count = 0
@@ -1420,9 +1707,7 @@ def copy(stat):
result = object.__new__(QueryStat)
result.__dict__.update(stat.__dict__)
return result
- def query_executed(stat, query_start_time):
- query_end_time = time()
- duration = query_end_time - query_start_time
+ def query_executed(stat, duration):
if stat.db_count:
stat.min_time = builtins.min(stat.min_time, duration)
stat.max_time = builtins.max(stat.max_time, duration)
@@ -1447,6 +1732,8 @@ def avg_time(stat):
if not stat.db_count: return None
return stat.sum_time / stat.db_count
+num_counter = itertools.count()
+
class SessionCache(object):
def __init__(cache, database):
cache.is_alive = True
@@ -1463,6 +1750,7 @@ def __init__(cache, database):
cache.objects_to_save = []
cache.saved_objects = []
cache.query_results = {}
+ cache.dbvals_deduplication_cache = defaultdict(dict)
cache.modified = False
cache.db_session = db_session = local.db_session
cache.immediate = db_session is not None and db_session.immediate
@@ -1476,12 +1764,17 @@ def connect(cache):
assert cache.connection is None
if cache.in_transaction: throw(ConnectionClosedError,
'Transaction cannot be continued because database connection failed')
- provider = cache.database.provider
- connection = provider.connect()
- try: provider.set_transaction_mode(connection, cache) # can set cache.in_transaction
+ database = cache.database
+ provider = database.provider
+ connection, is_new_connection = provider.connect()
+ if is_new_connection:
+ database.call_on_connect(connection)
+ try:
+ provider.set_transaction_mode(connection, cache) # can set cache.in_transaction
except:
provider.drop(connection, cache)
raise
+
cache.connection = connection
return connection
def reconnect(cache, exc):
@@ -1489,7 +1782,7 @@ def reconnect(cache, exc):
if exc is not None:
exc = getattr(exc, 'original_exc', exc)
if not provider.should_reconnect(exc): reraise(*sys.exc_info())
- if debug: log_orm('CONNECTION FAILED: %s' % exc)
+ if local.debug: log_orm('CONNECTION FAILED: %s' % exc)
connection = cache.connection
assert connection is not None
cache.connection = None
@@ -1503,7 +1796,7 @@ def prepare_connection_for_query_execution(cache):
# in the interactive mode, outside of the db_session
if cache.in_transaction or cache.modified:
local.db_session = None
- try: cache.commit()
+ try: cache.flush_and_commit()
finally: local.db_session = db_session
cache.db_session = db_session
cache.immediate = cache.immediate or db_session.immediate
@@ -1516,16 +1809,23 @@ def prepare_connection_for_query_execution(cache):
except Exception as e: connection = cache.reconnect(e)
if not cache.noflush_counter and cache.modified: cache.flush()
return connection
+ def flush_and_commit(cache):
+ try: cache.flush()
+ except:
+ cache.rollback()
+ raise
+ try: cache.commit()
+ except: transact_reraise(CommitException, [sys.exc_info()])
def commit(cache):
assert cache.is_alive
- database = cache.database
- provider = database.provider
try:
if cache.modified: cache.flush()
if cache.in_transaction:
assert cache.connection is not None
- provider.commit(cache.connection, cache)
+ cache.database.provider.commit(cache.connection, cache)
cache.for_update.clear()
+ cache.query_results.clear()
+ cache.max_id_cache.clear()
cache.immediate = True
except:
cache.rollback()
@@ -1544,21 +1844,30 @@ def close(cache, rollback=True):
connection = cache.connection
if connection is None: return
cache.connection = None
- if rollback:
- try: provider.rollback(connection, cache)
- except:
- provider.drop(connection, cache)
- raise
- provider.release(connection, cache)
- db_session = cache.db_session or local.db_session
- if db_session and db_session.strict:
- cache.clear()
- def clear(cache):
- for obj in cache.objects:
- obj._vals_ = obj._dbvals_ = obj._session_cache_ = None
- cache.objects = cache.indexes = cache.seeds = cache.for_update = cache.modified_collections \
- = cache.objects_to_save = cache.saved_objects = cache.query_results \
- = cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None
+
+ try:
+ if rollback:
+ try: provider.rollback(connection, cache)
+ except:
+ provider.drop(connection, cache)
+ raise
+ provider.release(connection, cache)
+ finally:
+ db_session = cache.db_session or local.db_session
+ if db_session and db_session.strict:
+ for obj in cache.objects:
+ obj._vals_ = obj._dbvals_ = obj._session_cache_ = None
+ cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None
+ else:
+ for obj in cache.objects:
+ obj._dbvals_ = obj._session_cache_ = None
+ for attr, setdata in iteritems(obj._vals_):
+ if attr.is_collection:
+ if not setdata.is_fully_loaded: obj._vals_[attr] = None
+
+ cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \
+ = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \
+ = cache.modified_collections = cache.collection_statistics = cache.dbvals_deduplication_cache = None
@contextmanager
def flush_disabled(cache):
cache.noflush_counter += 1
@@ -1568,34 +1877,39 @@ def flush(cache):
if cache.noflush_counter: return
assert cache.is_alive
assert not cache.saved_objects
- if not cache.immediate: cache.immediate = True
- for i in xrange(50):
- if not cache.modified: return
-
- with cache.flush_disabled():
- for obj in cache.objects_to_save: # can grow during iteration
- if obj is not None: obj._before_save_()
-
- cache.query_results.clear()
- modified_m2m = cache._calc_modified_m2m()
- for attr, (added, removed) in iteritems(modified_m2m):
- if not removed: continue
- attr.remove_m2m(removed)
- for obj in cache.objects_to_save:
- if obj is not None: obj._save_()
- for attr, (added, removed) in iteritems(modified_m2m):
- if not added: continue
- attr.add_m2m(added)
-
- cache.max_id_cache.clear()
- cache.modified_collections.clear()
- cache.objects_to_save[:] = ()
- cache.modified = False
-
- cache.call_after_save_hooks()
- else:
- if cache.modified: throw(TransactionError,
- 'Recursion depth limit reached in obj._after_save_() call')
+ prev_immediate = cache.immediate
+ cache.immediate = True
+ try:
+ for i in xrange(50):
+ if not cache.modified: return
+
+ with cache.flush_disabled():
+ for obj in cache.objects_to_save: # can grow during iteration
+ if obj is not None: obj._before_save_()
+
+ cache.query_results.clear()
+ modified_m2m = cache._calc_modified_m2m()
+ for attr, (added, removed) in iteritems(modified_m2m):
+ if not removed: continue
+ attr.remove_m2m(removed)
+ for obj in cache.objects_to_save:
+ if obj is not None: obj._save_()
+ for attr, (added, removed) in iteritems(modified_m2m):
+ if not added: continue
+ attr.add_m2m(added)
+
+ cache.max_id_cache.clear()
+ cache.modified_collections.clear()
+ cache.objects_to_save[:] = ()
+ cache.modified = False
+
+ cache.call_after_save_hooks()
+ else:
+ if cache.modified: throw(TransactionError,
+ 'Recursion depth limit reached in obj._after_save_() call')
+ finally:
+ if not cache.in_transaction:
+ cache.immediate = prev_immediate
def call_after_save_hooks(cache):
saved_objects = cache.saved_objects
cache.saved_objects = []
@@ -1627,7 +1941,7 @@ def _calc_modified_m2m(cache):
cache.modified_collections.clear()
return modified_m2m
def update_simple_index(cache, obj, attr, old_val, new_val, undo):
- assert old_val != new_val
+ if old_val == new_val: return
cache_index = cache.indexes[attr]
if new_val is not None:
obj2 = cache_index.setdefault(new_val, obj)
@@ -1636,7 +1950,7 @@ def update_simple_index(cache, obj, attr, old_val, new_val, undo):
if old_val is not None: del cache_index[old_val]
undo.append((cache_index, old_val, new_val))
def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval):
- assert old_dbval != new_dbval
+ if old_dbval == new_dbval: return
cache_index = cache.indexes[attr]
if new_dbval is not None:
obj2 = cache_index.setdefault(new_dbval, obj)
@@ -1649,6 +1963,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo):
if None in prev_vals: prev_vals = None
if None in new_vals: new_vals = None
if prev_vals is None and new_vals is None: return
+ if prev_vals == new_vals: return
cache_index = cache.indexes[attrs]
if new_vals is not None:
obj2 = cache_index.setdefault(new_vals, obj)
@@ -1659,6 +1974,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo):
if prev_vals is not None: del cache_index[prev_vals]
undo.append((cache_index, prev_vals, new_vals))
def db_update_composite_index(cache, obj, attrs, prev_vals, new_vals):
+ if prev_vals == new_vals: return
cache_index = cache.indexes[attrs]
if None not in new_vals:
obj2 = cache_index.setdefault(new_vals, obj)
@@ -1700,7 +2016,8 @@ class Attribute(object):
'id', 'pk_offset', 'pk_columns_offset', 'py_type', 'sql_type', 'entity', 'name', \
'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \
'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \
- 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden'
+ 'cascade_delete', 'index', 'reverse_index', 'original_default', 'sql_default', 'py_check', 'hidden', \
+ 'optimistic', 'fk_name', 'type_has_empty_value', 'interleave'
def __deepcopy__(attr, memo):
return attr # Attribute cannot be cloned by deepcopy()
@cut_traceback
@@ -1720,12 +2037,13 @@ def __init__(attr, py_type, *args, **kwargs):
if attr.is_pk: attr.pk_offset = 0
else: attr.pk_offset = None
attr.id = next(attr_id_counter)
- if not isinstance(py_type, (type, basestring, types.FunctionType)):
+ if not isinstance(py_type, (type, basestring, types.FunctionType, Array)):
if py_type is datetime: throw(TypeError,
'datetime is the module and cannot be used as attribute type. Use datetime.datetime instead')
throw(TypeError, 'Incorrect type of attribute: %r' % py_type)
attr.py_type = py_type
attr.is_string = type(py_type) is type and issubclass(py_type, basestring)
+ attr.type_has_empty_value = attr.is_string or hasattr(attr.py_type, 'default_empty_value')
attr.is_collection = isinstance(attr, Collection)
attr.is_relation = isinstance(attr.py_type, (EntityMeta, basestring, types.FunctionType))
attr.is_basic = not attr.is_collection and not attr.is_relation
@@ -1759,15 +2077,19 @@ def __init__(attr, py_type, *args, **kwargs):
if len(attr.columns) == 1: attr.column = attr.columns[0]
else: attr.columns = []
attr.index = kwargs.pop('index', None)
+ attr.reverse_index = kwargs.pop('reverse_index', None)
+ attr.fk_name = kwargs.pop('fk_name', None)
attr.col_paths = []
attr._columns_checked = False
attr.composite_keys = []
attr.lazy = kwargs.pop('lazy', getattr(py_type, 'lazy', False))
attr.lazy_sql_cache = None
attr.is_volatile = kwargs.pop('volatile', False)
+ attr.optimistic = kwargs.pop('optimistic', None)
attr.sql_default = kwargs.pop('sql_default', None)
attr.py_check = kwargs.pop('py_check', None)
attr.hidden = kwargs.pop('hidden', False)
+ attr.interleave = kwargs.pop('interleave', None)
attr.kwargs = kwargs
attr.converters = []
def _init_(attr, entity, name):
@@ -1796,8 +2118,8 @@ def _init_(attr, entity, name):
'Default value for required attribute %s cannot be empty string' % attr)
elif attr.default is None and not attr.nullable: throw(TypeError,
'Default value for non-nullable attribute %s cannot be set to None' % attr)
- elif attr.is_string and not attr.is_required and not attr.nullable:
- attr.default = ''
+ elif attr.type_has_empty_value and not attr.is_required and not attr.nullable:
+ attr.default = '' if attr.is_string else attr.py_type.default_empty_value()
else:
attr.default = None
@@ -1820,6 +2142,12 @@ def _init_(attr, entity, name):
elif attr.is_unique: throw(TypeError, 'Unique attribute %s cannot be of type float' % attr)
if attr.is_volatile and (attr.is_pk or attr.is_collection): throw(TypeError,
'%s attribute %s cannot be volatile' % (attr.__class__.__name__, attr))
+
+ if attr.interleave is not None:
+ if attr.is_collection: throw(TypeError,
+ '`interleave` option cannot be specified for %s attribute %r' % (attr.__class__.__name__, attr))
+ if attr.interleave not in (True, False): throw(TypeError,
+ '`interleave` option value should be True, False or None. Got: %r' % attr.interleave)
def linked(attr):
reverse = attr.reverse
if attr.cascade_delete is None:
@@ -1831,13 +2159,25 @@ def linked(attr):
if reverse.is_collection: throw(TypeError,
"'cascade_delete' option cannot be set for attribute %s, "
"because reverse attribute %s is collection" % (attr, reverse))
+ if attr.is_collection and not reverse.is_collection:
+ if attr.fk_name is not None:
+ throw(TypeError, 'You should specify fk_name in %s instead of %s' % (reverse, attr))
+ for option in attr.kwargs:
+ throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option))
@cut_traceback
def __repr__(attr):
owner_name = attr.entity.__name__ if attr.entity else '?'
return '%s.%s' % (owner_name, attr.name or '?')
def __lt__(attr, other):
return attr.id < other.id
+ def _get_entity(attr, obj, entity):
+ if entity is not None:
+ return entity
+ if obj is not None:
+ return obj.__class__
+ return attr.entity
def validate(attr, val, obj=None, entity=None, from_db=False):
+ val = deref_proxy(val)
if val is None:
if not attr.nullable and not from_db and not attr.is_required:
# for required attribute the exception will be thrown later with another message
@@ -1850,10 +2190,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False):
if callable(default): val = default()
else: val = default
- if entity is not None: pass
- elif obj is not None: entity = obj.__class__
- else: entity = attr.entity
-
+ entity = attr._get_entity(obj, entity)
reverse = attr.reverse
if not reverse:
if isinstance(val, Entity): throw(TypeError, 'Attribute %s must be of %s type. Got: %s'
@@ -1865,7 +2202,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False):
if converter is not None:
try:
if from_db: return converter.sql2py(val)
- val = converter.validate(val)
+ val = converter.validate(val, obj)
except UnicodeDecodeError as e:
throw(ValueError, 'Value for attribute %s cannot be converted to %s: %s'
% (attr, unicode.__name__, truncate_repr(val)))
@@ -1880,29 +2217,30 @@ def validate(attr, val, obj=None, entity=None, from_db=False):
except TypeError: throw(TypeError, 'Attribute %s must be of %s type. Got: %r'
% (attr, rentity.__name__, val))
else:
- if obj is not None: cache = obj._session_cache_
+ if obj is not None and obj._status_ is not None: cache = obj._session_cache_
else: cache = entity._database_._get_cache()
if cache is not val._session_cache_:
throw(TransactionError, 'An attempt to mix objects belonging to different transactions')
if attr.py_check is not None and not attr.py_check(val):
throw(ValueError, 'Check for attribute %s failed. Value: %s' % (attr, truncate_repr(val)))
return val
- def parse_value(attr, row, offsets):
+ def parse_value(attr, row, offsets, dbvals_deduplication_cache):
assert len(attr.columns) == len(offsets)
if not attr.reverse:
if len(offsets) > 1: throw(NotImplementedError)
offset = offsets[0]
- val = attr.validate(row[offset], None, attr.entity, from_db=True)
+ dbval = attr.validate(row[offset], None, attr.entity, from_db=True)
+ dbval = deduplicate(dbval, dbvals_deduplication_cache)
else:
- vals = [ row[offset] for offset in offsets ]
- if None in vals:
- assert len(set(vals)) == 1
- val = None
- else: val = attr.py_type._get_by_raw_pkval_(vals)
- return val
+ dbvals = [ row[offset] for offset in offsets ]
+ if None in dbvals:
+ assert len(set(dbvals)) == 1
+ dbval = None
+ else: dbval = attr.py_type._get_by_raw_pkval_(dbvals)
+ return dbval
def load(attr, obj):
- if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot load attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name))
+ cache = obj._session_cache_
+ if cache is None or not cache.is_alive: throw_db_session_is_over('load attribute', obj, attr)
if not attr.columns:
reverse = attr.reverse
assert reverse is not None and reverse.columns
@@ -1919,7 +2257,7 @@ def load(attr, obj):
from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ]
pk_columns = entity._pk_columns_
pk_converters = entity._pk_converters_
- criteria_list = [ [ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]
+ criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]
for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ]
sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ]
sql, adapter = database._ast2sql(sql_ast)
@@ -1929,7 +2267,7 @@ def load(attr, obj):
arguments = adapter(obj._get_raw_pkval_())
cursor = database._exec_sql(sql, arguments)
row = cursor.fetchone()
- dbval = attr.parse_value(row, offsets)
+ dbval = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache)
attr.db_set(obj, dbval)
else: obj._load_()
return obj._vals_[attr]
@@ -1937,26 +2275,26 @@ def load(attr, obj):
def __get__(attr, obj, cls=None):
if obj is None: return attr
if attr.pk_offset is not None: return attr.get(obj)
- result = attr.get(obj)
+ value = attr.get(obj)
bit = obj._bits_except_volatile_[attr]
wbits = obj._wbits_
if wbits is not None and not wbits & bit: obj._rbits_ |= bit
- return result
+ return value
def get(attr, obj):
if attr.pk_offset is None and obj._status_ in ('deleted', 'cancelled'):
throw_object_was_deleted(obj)
vals = obj._vals_
- if vals is None: throw_db_session_is_over(obj, attr)
+ if vals is None: throw_db_session_is_over('read value of', obj, attr)
val = vals[attr] if attr in vals else attr.load(obj)
if val is not None and attr.reverse and val._subclasses_ and val._status_ not in ('deleted', 'cancelled'):
- seeds = obj._session_cache_.seeds[val._pk_attrs_]
- if val in seeds: val._load_()
+ cache = obj._session_cache_
+ if cache is not None and val in cache.seeds[val._pk_attrs_]:
+ val._load_()
return val
@cut_traceback
def __set__(attr, obj, new_val, undo_funcs=None):
cache = obj._session_cache_
- if not cache.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot assign new value to attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name))
+ if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr)
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
reverse = attr.reverse
new_val = attr.validate(new_val, obj, from_db=False)
@@ -2039,25 +2377,28 @@ def undo_func():
raise
def db_set(attr, obj, new_dbval, is_reverse_call=False):
cache = obj._session_cache_
- assert cache.is_alive
+ assert cache is not None and cache.is_alive
assert obj._status_ not in created_or_deleted_statuses
assert attr.pk_offset is None
if new_dbval is NOT_LOADED: assert is_reverse_call
old_dbval = obj._dbvals_.get(attr, NOT_LOADED)
+ if old_dbval is not NOT_LOADED:
+ if old_dbval == new_dbval or (
+ not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)):
+ return
- if attr.py_type is float:
- if old_dbval is NOT_LOADED: pass
- elif attr.converters[0].equals(old_dbval, new_dbval): return
- elif old_dbval == new_dbval: return
-
- bit = obj._bits_[attr]
+ bit = obj._bits_except_volatile_[attr]
if obj._rbits_ & bit:
assert old_dbval is not NOT_LOADED
- if new_dbval is NOT_LOADED: diff = ''
- else: diff = ' (was: %s, now: %s)' % (old_dbval, new_dbval)
- throw(UnrepeatableReadError,
- 'Value of %s.%s for %s was updated outside of current transaction%s'
- % (obj.__class__.__name__, attr.name, obj, diff))
+ msg = 'Value of %s for %s was updated outside of current transaction' % (attr, obj)
+ if new_dbval is not NOT_LOADED:
+ msg = '%s (was: %s, now: %s)' % (msg, old_dbval, new_dbval)
+ elif isinstance(attr.reverse, Optional):
+ assert old_dbval is not None
+ msg = "Multiple %s objects linked with the same %s object. " \
+ "Maybe %s attribute should be Set instead of Optional" \
+ % (attr.entity.__name__, old_dbval, attr.reverse)
+ throw(UnrepeatableReadError, msg)
if new_dbval is NOT_LOADED: obj._dbvals_.pop(attr, None)
else: obj._dbvals_[attr] = new_dbval
@@ -2065,9 +2406,8 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False):
wbit = bool(obj._wbits_ & bit)
if not wbit:
old_val = obj._vals_.get(attr, NOT_LOADED)
- assert old_val == old_dbval
+ assert old_val == old_dbval, (old_val, old_dbval)
if attr.is_part_of_unique_index:
- cache = obj._session_cache_
if attr.is_unique: cache.db_update_simple_index(obj, attr, old_val, new_dbval)
get_val = obj._vals_.get
for attrs, i in attr.composite_keys:
@@ -2076,8 +2416,13 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False):
vals[i] = new_dbval
new_vals = tuple(vals)
cache.db_update_composite_index(obj, attrs, old_vals, new_vals)
- if new_dbval is NOT_LOADED: obj._vals_.pop(attr, None)
- else: obj._vals_[attr] = new_dbval
+ if new_dbval is NOT_LOADED:
+ obj._vals_.pop(attr, None)
+ elif attr.reverse:
+ obj._vals_[attr] = new_dbval
+ else:
+ assert len(attr.converters) == 1
+ obj._vals_[attr] = attr.converters[0].dbval2val(new_dbval, obj)
reverse = attr.reverse
if not reverse: pass
@@ -2171,6 +2516,8 @@ def describe(attr):
options = []
if attr.args: options.append(', '.join(imap(str, attr.args)))
if attr.auto: options.append('auto=True')
+ for k, v in sorted(attr.kwargs.items()):
+ options.append('%s=%r' % (k, v))
if not isinstance(attr, PrimaryKey) and attr.is_unique: options.append('unique=True')
if attr.default is not None: options.append('default=%r' % attr.default)
if not options: options = ''
@@ -2186,7 +2533,13 @@ class Required(Attribute):
def validate(attr, val, obj=None, entity=None, from_db=False):
val = Attribute.validate(attr, val, obj, entity, from_db)
if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)):
- throw(ValueError, 'Attribute %s is required' % (attr if obj is None else '%r.%s' % (obj, attr.name)))
+ if not from_db:
+ throw(ValueError, 'Attribute %s is required' % (
+ attr if obj is None or obj._status_ is None else '%r.%s' % (obj, attr.name)))
+ else:
+ warnings.warn('Database contains %s for required attribute %s'
+ % ('NULL' if val is None else 'empty string', attr),
+ DatabaseContainsIncorrectEmptyValue)
return val
class Discriminator(Required):
@@ -2217,7 +2570,7 @@ def process_entity_inheritance(attr, entity):
entity._discriminator_ = entity.__name__
discr_value = entity._discriminator_
if discr_value is not None:
- try: entity._discriminator_ = discr_value = attr.validate(discr_value)
+ try: entity._discriminator_ = discr_value = attr.validate(discr_value, None, entity)
except ValueError: throw(TypeError,
"Incorrect discriminator value is set for %s attribute '%s' of '%s' type: %r"
% (entity.__name__, attr.name, attr.py_type.__name__, discr_value))
@@ -2228,10 +2581,18 @@ def process_entity_inheritance(attr, entity):
% (entity.__name__, attr.name, attr.py_type.__name__))
attr.code2cls[discr_value] = entity
def validate(attr, val, obj=None, entity=None, from_db=False):
- if from_db: return val
- elif val is DEFAULT:
+ if from_db:
+ return val
+ entity = attr._get_entity(obj, entity)
+ if val is DEFAULT:
assert entity is not None
return entity._discriminator_
+ if val != entity._discriminator_:
+ for cls in entity._subclasses_:
+ if val == cls._discriminator_:
+ break
+ else: throw(TypeError, 'Invalid discriminator attribute value for %s. Expected: %r, got: %r'
+ % (entity.__name__, entity._discriminator_, val))
return Attribute.validate(attr, val, obj, entity)
def load(attr, obj):
assert False # pragma: no cover
@@ -2332,7 +2693,7 @@ def __new__(cls, *args, **kwargs):
class Collection(Attribute):
__slots__ = 'table', 'wrapper_class', 'symmetric', 'reverse_column', 'reverse_columns', \
'nplus1_threshold', 'cached_load_sql', 'cached_add_m2m_sql', 'cached_remove_m2m_sql', \
- 'cached_count_sql', 'cached_empty_sql'
+ 'cached_count_sql', 'cached_empty_sql', 'reverse_fk_name'
def __init__(attr, py_type, *args, **kwargs):
if attr.__class__ is Collection: throw(TypeError, "'Collection' is abstract type")
table = kwargs.pop('table', None) # TODO: rename table to link_table or m2m_table
@@ -2365,8 +2726,9 @@ def __init__(attr, py_type, *args, **kwargs):
if len(attr.reverse_columns) == 1: attr.reverse_column = attr.reverse_columns[0]
else: attr.reverse_columns = []
+ attr.reverse_fk_name = kwargs.pop('reverse_fk_name', None)
+
attr.nplus1_threshold = kwargs.pop('nplus1_threshold', 1)
- for option in attr.kwargs: throw(TypeError, 'Unknown option %r' % option)
attr.cached_load_sql = {}
attr.cached_add_m2m_sql = None
attr.cached_remove_m2m_sql = None
@@ -2379,8 +2741,11 @@ def _init_(attr, entity, name):
if attr.default is not None:
throw(TypeError, 'Default value could not be set for collection attribute')
attr.symmetric = (attr.py_type == entity.__name__ and attr.reverse == name)
- if not attr.symmetric and attr.reverse_columns: throw(TypeError,
- "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only")
+ if not attr.symmetric:
+ if attr.reverse_columns:
+ throw(TypeError, "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only")
+ if attr.reverse_index:
+ throw(TypeError, "'reverse_index' option can be set for symmetric relations only")
if attr.py_check is not None:
throw(NotImplementedError, "'py_check' parameter is not supported for collection attributes")
def load(attr, obj):
@@ -2411,7 +2776,7 @@ def param(i, j, converter):
else:
return [ 'PARAM', (i, j, None), converter ]
if batch_size == 1:
- return [ [ 'EQ', [ 'COLUMN', alias, column ], param(start, j, converter) ]
+ return [ [ converter.EQ, [ 'COLUMN', alias, column ], param(start, j, converter) ]
for j, (column, converter) in enumerate(izip(columns, converters)) ]
if len(columns) == 1:
column = columns[0]
@@ -2426,7 +2791,7 @@ def param(i, j, converter):
condition = [ 'IN', row, param_list ]
return [ condition ]
else:
- conditions = [ [ 'AND' ] + [ [ 'EQ', [ 'COLUMN', alias, column ], param(i+start, j, converter) ]
+ conditions = [ [ 'AND' ] + [ [ converter.EQ, [ 'COLUMN', alias, column ], param(i+start, j, converter) ]
for j, (column, converter) in enumerate(izip(columns, converters)) ]
for i in xrange(batch_size) ]
return [ [ 'OR' ] + conditions ]
@@ -2434,6 +2799,7 @@ def param(i, j, converter):
class Set(Collection):
__slots__ = []
def validate(attr, val, obj=None, entity=None, from_db=False):
+ val = deref_proxy(val)
assert val is not NOT_LOADED
if val is DEFAULT: return set()
reverse = attr.reverse
@@ -2450,19 +2816,76 @@ def validate(attr, val, obj=None, entity=None, from_db=False):
except TypeError: throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r'
% (entity.__name__, attr.name, rentity.__name__, val))
for item in items:
+ item = deref_proxy(item)
if not isinstance(item, rentity):
throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r'
% (entity.__name__, attr.name, rentity.__name__, item))
- if obj is not None: cache = obj._session_cache_
+ if obj is not None and obj._status_ is not None: cache = obj._session_cache_
else: cache = entity._database_._get_cache()
for item in items:
if item._session_cache_ is not cache:
throw(TransactionError, 'An attempt to mix objects belonging to different transactions')
return items
+ def prefetch_load_all(attr, objects):
+ entity = attr.entity
+ database = entity._database_
+ cache = database._get_cache()
+ if cache is None or not cache.is_alive:
+ throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over')
+ reverse = attr.reverse
+ rentity = reverse.entity
+ objects = sorted(objects, key=entity._get_raw_pkval_)
+ max_batch_size = database.provider.max_params_count // len(entity._pk_columns_)
+ result = set()
+ if not reverse.is_collection:
+ for i in xrange(0, len(objects), max_batch_size):
+ batch = objects[i:i+max_batch_size]
+ sql, adapter, attr_offsets = rentity._construct_batchload_sql_(len(batch), reverse)
+ arguments = adapter(batch)
+ cursor = database._exec_sql(sql, arguments)
+ result.update(rentity._fetch_objects(cursor, attr_offsets))
+ else:
+ pk_len = len(entity._pk_columns_)
+ m2m_dict = defaultdict(set)
+ for i in xrange(0, len(objects), max_batch_size):
+ batch = objects[i:i+max_batch_size]
+ sql, adapter = attr.construct_sql_m2m(len(batch))
+ arguments = adapter(batch)
+ cursor = database._exec_sql(sql, arguments)
+ if len(batch) > 1:
+ for row in cursor.fetchall():
+ obj = entity._get_by_raw_pkval_(row[:pk_len])
+ item = rentity._get_by_raw_pkval_(row[pk_len:])
+ m2m_dict[obj].add(item)
+ else:
+ obj = batch[0]
+ m2m_dict[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()}
+
+ for obj2, items in iteritems(m2m_dict):
+ setdata2 = obj2._vals_.get(attr)
+ if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData()
+ else:
+ phantoms = setdata2 - items
+ if setdata2.added: phantoms -= setdata2.added
+ if phantoms: throw(UnrepeatableReadError,
+ 'Phantom object %s disappeared from collection %s.%s'
+ % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name))
+ items -= setdata2
+ if setdata2.removed: items -= setdata2.removed
+ setdata2 |= items
+ reverse.db_reverse_add(items, obj2)
+ result.update(items)
+ for obj in objects:
+ setdata = obj._vals_.get(attr)
+ if setdata is None:
+ setdata = obj._vals_[attr] = SetData()
+ setdata.is_fully_loaded = True
+ setdata.absent = None
+ setdata.count = len(setdata)
+ return result
def load(attr, obj, items=None):
cache = obj._session_cache_
- if not cache.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot load collection %s.%s: the database session is over' % (safe_repr(obj), attr.name))
+ if cache is None or not cache.is_alive: throw_db_session_is_over('load collection', obj, attr)
assert obj._status_ not in del_statuses
setdata = obj._vals_.get(attr)
if setdata is None: setdata = obj._vals_[attr] = SetData()
@@ -2470,14 +2893,13 @@ def load(attr, obj, items=None):
entity = attr.entity
reverse = attr.reverse
rentity = reverse.entity
- if not reverse: throw(NotImplementedError)
database = obj._database_
if cache is not database._get_cache():
throw(TransactionError, "Transaction of object %s belongs to different thread")
if items:
if not reverse.is_collection:
- items = set(item for item in items if reverse not in item._vals_)
+ items = {item for item in items if reverse not in item._vals_}
else:
items = set(items)
items -= setdata
@@ -2497,15 +2919,14 @@ def load(attr, obj, items=None):
items.append(obj)
arguments = adapter(items)
cursor = database._exec_sql(sql, arguments)
- loaded_items = set(imap(rentity._get_by_raw_pkval_, cursor.fetchall()))
+ loaded_items = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()}
setdata |= loaded_items
reverse.db_reverse_add(loaded_items, obj)
return setdata
counter = cache.collection_statistics.setdefault(attr, 0)
nplus1_threshold = attr.nplus1_threshold
- prefetching = options.PREFETCHING and not attr.lazy and nplus1_threshold is not None \
- and (counter >= nplus1_threshold or cache.noflush_counter)
+ prefetching = not attr.lazy and nplus1_threshold is not None and counter >= nplus1_threshold
objects = [ obj ]
setdata_list = [ setdata ]
@@ -2540,16 +2961,16 @@ def load(attr, obj, items=None):
items = d.get(obj2)
if items is None: items = d[obj2] = set()
items.add(item)
- else: d[obj] = set(imap(rentity._get_by_raw_pkval_, cursor.fetchall()))
+ else: d[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()}
for obj2, items in iteritems(d):
setdata2 = obj2._vals_.get(attr)
- if setdata2 is None: setdata2 = obj._vals_[attr] = SetData()
+ if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData()
else:
phantoms = setdata2 - items
if setdata2.added: phantoms -= setdata2.added
if phantoms: throw(UnrepeatableReadError,
'Phantom object %s disappeared from collection %s.%s'
- % (safe_repr(phantoms.pop()), safe_repr(obj), attr.name))
+ % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name))
items -= setdata2
if setdata2.removed: items -= setdata2.removed
setdata2 |= items
@@ -2599,7 +3020,7 @@ def construct_sql_m2m(attr, batch_size=1, items_count=0):
return sql, adapter
def copy(attr, obj):
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
- if obj._vals_ is None: throw_db_session_is_over(obj, attr)
+ if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr)
setdata = obj._vals_.get(attr)
if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj)
reverse = attr.reverse
@@ -2623,8 +3044,7 @@ def __set__(attr, obj, new_items, undo_funcs=None):
if isinstance(new_items, SetInstance) and new_items._obj_ is obj and new_items._attr_ is attr:
return # after += or -=
cache = obj._session_cache_
- if not cache.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr))
+ if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr)
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
with cache.flush_disabled():
new_items = attr.validate(new_items, obj)
@@ -2794,7 +3214,7 @@ def remove_m2m(attr, removed):
columns = reverse.columns + attr.columns
converters = reverse.converters + attr.converters
for i, (column, converter) in enumerate(izip(columns, converters)):
- where_list.append([ 'EQ', ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ])
+ where_list.append([ converter.EQ, ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ])
from_ast = [ 'FROM', [ None, 'TABLE', attr.table ] ]
sql_ast = [ 'DELETE', None, from_ast, where_list ]
sql, adapter = database._ast2sql(sql_ast)
@@ -2842,6 +3262,35 @@ def unpickle_setwrapper(obj, attrname, items):
setdata.count = len(setdata)
return wrapper
+
+class SetIterator(object):
+ def __init__(self, wrapper):
+ self._wrapper = wrapper
+ self._query = None
+ self._iter = None
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ if self._iter is None:
+ self._iter = iter(self._wrapper.copy())
+ return next(self._iter)
+
+ __next__ = next
+
+ def _get_query(self):
+ if self._query is None:
+ self._query = self._wrapper.select()
+ return self._query
+
+ def _get_type_(self):
+ return QueryType(self._get_query())
+
+ def _normalize_var(self, query_type):
+ return query_type, self._get_query()
+
+
class SetInstance(object):
__slots__ = '_obj_', '_attr_', '_attrnames_'
_parent_ = None
@@ -2859,7 +3308,8 @@ def __repr__(wrapper):
return '<%s %r.%s>' % (wrapper.__class__.__name__, wrapper._obj_, wrapper._attr_.name)
@cut_traceback
def __str__(wrapper):
- if not wrapper._obj_._session_cache_.is_alive: content = '...'
+ cache = wrapper._obj_._session_cache_
+ if cache is None or not cache.is_alive: content = '...'
else: content = ', '.join(imap(str, wrapper))
return '%s([%s])' % (wrapper.__class__.__name__, content)
@cut_traceback
@@ -2867,7 +3317,7 @@ def __nonzero__(wrapper):
attr = wrapper._attr_
obj = wrapper._obj_
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
- if obj._vals_ is None: throw_db_session_is_over(obj, attr)
+ if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr)
setdata = obj._vals_.get(attr)
if setdata is None: setdata = attr.load(obj)
if setdata: return True
@@ -2878,7 +3328,7 @@ def is_empty(wrapper):
attr = wrapper._attr_
obj = wrapper._obj_
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
- if obj._vals_ is None: throw_db_session_is_over(obj, attr)
+ if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr)
setdata = obj._vals_.get(attr)
if setdata is None: setdata = obj._vals_[attr] = SetData()
elif setdata.is_fully_loaded: return not setdata
@@ -2892,7 +3342,7 @@ def is_empty(wrapper):
if cached_sql is None:
where_list = [ 'WHERE' ]
for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)):
- where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ])
+ where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ])
if not reverse.is_collection:
table_name = rentity._table_
select_list, attr_offsets = rentity._construct_select_clause_()
@@ -2901,7 +3351,7 @@ def is_empty(wrapper):
select_list = [ 'ALL' ] + [ [ 'COLUMN', None, column ] for column in attr.columns ]
attr_offsets = None
sql_ast = [ 'SELECT', select_list, [ 'FROM', [ None, 'TABLE', table_name ] ],
- where_list, [ 'LIMIT', [ 'VALUE', 1 ] ] ]
+ where_list, [ 'LIMIT', 1 ] ]
sql, adapter = database._ast2sql(sql_ast)
attr.cached_empty_sql = sql, adapter, attr_offsets
else: sql, adapter, attr_offsets = cached_sql
@@ -2924,7 +3374,7 @@ def __len__(wrapper):
attr = wrapper._attr_
obj = wrapper._obj_
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
- if obj._vals_ is None: throw_db_session_is_over(obj, attr)
+ if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr)
setdata = obj._vals_.get(attr)
if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj)
return len(setdata)
@@ -2934,10 +3384,11 @@ def count(wrapper):
obj = wrapper._obj_
cache = obj._session_cache_
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
- if obj._vals_ is None: throw_db_session_is_over(obj, attr)
+ if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr)
setdata = obj._vals_.get(attr)
if setdata is None: setdata = obj._vals_[attr] = SetData()
elif setdata.count is not None: return setdata.count
+ if cache is None or not cache.is_alive: throw_db_session_is_over('read value of', obj, attr)
entity = attr.entity
reverse = attr.reverse
database = entity._database_
@@ -2945,10 +3396,10 @@ def count(wrapper):
if cached_sql is None:
where_list = [ 'WHERE' ]
for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)):
- where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ])
+ where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ])
if not reverse.is_collection: table_name = reverse.entity._table_
else: table_name = attr.table
- sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ],
+ sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ],
[ 'FROM', [ None, 'TABLE', table_name ] ], where_list ]
sql, adapter = database._ast2sql(sql_ast)
attr.cached_count_sql = sql, adapter
@@ -2962,7 +3413,7 @@ def count(wrapper):
return setdata.count
@cut_traceback
def __iter__(wrapper):
- return iter(wrapper.copy())
+ return SetIterator(wrapper)
@cut_traceback
def __eq__(wrapper, other):
if isinstance(other, SetInstance):
@@ -2985,8 +3436,10 @@ def __contains__(wrapper, item):
attr = wrapper._attr_
obj = wrapper._obj_
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
- if obj._vals_ is None: throw_db_session_is_over(obj, attr)
+ if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr)
if not isinstance(item, attr.py_type): return False
+ if item._session_cache_ is not obj._session_cache_:
+ throw(TransactionError, 'An attempt to mix objects belonging to different transactions')
reverse = attr.reverse
if not reverse.is_collection:
@@ -3021,15 +3474,13 @@ def create(wrapper, **kwargs):
kwargs[reverse.name] = wrapper._obj_
item_type = attr.py_type
item = item_type(**kwargs)
- wrapper.add(item)
return item
@cut_traceback
def add(wrapper, new_items):
obj = wrapper._obj_
attr = wrapper._attr_
cache = obj._session_cache_
- if not cache.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr))
+ if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr)
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
with cache.flush_disabled():
reverse = attr.reverse
@@ -3068,8 +3519,7 @@ def remove(wrapper, items):
obj = wrapper._obj_
attr = wrapper._attr_
cache = obj._session_cache_
- if not cache.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr))
+ if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr)
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
with cache.flush_disabled():
reverse = attr.reverse
@@ -3111,8 +3561,8 @@ def __isub__(wrapper, items):
def clear(wrapper):
obj = wrapper._obj_
attr = wrapper._attr_
- if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver,
- 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr))
+ cache = obj._session_cache_
+ if cache is None or not obj._session_cache_.is_alive: throw_db_session_is_over('change collection', obj, attr)
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
attr.__set__(obj, ())
@cut_traceback
@@ -3128,16 +3578,18 @@ def select(wrapper, *args):
s = 'lambda item: JOIN(obj in item.%s)' if reverse.is_collection else 'lambda item: item.%s == obj'
query = query.filter(s % reverse.name, {'obj' : obj, 'JOIN': JOIN})
if args:
- func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=3)
+ func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+1)
query = query.filter(func, globals, locals)
return query
filter = select
- def limit(wrapper, limit, offset=None):
+ def limit(wrapper, limit=None, offset=None):
return wrapper.select().limit(limit, offset)
def page(wrapper, pagenum, pagesize=10):
return wrapper.select().page(pagenum, pagesize)
def order_by(wrapper, *args):
return wrapper.select().order_by(*args)
+ def sort_by(wrapper, *args):
+ return wrapper.select().sort_by(*args)
def random(wrapper, limit):
return wrapper.select().random(limit)
@@ -3168,7 +3620,8 @@ def distinct(multiset):
return multiset._items_.copy()
@cut_traceback
def __repr__(multiset):
- if multiset._obj_._session_cache_.is_alive:
+ cache = multiset._obj_._session_cache_
+ if cache is not None and cache.is_alive:
size = builtins.sum(itervalues(multiset._items_))
if size == 1: size_str = ' (1 item)'
else: size_str = ' (%d items)' % size
@@ -3305,11 +3758,26 @@ def __init__(entity, name, bases, cls_dict):
new_attrs.append(attr)
new_attrs.sort(key=attrgetter('id'))
+ interleave_attrs = []
+ for attr in new_attrs:
+ if attr.interleave is not None:
+ if attr.interleave:
+ interleave_attrs.append(attr)
+ entity._interleave_ = None
+ if interleave_attrs:
+ if len(interleave_attrs) > 1: throw(TypeError,
+ 'only one attribute may be marked as interleave. Got: %s'
+ % ', '.join(repr(attr) for attr in interleave_attrs))
+ interleave = interleave_attrs[0]
+ if not interleave.is_relation: throw(TypeError,
+ 'Interleave attribute should be part of relationship. Got: %r' % attr)
+ entity._interleave_ = interleave
+
indexes = entity._indexes_ = entity.__dict__.get('_indexes_', [])
for attr in new_attrs:
if attr.is_unique: indexes.append(Index(attr, is_pk=isinstance(attr, PrimaryKey)))
for index in indexes: index._init_(entity)
- primary_keys = set(index.attrs for index in indexes if index.is_pk)
+ primary_keys = {index.attrs for index in indexes if index.is_pk}
if direct_bases:
if primary_keys: throw(ERDiagramError, 'Primary key cannot be redefined in derived classes')
base_indexes = []
@@ -3317,7 +3785,7 @@ def __init__(entity, name, bases, cls_dict):
for index in base._indexes_:
if index not in base_indexes and index not in indexes: base_indexes.append(index)
indexes[:0] = base_indexes
- primary_keys = set(index.attrs for index in indexes if index.is_pk)
+ primary_keys = {index.attrs for index in indexes if index.is_pk}
if len(primary_keys) > 1: throw(ERDiagramError, 'Only one primary key can be defined in each entity class')
elif not primary_keys:
@@ -3346,7 +3814,7 @@ def __init__(entity, name, bases, cls_dict):
entity._new_attrs_ = new_attrs
entity._attrs_ = base_attrs + new_attrs
- entity._adict_ = dict((attr.name, attr) for attr in entity._attrs_)
+ entity._adict_ = {attr.name: attr for attr in entity._attrs_}
entity._subclass_attrs_ = []
entity._subclass_adict_ = {}
for base in entity._all_bases_:
@@ -3438,7 +3906,7 @@ def _link_reverse_attrs_(entity):
database = entity._database_
for attr in entity._new_attrs_:
py_type = attr.py_type
- if not issubclass(py_type, Entity): continue
+ if not isinstance(py_type, EntityMeta): continue
entity2 = py_type
if entity2._database_ is not database:
@@ -3484,6 +3952,12 @@ def _link_reverse_attrs_(entity):
attr2.reverse = attr
attr.linked()
attr2.linked()
+ def _check_table_options_(entity):
+ if entity._root_ is not entity:
+ if '_table_options_' in entity.__dict__: throw(TypeError,
+ 'Cannot redefine %s options in %s entity' % (entity._root_.__name__, entity.__name__))
+ elif not hasattr(entity, '_table_options_'):
+ entity._table_options_ = {}
def _get_pk_columns_(entity):
if entity._pk_columns_ is not None: return entity._pk_columns_
pk_columns = []
@@ -3503,64 +3977,51 @@ def _get_pk_columns_(entity):
return pk_columns
def __iter__(entity):
return EntityIter(entity)
- def _normalize_args_(entity, kwargs, setdefault=False):
- avdict = {}
- if setdefault:
- for name in kwargs:
- if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name)
- for attr in entity._attrs_:
- val = kwargs.get(attr.name, DEFAULT)
- avdict[attr] = attr.validate(val, None, entity, from_db=False)
- else:
- get_attr = entity._adict_.get
- for name, val in iteritems(kwargs):
- attr = get_attr(name)
- if attr is None: throw(TypeError, 'Unknown attribute %r' % name)
- avdict[attr] = attr.validate(val, None, entity, from_db=False)
- if entity._pk_is_composite_:
- get_val = avdict.get
- pkval = tuple(get_val(attr) for attr in entity._pk_attrs_)
- if None in pkval: pkval = None
- else: pkval = avdict.get(entity._pk_attrs_[0])
- return pkval, avdict
@cut_traceback
def __getitem__(entity, key):
if type(key) is not tuple: key = (key,)
- if len(key) != len(entity._pk_attrs_):
- throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)'
- % (entity.__name__, len(key), len(entity._pk_attrs_)))
- kwargs = dict(izip(imap(attrgetter('name'), entity._pk_attrs_), key))
- return entity._find_one_(kwargs)
+ if len(key) == len(entity._pk_attrs_):
+ kwargs = {attr.name: value for attr, value in izip(entity._pk_attrs_, key)}
+ return entity._find_one_(kwargs)
+ if len(key) == len(entity._pk_columns_):
+ return entity._get_by_raw_pkval_(key, from_db=False, seed=False)
+
+ throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)'
+ % (entity.__name__, len(key), len(entity._pk_attrs_)))
@cut_traceback
def exists(entity, *args, **kwargs):
- if args: return entity._query_from_args_(args, kwargs, frame_depth=3).exists()
+ if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).exists()
try: obj = entity._find_one_(kwargs)
except ObjectNotFound: return False
except MultipleObjectsFoundError: return True
return True
@cut_traceback
def get(entity, *args, **kwargs):
- if args: return entity._query_from_args_(args, kwargs, frame_depth=3).get()
+ if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).get()
try: return entity._find_one_(kwargs) # can throw MultipleObjectsFoundError
except ObjectNotFound: return None
@cut_traceback
def get_for_update(entity, *args, **kwargs):
nowait = kwargs.pop('nowait', False)
- if args: return entity._query_from_args_(args, kwargs, frame_depth=3).for_update(nowait).get()
- try: return entity._find_one_(kwargs, True, nowait) # can throw MultipleObjectsFoundError
+ skip_locked = kwargs.pop('skip_locked', False)
+ if nowait and skip_locked:
+ throw(TypeError, 'nowait and skip_locked options are mutually exclusive')
+ if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1) \
+ .for_update(nowait, skip_locked).get()
+ try: return entity._find_one_(kwargs, True, nowait, skip_locked) # can throw MultipleObjectsFoundError
except ObjectNotFound: return None
@cut_traceback
def get_by_sql(entity, sql, globals=None, locals=None):
- objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=3) # can throw MultipleObjectsFoundError
+ objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=cut_traceback_depth+1) # can throw MultipleObjectsFoundError
if not objects: return None
assert len(objects) == 1
return objects[0]
@cut_traceback
def select(entity, *args):
- return entity._query_from_args_(args, kwargs=None, frame_depth=3)
+ return entity._query_from_args_(args, kwargs=None, frame_depth=cut_traceback_depth+1)
@cut_traceback
def select_by_sql(entity, sql, globals=None, locals=None):
- return entity._find_by_sql_(None, sql, globals, locals, frame_depth=3)
+ return entity._find_by_sql_(None, sql, globals, locals, frame_depth=cut_traceback_depth+1)
@cut_traceback
def select_random(entity, limit):
if entity._pk_is_composite_: return entity.select().random(limit)
@@ -3574,7 +4035,7 @@ def select_random(entity, limit):
if max_id is None:
max_id_sql = entity._cached_max_id_sql_
if max_id_sql is None:
- sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', [ 'COLUMN', None, pk.column ] ] ],
+ sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', None, [ 'COLUMN', None, pk.column ] ] ],
[ 'FROM', [ None, 'TABLE', entity._table_ ] ] ]
max_id_sql, adapter = database._ast2sql(sql_ast)
entity._cached_max_id_sql_ = max_id_sql
@@ -3623,15 +4084,24 @@ def select_random(entity, limit):
if obj in seeds: obj._load_()
if found_in_cache: shuffle(result)
return result
- def _find_one_(entity, kwargs, for_update=False, nowait=False):
+ def _find_one_(entity, kwargs, for_update=False, nowait=False, skip_locked=False):
if entity._database_.schema is None:
throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__)
- pkval, avdict = entity._normalize_args_(kwargs, False)
+ avdict = {}
+ get_attr = entity._adict_.get
+ for name, val in iteritems(kwargs):
+ attr = get_attr(name)
+ if attr is None: throw(TypeError, 'Unknown attribute %r' % name)
+ avdict[attr] = attr.validate(val, None, entity, from_db=False)
+ if entity._pk_is_composite_:
+ pkval = tuple(imap(avdict.get, entity._pk_attrs_))
+ if None in pkval: pkval = None
+ else: pkval = avdict.get(entity._pk_attrs_[0])
for attr in avdict:
if attr.is_collection:
throw(TypeError, 'Collection attribute %s cannot be specified as search criteria' % attr)
obj, unique = entity._find_in_cache_(pkval, avdict, for_update)
- if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait)
+ if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait, skip_locked)
if obj is None: throw(ObjectNotFound, entity, pkval)
return obj
def _find_in_cache_(entity, pkval, avdict, for_update=False):
@@ -3654,9 +4124,9 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False):
get_val = avdict.get
vals = tuple(get_val(attr) for attr in attrs)
if None in vals: continue
+ unique = True
cache_index = cache_indexes.get(attrs)
if cache_index is None: continue
- unique = True
obj = cache_index.get(vals)
if obj is not None: break
if obj is None:
@@ -3683,11 +4153,11 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False):
entity._set_rbits((obj,), avdict)
return obj, unique
return None, unique
- def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False):
+ def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False, skip_locked=False):
database = entity._database_
- query_attrs = dict((attr, value is None) for attr, value in iteritems(avdict))
+ query_attrs = {attr: value is None for attr, value in iteritems(avdict)}
limit = 2 if not unique else None
- sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait)
+ sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait, skip_locked)
arguments = adapter(avdict)
if for_update: database._get_cache().immediate = True
cursor = database._exec_sql(sql, arguments)
@@ -3719,11 +4189,12 @@ def _find_by_sql_(entity, max_fetch_count, sql, globals, locals, frame_depth):
objects = entity._fetch_objects(cursor, attr_offsets, max_fetch_count)
return objects
- def _construct_select_clause_(entity, alias=None, distinct=False,
- query_attrs=(), attrs_to_prefetch=(), all_attributes=False):
+ def _construct_select_clause_(entity, alias=None, distinct=False, query_attrs=(), all_attributes=False):
attr_offsets = {}
select_list = [ 'DISTINCT' ] if distinct else [ 'ALL' ]
root = entity._root_
+ pc = local.prefetch_context
+ attrs_to_prefetch = pc.attrs_to_prefetch_dict.get(entity, ()) if pc else ()
for attr in chain(root._attrs_, root._subclass_attrs_):
if not all_attributes and not issubclass(attr.entity, entity) \
and not issubclass(entity, attr.entity): continue
@@ -3738,12 +4209,13 @@ def _construct_select_clause_(entity, alias=None, distinct=False,
def _construct_discriminator_criteria_(entity, alias=None):
discr_attr = entity._discriminator_attr_
if discr_attr is None: return None
- code2cls = discr_attr.code2cls
discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in entity._subclasses_ ]
discr_values.append([ 'VALUE', entity._discriminator_])
return [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ]
def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True):
- query_key = batch_size, attr, from_seeds
+ pc = local.prefetch_context
+ attrs_to_prefetch = pc.get_frozen_attrs_to_prefetch(entity) if pc is not None else ()
+ query_key = batch_size, attr, from_seeds, attrs_to_prefetch
cached_sql = entity._batchload_sql_cache_.get(query_key)
if cached_sql is not None: return cached_sql
select_list, attr_offsets = entity._construct_select_clause_(all_attributes=True)
@@ -3763,12 +4235,10 @@ def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True):
cached_sql = sql, adapter, attr_offsets
entity._batchload_sql_cache_[query_key] = cached_sql
return cached_sql
- def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False):
- if limit and entity._database_.provider.dialect == 'MSSQL':
- order_by_pk = True # todo: use TOP 1 instead of FETCH NEXT and remove this line
- if nowait: assert for_update
+ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False, skip_locked=False):
+ if nowait or skip_locked: assert for_update
sorted_query_attrs = tuple(sorted(query_attrs.items()))
- query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait
+ query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait, skip_locked
cached_sql = entity._find_sql_cache_.get(query_key)
if cached_sql is not None: return cached_sql
select_list, attr_offsets = entity._construct_select_clause_(query_attrs=query_attrs)
@@ -3785,7 +4255,8 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda
if attr_is_none: where_list.append([ 'IS_NULL', [ 'COLUMN', None, attr.column ] ])
else:
if len(attr.converters) > 1: throw(NotImplementedError)
- where_list.append([ 'EQ', [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), attr.converters[0] ] ])
+ converter = attr.converters[0]
+ where_list.append([ converter.EQ, [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), converter ] ])
elif not attr.columns: throw(NotImplementedError)
else:
attr_entity = attr.py_type; assert attr_entity == attr.reverse.entity
@@ -3794,12 +4265,12 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda
where_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ])
else:
for j, (column, converter) in enumerate(izip(attr.columns, attr_entity._pk_converters_)):
- where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ])
+ where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ])
if not for_update: sql_ast = [ 'SELECT', select_list, from_list, where_list ]
- else: sql_ast = [ 'SELECT_FOR_UPDATE', bool(nowait), select_list, from_list, where_list ]
+ else: sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked, select_list, from_list, where_list ]
if order_by_pk: sql_ast.append([ 'ORDER_BY' ] + [ [ 'COLUMN', None, column ] for column in entity._pk_columns_ ])
- if limit is not None: sql_ast.append([ 'LIMIT', [ 'VALUE', limit ] ])
+ if limit is not None: sql_ast.append([ 'LIMIT', limit ])
database = entity._database_
sql, adapter = database._ast2sql(sql_ast)
cached_sql = sql, adapter, attr_offsets
@@ -3836,31 +4307,43 @@ def _set_rbits(entity, objects, attrs):
if wbits is None: continue
rbits = get_rbits(obj.__class__)
if rbits is None:
- rbits = sum(obj._bits_.get(attr, 0) for attr in attrs)
+ rbits = sum(obj._bits_except_volatile_.get(attr, 0) for attr in attrs)
rbits_dict[obj.__class__] = rbits
obj._rbits_ |= rbits & ~wbits
def _parse_row_(entity, row, attr_offsets):
discr_attr = entity._discriminator_attr_
- if not discr_attr: real_entity_subclass = entity
+ if not discr_attr:
+ discr_value = None
+ real_entity_subclass = entity
else:
discr_offset = attr_offsets[discr_attr][0]
discr_value = discr_attr.validate(row[discr_offset], None, entity, from_db=True)
real_entity_subclass = discr_attr.code2cls[discr_value]
+ discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x
+
+ database = entity._database_
+ cache = local.db2cache[database]
avdict = {}
for attr in real_entity_subclass._attrs_:
offsets = attr_offsets.get(attr)
- if offsets is None or attr.is_discriminator: continue
- avdict[attr] = attr.parse_value(row, offsets)
- if not entity._pk_is_composite_: pkval = avdict.pop(entity._pk_attrs_[0], None)
- else: pkval = tuple(avdict.pop(attr, None) for attr in entity._pk_attrs_)
+ if offsets is None:
+ continue
+ if attr.is_discriminator:
+ avdict[attr] = discr_value
+ else:
+ avdict[attr] = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache)
+
+ pkval = tuple(avdict.pop(attr) for attr in entity._pk_attrs_)
+ assert None not in pkval
+ if not entity._pk_is_composite_: pkval = pkval[0]
return real_entity_subclass, pkval, avdict
def _load_many_(entity, objects):
database = entity._database_
cache = database._get_cache()
seeds = cache.seeds[entity._pk_attrs_]
if not seeds: return
- objects = set(obj for obj in objects if obj in seeds)
+ objects = {obj for obj in objects if obj in seeds}
objects = sorted(objects, key=attrgetter('_pkval_'))
max_batch_size = database.provider.max_params_count // len(entity._pk_columns_)
while objects:
@@ -3903,7 +4386,6 @@ def _query_from_args_(entity, args, kwargs, frame_depth):
for_expr = ast.GenExprFor(ast.AssName(name, 'OP_ASSIGN'), ast.Name('.0'), [ if_expr ])
inner_expr = ast.GenExprInner(ast.Name(name), [ for_expr ])
locals = locals.copy() if locals is not None else {}
- assert '.0' not in locals
locals['.0'] = entity
return Query(code_key, inner_expr, globals, locals, cells)
def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs=None, obj_to_init=None):
@@ -3928,7 +4410,9 @@ def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs=
if obj is None:
with cache.flush_disabled():
- obj = obj_to_init or object.__new__(entity)
+ obj = obj_to_init
+ if obj_to_init is None:
+ obj = object.__new__(entity)
cache.objects.add(obj)
obj._pkval_ = pkval
obj._status_ = status
@@ -3961,7 +4445,7 @@ def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs=
assert cache.in_transaction
cache.for_update.add(obj)
return obj
- def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True):
+ def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True, seed=True):
i = 0
pkval = []
for attr in entity._pk_attrs_:
@@ -3969,16 +4453,19 @@ def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True):
val = raw_pkval[i]
i += 1
if not attr.reverse: val = attr.validate(val, None, entity, from_db=from_db)
- else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db)
+ else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db, seed=seed)
else:
if not attr.reverse: throw(NotImplementedError)
vals = raw_pkval[i:i+len(attr.columns)]
- val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db)
+ val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db, seed=seed)
i += len(attr.columns)
pkval.append(val)
if not entity._pk_is_composite_: pkval = pkval[0]
else: pkval = tuple(pkval)
- obj = entity._get_from_identity_map_(pkval, 'loaded', for_update)
+ if seed:
+ obj = entity._get_from_identity_map_(pkval, 'loaded', for_update)
+ else:
+ obj = entity[pkval]
assert obj._status_ != 'cancelled'
return obj
def _get_propagation_mixin_(entity):
@@ -3990,6 +4477,8 @@ def _get_propagation_mixin_(entity):
def fget(wrapper, attr=attr):
attrnames = wrapper._attrnames_ + (attr.name,)
items = [ x for x in (attr.__get__(item) for item in wrapper) if x is not None ]
+ if attr.py_type is Json:
+ return [ item.get_untracked() if isinstance(item, TrackedValue) else item for item in items ]
return Multiset(wrapper._obj_, attrnames, items)
elif not attr.is_collection:
def fget(wrapper, attr=attr):
@@ -4079,21 +4568,21 @@ def _get_attrs_(entity, only=None, exclude=None, with_collections=False, with_la
entity._attrnames_cache_[key] = attrs
return attrs
-def populate_criteria_list(criteria_list, columns, converters, params_count=0, table_alias=None):
- assert len(columns) == len(converters)
- for column, converter in izip(columns, converters):
- if converter is not None:
- criteria_list.append([ 'EQ', [ 'COLUMN', table_alias, column ],
- [ 'PARAM', (params_count, None, None), converter ] ])
+def populate_criteria_list(criteria_list, columns, converters, operations,
+ params_count=0, table_alias=None, optimistic=False):
+ for column, op, converter in izip(columns, operations, converters):
+ if op == 'IS_NULL':
+ criteria_list.append([ op, [ 'COLUMN', None, column ] ])
else:
- criteria_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ])
+ criteria_list.append([ op, [ 'COLUMN', table_alias, column ],
+ [ 'PARAM', (params_count, None, None), converter, optimistic ] ])
params_count += 1
return params_count
-statuses = set(['created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted'])
-del_statuses = set(['marked_to_delete', 'deleted', 'cancelled'])
-created_or_deleted_statuses = set(['created']) | del_statuses
-saved_statuses = set(['inserted', 'updated', 'deleted'])
+statuses = {'created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted'}
+del_statuses = {'marked_to_delete', 'deleted', 'cancelled'}
+created_or_deleted_statuses = {'created'} | del_statuses
+saved_statuses = {'inserted', 'updated', 'deleted'}
def throw_object_was_deleted(obj):
assert obj._status_ in del_statuses
@@ -4119,6 +4608,64 @@ def unpickle_entity(d):
def safe_repr(obj):
return Entity.__repr__(obj)
+def make_proxy(obj):
+ proxy = EntityProxy(obj)
+ return proxy
+
+class EntityProxy(object):
+ def __init__(self, obj):
+ entity = obj.__class__
+ object.__setattr__(self, '_entity_', entity)
+ pkval = obj.get_pk()
+ if pkval is None:
+ cache = obj._session_cache_
+ if obj._status_ in del_statuses or cache is None or not cache.is_alive:
+ throw(ValueError, 'Cannot make a proxy for %s object: primary key is not specified' % entity.__name__)
+ flush()
+ pkval = obj.get_pk()
+ assert pkval is not None
+ object.__setattr__(self, '_obj_pk_', pkval)
+
+ def __repr__(self):
+ entity = self._entity_
+ pkval = self._obj_pk_
+ pkrepr = ','.join(repr(item) for item in pkval) if isinstance(pkval, tuple) else repr(pkval)
+ return '