diff --git a/.gitignore b/.gitignore index d58e61b2f..b42e08466 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ pony/orm/tests/coverage.bat pony/orm/tests/htmlcov/*.* MANIFEST docs/_build/ +pony.egg-info/ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..a8cb05707 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "C:\\Users\\sostholm\\Envs\\pony\\Scripts\\python.exe" +} \ No newline at end of file diff --git a/BACKERS.md b/BACKERS.md new file mode 100644 index 000000000..1dc5db04c --- /dev/null +++ b/BACKERS.md @@ -0,0 +1,18 @@ +# Sponsors & Backers + +Pony ORM is Apache 2.0 licensed open source project. If you would like to support Pony ORM development, please consider: + +[Become a backer or sponsor](https://ponyorm.org/donation.html) + +## Backers + +- [Vincere](https://vince.re) +- Sergio Aguilar Guerrero +- David ROUBLOT +- Elijas Dapšauskas +- Dan Swain +- Christian Macht +- Johnathan Nader +- Andrei Rachalouski +- Juan Pablo Scaletti +- Marcus Birkenkrahe diff --git a/CHANGELOG.md b/CHANGELOG.md index bde74fca1..f6e083879 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,333 @@ +# PonyORM release 0.7.13 (2020-03-03) + +This release contains no new features or bugfixes. The only reason for this release is to test our CI/CD process. + +# PonyORM release 0.7.12 (2020-02-04) + +## Features + +* CockroachDB support added +* CI testing for SQLite, PostgreSQL & CockroachDB + +## Bugfixes + +* Fix translation of getting array items with negative indexes +* Fix string getitem translation for slices and negative indexes +* PostgreSQL DISTINCT bug fixed for queries with ORDER BY clause +* Fix date difference syntax in PostgreSQL +* Fix casting json to dobule in PostgreSQL +* Fix count by several columns in PostgreSQL +* Fix PostgreSQL MIN and MAX expressions on boolean columns +* Fix determination of interactive mode in PyCharm +* Fix column definition when `sql_default` is specified: DEFAULT should be before NOT NULL +* Relax checks on updating in-memory cache indexes (don't throw CacheIndexError on valid cases) +* Fix deduplication logic for attribute values + + +# PonyORM release 0.7.11 (2019-10-23) + +## Features + +* #472: Python 3.8 support +* Support of hybrid functions (inlining simple Python functions into query) +* #438: support datetime-datetime, datetime-timedelta, datetime+timedelta in queries + +## Bugfixes + +* #430: add ON DELETE CASCADE for many-to-many relationships +* #465: Should reconnect to MySQL on OperationalError 2013 'Lost connection to MySQL server during query' +* #468: Tuple-value comparisons generate incorrect queries +* #470 fix PendingDeprecationWarning of imp module +* Fix incorrect unpickling of objects with Json attributes +* Check value of discriminator column on object creation if set explicitly +* Correctly handle Flask current_user proxy when adding new items to collections +* Some bugs in syntax of aggregated queries were fixed +* Fix syntax of bulk delete queries +* Bulk delete queries should clear query results cache so next select will get correct result from the database +* Fix error message when hybrid method is too complex to decompile + + +# PonyORM release 0.7.10 (2019-04-20) + +## Bugfixes + +* Python3.7 and PyPy decompiling fixes +* Fix reading NULL from Optional nullable array column +* Fix handling of empty arrays in queries +* #415: error message typo +* #432: PonyFlask - request object can trigger teardown_request without real request +* Fix GROUP CONCAT separator for MySQL + + +# PonyORM release 0.7.9 (2019-01-21) + +## Bugfixes + +* Fix handling of empty arrays and empty lists in queries +* Fix reading optional nullable array columns from database + + +# PonyORM release 0.7.8 (2019-01-19) + +## Bugfixes + +* #414: prefetching Optional relationships fails on 0.7.7 +* Fix a bug caused by incorrect deduplication of column values + + +# PonyORM release 0.7.7 (2019-01-17) + +## Major features + +* Array type support for PostgreSQL and SQLite +* isinstance() support in queries +* Support of queries based on collections: select(x for x in y.items) + +## Other features + +* Support of Entity.select(**kwargs) +* Support of SKIP LOCKED option in 'SELECT ... FOR UPDATE' +* New function make_proxy(obj) to make cros-db_session proxy objects +* Specify ON DELETE CASCADE/SET NULL in foreign keys +* Support of LIMIT in `SELECT FROM (SELECT ...)` type of queries +* Support for negative JSON array indexes in SQLite + +## Improvements + +* Improved query prefetching: use fewer number of SQL queries +* Memory optimization: deduplication of values recieved from the database in the same session +* increase DBAPIProvider.max_params_count value + +## Bugfixes + +* #405: breaking change with cx_Oracle 7.0: DML RETURNING now returns a list +* #380: db_session should work with async functions +* #385: test fails with python3.6 +* #386: release unlocked lock error in SQLite +* #390: TypeError: writable buffers are not hashable +* #398: add auto coversion of numpy numeric types +* #404: GAE local run detection +* Fix Flask compatibility: add support of LocalProxy object +* db_session(sql_debug=True) should log SQL commands also during db_session.__exit__() +* Fix duplicated table join in FROM clause +* Fix accessing global variables from hybrid methods and properties +* Fix m2m collection loading bug +* Fix composite index bug: stackoverflow.com/questions/53147694 +* Fix MyEntity[obj.get_pk()] if pk is composite +* MySQL group_concat_max_len option set to max of 32bit platforms to avoid truncation +* Show all attribute options in show(Entity) call +* For nested db_session retry option should be ignored +* Fix py_json_unwrap +* Other minor fixes + + +# Pony ORM Release 0.7.6 (2018-08-10) + +## Bugfixes + +* Fixed a bug with hybrid properties that use external functions + + +# Pony ORM Release 0.7.6rc1 (2018-08-08) + +## New features + +* f-strings support in queries: `select(f'{s.name} - {s.age}' for s in Student)` +* #344: It is now possible to specify offset without limit: `query.limit(offset=10)` +* #371: Support of explicit casting of JSON expressions to `str`, `int` or `float` +* `@db.on_connect` decorator added + +## Bugfixes + +* Fix bulk delete bug introduced in 0.7.4 +* #370 Fix memory leak introduced in 0.7.4 +* Now `exists()` in query does not throw away condition in generator expression: `exists(s.gpa > 3 for s in Student)` +* #373: 0.7.4/0.7.5 breaks queries using the `in` operator to test membership of another query result +* #374: `auto=True` can be used with all PrimaryKey types, not only `int` +* #369: Make QueryResult looks like a list object again: add concatenation with lists, `.shuffle()` and `.to_list()` methods +* #355: Fix binary primary keys `PrimaryKey(buffer)` in Python2 +* Interactive mode support for PyCharm console +* Fix wrong table aliases in complex queries +* Fix query optimization code for complex queries + + +# Pony ORM Release 0.7.5 (2018-07-24) + +## Bugfixes + +* `query.where` and `query.filter` method bug introduced in 0.7.4 was fixed + + +# Pony ORM Release 0.7.4 (2018-07-23) + +## Major features + +* Hybrid methods and properties added: https://docs.ponyorm.com/entities.html#hybrid-methods-and-properties +* Allow to base queries on another queries: `select(x.a for x in prev_query if x.b)` +* Added support of Python 3.7 +* Added support of PyPy +* `group_concat()` aggregate function added +* pony.flask subpackage added for integration with Flask + +## Other features + +* `distinct` option added to aggregate functions +* Support of explicit casting to `float` and `bool` in queries + +## Improvements + +* Apply @cut_traceback decorator only when pony.MODE is 'INTERACTIVE' + +## Bugfixes + +* In SQLite3 `LIKE` is case sensitive now +* #249: Fix incorrect mixin used for Timedelta +* #251: correct dealing with qualified table names +* #301: Fix aggregation over JSON Column +* #306: Support of frozenset constants added +* #308: Fixed an error when assigning JSON attribute value to the same attribute: obj.json_attr = obj.json_attr +* #313: Fix missed retry on exception raised during db_session.__exit__ +* #314: Fix AttributeError: 'NoneType' object has no attribute 'seeds' +* #315: Fix attribute lifting for JSON attributes +* #321: Fix KeyError on obj.delete() +* #325: duplicating percentage sign in raw SQL queries without parameters +* #331: Overriding __len__ in entity fails +* #336: entity declaration serialization +* #357: reconnect after PostgreSQL server closed the connection unexpectedly +* Fix Python implementation of between() function and rename arguments: between(a, x, y) -> between(x, a, b) +* Fix retry handling: in PostgreSQL and Oracle an error can be raised during commit +* Fix optimistic update checks for composite foreign keys +* Don't raise OptimisticCheckError if db_session is not optimistic +* Handling incorrect datetime values in MySQL +* Improved ImportError exception messages when MySQLdb, pymysql, psycopg2 or psycopg2cffi driver was not found +* desc() function fixed to allow reverse its effect by calling desc(desc(x)) +* __contains__ method should check if objects belong to the same db_session +* Fix pony.MODE detection; mod_wsgi detection according to official doc +* A lot of inner fixes + + +# Pony ORM Release 0.7.3 (2017-10-23) + +## New features + +* `where()` method added to query +* `coalesce()` function added +* `between(x, a, b)` function added +* #295: Add `_table_options_` for entity class to specify engine, tablespace, etc. +* Make debug flag thread-local +* `sql_debugging` context manager added +* `sql_debug` and show_values arguments to db_session added +* `set_sql_debug` function added as alias to (to be deprecated) `sql_debug` function +* Allow `db_session` to accept `ddl` parameter when used as context manager +* Add `optimistic=True` option to db_session +* Skip optimistic checks for queries in `db_session` with `serializable=True` +* `fk_name` option added for attributes in order to specify foreign key name +* #280: Now it's possible to specify `timeout` option, as well as pass other keyword arguments for `sqlite3.connect` function +* Add support of explicit casting to int in queries using `int()` function +* Added modulo division % native support in queries + +## Bugfixes + +* Fix bugs with composite table names +* Fix invalid foreign key & index names for tables which names include schema name +* For queries like `select(x for x in MyObject if not x.description)` add "OR x.info IS NULL" for nullable string columns +* Add optimistic checking for `delete()` method +* Show updated attributes when `OptimisticCheckError` is being raised +* Fix incorrect aliases in nested queries +* Correctly pass exception from user-defined functions in SQLite +* More clear error messages for `UnrepeatableReadError` +* Fix `db_session(strict=True)` which was broken in 2d3afb24 +* Fixes #170: Problem with a primary key column used as a part of another key +* Fixes #223: incorrect result of `getattr(entity, attrname)` when the same lambda applies to different entities +* Fixes #266: Add handler to `"pony.orm"` logger does not work +* Fixes #278: Cascade delete error: FOREIGN KEY constraint failed, with complex entity relationships +* Fixes #283: Lost Json update immediately after object creation +* Fixes #284: `query.order_by()` orders Json numbers like strings +* Fixes #288: Expression text parsing issue in Python 3 +* Fixes #293: translation of if-expressions in expression +* Fixes #294: Real stack traces swallowed within IPython shell +* `Collection.count()` method should check if session is alive +* Set `obj._session_cache_` to None after exiting from db session for better garbage collection +* Unload collections which are not fully loaded after exiting from db session for better garbage collection +* Raise on unknown options for attributes that are part of relationship + + +# Pony ORM Release 0.7.2 (2017-07-17) + +## New features + +* All arguments of db.bind() can be specified as keyword arguments. Previously Pony required the first positional argument which specified the database provider. Now you can pass all the database parameters using the dict: db.bind(**db_params). See https://docs.ponyorm.com/api_reference.html#Database.bind +* The `optimistic` attribute option is added https://docs.ponyorm.com/api_reference.html#cmdoption-arg-optimistic + +## Bugfixes + +* Fixes #219: when a database driver raises an error, sometimes this error was masked by the 'RollbackException: InterfaceError: connection already closed' exception. This happened because on error, Pony tried to rollback transaction, but the connection to the database was already closed and it masked the initial error. Now Pony displays the original error which helps to understand the cause of the problem. +* Fixes #276: Memory leak +* Fixes the __all__ declaration. Previously IDEs, such as PyCharm, could not understand what is going to be imported by 'from pony.orm import *'. Now it works fine. +* Fixes #232: negate check for numeric expressions now checks if value is zero or NULL +* Fixes #238, fixes #133: raise TransactionIntegrityError exception instead of AssertionError if obj.collection.create(**kwargs) creates a duplicate object +* Fixes #221: issue with unicode json path keys +* Fixes bug when discriminator column is used as a part of a primary key +* Handle situation when SQLite blob column contains non-binary value + + +# Pony ORM Release 0.7.1 (2017-01-10) + +## New features + +* New warning DatabaseContainsIncorrectEmptyValue added, it is raised when the required attribute is empty during loading an entity from the database + +## Bugfixes + +* Fixes #216: Added Python 3.6 support +* Fixes #203: subtranslator should use argnames from parent translator +* Change a way aliases in SQL query are generated in order to fix a problem when a subquery alias masks a base query alias +* Volatile attribute bug fixed +* Fix creation of self-referenced foreign keys - before this Pony didn't create the foreign key for self-referenced attributes +* Bug fixed: when required attribute is empty the loading from the database shouldn't raise the validation error. Now Pony raises the warning DatabaseContainsIncorrectEmptyValue +* Throw an error with more clear explanation when a list comprehension is used inside a query instead of a generator expression: "Use generator expression (... for ... in ...) instead of list comprehension [... for ... in ...] inside query" + + +# Pony ORM Release 0.7 (2016-10-11) + +Starting with this release Pony ORM is release under the Apache License, Version 2.0. + +## New features + +* Added getattr() support in queries: https://docs.ponyorm.com/api_reference.html#getattr + +## Backward incompatible changes + +* #159: exceptions happened during flush() should not be wrapped with CommitException + +Before this release an exception that happened in a hook(https://docs.ponyorm.com/api_reference.html#entity-hooks), could be raised in two ways - either wrapped into the CommitException or without wrapping. It depended if the exception happened during the execution of flush() or commit() function on the db_session exit. Now the exception happened inside the hook never will be wrapped into the CommitException. + +## Bugfixes + +* #190: Timedelta is not supported when using pymysql + + +# Pony ORM Release 0.6.6 (2016-08-22) + +## New features + +* Added native JSON data type support in all supported databases: https://docs.ponyorm.com/json.html + +## Backward incompatible changes + +* Dropped Python 2.6 support + +## Improvements + +* #179 Added the compatibility with PYPY using psycopg2cffi +* Added an experimental @db_session `strict` parameter: https://docs.ponyorm.com/transactions.html#strict + +## Bugfixes + +* #182 - LEFT JOIN doesn't work as expected for inherited entities when foreign key is None +* Some small bugs were fixed + + # Pony ORM Release 0.6.5 (2016-04-04) ## Improvements diff --git a/LICENSE b/LICENSE index 2def0e883..9a2a1c6bf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,661 +1,179 @@ - GNU AFFERO GENERAL PUBLIC LICENSE - Version 3, 19 November 2007 - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The GNU Affero General Public License is a free, copyleft license for -software and other kinds of works, specifically designed to ensure -cooperation with the community in the case of network server software. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -our General Public Licenses are intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - Developers that use our General Public Licenses protect your rights -with two steps: (1) assert copyright on the software, and (2) offer -you this License which gives you legal permission to copy, distribute -and/or modify the software. - - A secondary benefit of defending all users' freedom is that -improvements made in alternate versions of the program, if they -receive widespread use, become available for other developers to -incorporate. Many developers of free software are heartened and -encouraged by the resulting cooperation. However, in the case of -software used on network servers, this result may fail to come about. -The GNU General Public License permits making a modified version and -letting the public access it on a server without ever releasing its -source code to the public. - - The GNU Affero General Public License is designed specifically to -ensure that, in such cases, the modified source code becomes available -to the community. It requires the operator of a network server to -provide the source code of the modified version running there to the -users of that server. Therefore, public use of a modified version, on -a publicly accessible server, gives the public access to the source -code of the modified version. - - An older license, called the Affero General Public License and -published by Affero, was designed to accomplish similar goals. This is -a different license, not a version of the Affero GPL, but Affero has -released a new version of the Affero GPL which permits relicensing under -this license. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU Affero General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Remote Network Interaction; Use with the GNU General Public License. - - Notwithstanding any other provision of this License, if you modify the -Program, your modified version must prominently offer all users -interacting with it remotely through a computer network (if your version -supports such interaction) an opportunity to receive the Corresponding -Source of your version by providing access to the Corresponding Source -from a network server at no charge, through some standard or customary -means of facilitating copying of software. This Corresponding Source -shall include the Corresponding Source for any work covered by version 3 -of the GNU General Public License that is incorporated pursuant to the -following paragraph. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the work with which it is combined will remain governed by version -3 of the GNU General Public License. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU Affero General Public License from time to time. Such new versions -will be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU Affero General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU Affero General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU Affero General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - - Copyright (C) - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - - If your software can interact with users remotely through a computer -network, you should also make sure that it provides a way for users to -get its source. For example, if your program is a web application, its -interface could display a "Source" link that leads users to an archive -of the code. There are many ways you could offer source, and different -solutions will be better for different programs; see section 13 for the -specific requirements. - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU AGPL, see -. \ No newline at end of file + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2016 Alexander Kozlovsky, Alexey Malashkevich diff --git a/MANIFEST.in b/MANIFEST.in index 1bf3c80a3..e05b34645 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ -include pony/orm/tests/queries.txt \ No newline at end of file +include pony/orm/tests/queries.txt +include pony/flask/example/templates *.html +include LICENSE diff --git a/README.md b/README.md index e7fd3ee59..c0f3e9eef 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,91 @@ +# Downloads +[![Downloads](https://pepy.tech/badge/pony)](https://pepy.tech/project/pony) [![Downloads](https://pepy.tech/badge/pony/month)](https://pepy.tech/project/pony/month) [![Downloads](https://pepy.tech/badge/pony/week)](https://pepy.tech/project/pony/week) + +# Tests + +#### PostgreSQL +Python 2 + + +Python 3 + + + +#### SQLite +Python 2 + + +Python 3 + + + +#### CockroachDB +Python 2 + + +Python 3 + + + + Pony Object-Relational Mapper -================================== +============================= -Pony is an object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using generator expressions. Pony works with entities which are mapped to a SQL database. Using generator syntax for writing queries allows the user to formulate very eloquent queries. It increases the level of abstraction and allows a programmer to concentrate on the business logic of the application. For this purpose Pony analyzes the abstract syntax tree of a generator and translates it to its SQL equivalent. +Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions and lambdas. Pony analyzes the abstract syntax tree of the expression and translates it into a SQL query. -Following is an example of a query in Pony: +Here is an example query in Pony: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. -Pony ORM also include the ER Diagram Editor which is a great tool for prototyping. You can create your ER diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. +By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps to save resources. Pony achieves this ease of use through the following: + +* Compact entity definitions +* The concise query language +* Ability to work with Pony interactively in a Python interpreter +* Comprehensive error messages, showing the exact part where an error occurred in the query +* Displaying of the generated SQL in a readable format with indentation + +All this helps the developer to focus on implementing the business logic of an application, instead of struggling with a mapper trying to understand how to get the data from the database. + +See the example [here](https://github.com/ponyorm/pony/blob/orm/pony/orm/examples/estore.py) + -The package pony.orm.examples contains several examples. +Support Pony ORM Development +---------------------------- -Documenation is available at [https://docs.ponyorm.com](https://docs.ponyorm.com) -The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc), it is released under Apache 2.0 license. -Please create new documentation related issues [https://github.com/ponyorm/pony-doc/issues](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. +Pony ORM is Apache 2.0 licensed open source project. If you would like to support Pony ORM development, please consider: + +[Become a backer or sponsor](https://ponyorm.org/donation.html) + + +Online tool for database design +------------------------------- + +Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool for prototyping. You can create your database diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. + + +Documentation +------------- + +Documenation is available at [https://docs.ponyorm.org](https://docs.ponyorm.org) +The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc). +Please create new documentation related issues [here](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. -We are looking forward to your comments and suggestions at our mailing list [http://ponyorm-list.ponyorm.com](http://ponyorm-list.ponyorm.com) License ------------ +------- + +Pony ORM is released under the Apache 2.0 license. + -Pony ORM is released under multiple licenses, check [ponyorm.com](https://ponyorm.com/license-and-pricing.html) for more information. +PonyORM community +----------------- -Copyright (c) 2016 Pony ORM, LLC. All rights reserved. -team (at) ponyorm.com +Please post your questions on [Stack Overflow](http://stackoverflow.com/questions/tagged/ponyorm). +Meet the PonyORM team, chat with the community members, and get your questions answered on our community [Telegram group](https://t.me/ponyorm). +Join our newsletter at [ponyorm.org](https://ponyorm.org). +Reach us on [Twitter](https://twitter.com/ponyorm). -Please send your questions, comments and suggestions to our mailing list [http://ponyorm-list.ponyorm.com](http://ponyorm-list.ponyorm.com) +Copyright (c) 2013-2019 Pony ORM. All rights reserved. info (at) ponyorm.org diff --git a/models.py b/models.py new file mode 100644 index 000000000..abdb3be5a --- /dev/null +++ b/models.py @@ -0,0 +1,21 @@ +from datetime import datetime +from uuid import UUID, uuid4 +from pony.orm import * + +db = Database() + +class User(db.Entity): + id = PrimaryKey(UUID, default=uuid4) + created = Required(datetime, default=lambda: datetime.now()) + + + + +db.bind('mssql', driver='ODBC Driver 17 for SQL Server', server='10.24.219.31', database='testing', username="tester2", password="tester123!") +db.generate_mapping(create_tables=True) + +if __name__ == '__main__': + with db_session: + usr = User() + users = select(u for u in User)[:] + print(users) diff --git a/pony/__init__.py b/pony/__init__.py index 948b1ddef..ac2d859ae 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -1,40 +1,50 @@ from __future__ import absolute_import, print_function -import sys +import os, sys from os.path import dirname -__version__ = '0.6.6-dev' +__version__ = '0.7.14-dev' def detect_mode(): try: import google.appengine except ImportError: pass else: - try: import dev_appserver - except ImportError: return 'GAE-SERVER' - return 'GAE-LOCAL' + if os.getenv('SERVER_SOFTWARE', '').startswith('Development'): + return 'GAE-LOCAL' + return 'GAE-SERVER' - try: mod_wsgi = sys.modules['mod_wsgi'] - except KeyError: pass + try: from mod_wsgi import version + except: pass else: return 'MOD_WSGI' - if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' + main = sys.modules['__main__'] - if 'uwsgi' in sys.modules: return 'UWSGI' + if not hasattr(main, '__file__'): # console + return 'INTERACTIVE' + + if os.getenv('IPYTHONENABLE', '') == 'True': + return 'INTERACTIVE' + + if getattr(main, 'INTERACTIVE_MODE_AVAILABLE', False): # pycharm console + return 'INTERACTIVE' - try: sys.modules['__main__'].__file__ - except AttributeError: return 'INTERACTIVE' - return 'CHERRYPY' + if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' + if 'uwsgi' in sys.modules: return 'UWSGI' + if 'flask' in sys.modules: return 'FLASK' + if 'cherrypy' in sys.modules: return 'CHERRYPY' + if 'bottle' in sys.modules: return 'BOTTLE' + return 'UNKNOWN' MODE = detect_mode() MAIN_FILE = None -if MODE in ('CHERRYPY', 'GAE-LOCAL', 'GAE-SERVER', 'FCGI-FLUP'): - MAIN_FILE = sys.modules['__main__'].__file__ -elif MODE == 'MOD_WSGI': +if MODE == 'MOD_WSGI': for module_name, module in sys.modules.items(): if module_name.startswith('_mod_wsgi_'): MAIN_FILE = module.__file__ break +elif MODE != 'INTERACTIVE': + MAIN_FILE = sys.modules['__main__'].__file__ if MAIN_FILE is not None: MAIN_DIR = dirname(MAIN_FILE) else: MAIN_DIR = None diff --git a/pony/flask/__init__.py b/pony/flask/__init__.py new file mode 100644 index 000000000..75e877cb5 --- /dev/null +++ b/pony/flask/__init__.py @@ -0,0 +1,23 @@ +from pony.orm import db_session +from flask import request + +def _enter_session(): + session = db_session() + request.pony_session = session + session.__enter__() + +def _exit_session(exception): + session = getattr(request, 'pony_session', None) + if session is not None: + session.__exit__(exc=exception) + +class Pony(object): + def __init__(self, app=None): + self.app = None + if app is not None: + self.init_app(app) + + def init_app(self, app): + self.app = app + self.app.before_request(_enter_session) + self.app.teardown_request(_exit_session) \ No newline at end of file diff --git a/pony/flask/example/__init__.py b/pony/flask/example/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pony/flask/example/__main__.py b/pony/flask/example/__main__.py new file mode 100644 index 000000000..93d11b071 --- /dev/null +++ b/pony/flask/example/__main__.py @@ -0,0 +1,7 @@ +from .views import * +from .app import app + +if __name__ == '__main__': + db.bind(**app.config['PONY']) + db.generate_mapping(create_tables=True) + app.run() \ No newline at end of file diff --git a/pony/flask/example/app.py b/pony/flask/example/app.py new file mode 100644 index 000000000..51d33ad90 --- /dev/null +++ b/pony/flask/example/app.py @@ -0,0 +1,16 @@ +from flask import Flask +from flask_login import LoginManager +from pony.flask import Pony +from .config import config +from .models import db + +app = Flask(__name__) +app.config.update(config) + +Pony(app) +login_manager = LoginManager(app) +login_manager.login_view = 'login' + +@login_manager.user_loader +def load_user(user_id): + return db.User.get(id=user_id) diff --git a/pony/flask/example/config.py b/pony/flask/example/config.py new file mode 100644 index 000000000..0c33fe839 --- /dev/null +++ b/pony/flask/example/config.py @@ -0,0 +1,9 @@ +config = dict( + DEBUG = False, + SECRET_KEY = 'secret_xxx', + PONY = { + 'provider': 'sqlite', + 'filename': 'db.db3', + 'create_db': True + } +) \ No newline at end of file diff --git a/pony/flask/example/models.py b/pony/flask/example/models.py new file mode 100644 index 000000000..f97d56425 --- /dev/null +++ b/pony/flask/example/models.py @@ -0,0 +1,10 @@ +from pony.orm import Database, Required, Optional +from flask_login import UserMixin +from datetime import datetime + +db = Database() + +class User(db.Entity, UserMixin): + login = Required(str, unique=True) + password = Required(str) + last_login = Optional(datetime) \ No newline at end of file diff --git a/pony/flask/example/templates/index.html b/pony/flask/example/templates/index.html new file mode 100644 index 000000000..a40948d31 --- /dev/null +++ b/pony/flask/example/templates/index.html @@ -0,0 +1,36 @@ + + + Hello! + + + + +
+ {% with messages = get_flashed_messages() %} + {% if messages %} + {% for message in messages %} + + {% endfor %} + {% endif %} + {% endwith %} + {% if not current_user.is_authenticated %} +

Hi, please log in or register


+ {% else %} +

Hi, {{ current_user.login }}. Your last login: {{ current_user.last_login.strftime('%Y-%m-%d') }}

+ Logout +

List of users

+
    + {% for user in users %} +
  • + {% if user == current_user %} + {{ user.login }} + {% else %} + {{ user.login }} + {% endif %} +
  • + {% endfor %} +
+ {% endif %} +
+ + \ No newline at end of file diff --git a/pony/flask/example/templates/login.html b/pony/flask/example/templates/login.html new file mode 100644 index 000000000..562525904 --- /dev/null +++ b/pony/flask/example/templates/login.html @@ -0,0 +1,30 @@ + + + Login page + + + + +
+ {% with messages = get_flashed_messages() %} + {% if messages %} + {% for message in messages %} + + {% endfor %} + {% endif %} + {% endwith %} +

Please login

+
+
+ + + +
+ {% if error %} +

Error: {{ error }} + {% endif %} +

+ + \ No newline at end of file diff --git a/pony/flask/example/templates/reg.html b/pony/flask/example/templates/reg.html new file mode 100644 index 000000000..ae9a27d91 --- /dev/null +++ b/pony/flask/example/templates/reg.html @@ -0,0 +1,30 @@ + + + Login page + + + + +
+ {% with messages = get_flashed_messages() %} + {% if messages %} + {% for message in messages %} + + {% endfor %} + {% endif %} + {% endwith %} +

Register

+
+
+ + + +
+ {% if error %} +

Error: {{ error }} + {% endif %} +

+ + \ No newline at end of file diff --git a/pony/flask/example/views.py b/pony/flask/example/views.py new file mode 100644 index 000000000..a8477a025 --- /dev/null +++ b/pony/flask/example/views.py @@ -0,0 +1,56 @@ +from .app import app +from .models import db +from flask import render_template, request, flash, redirect, abort +from flask_login import current_user, logout_user, login_user, login_required +from datetime import datetime +from pony.orm import flush + +@app.route('/') +def index(): + users = db.User.select() + return render_template('index.html', user=current_user, users=users) + +@app.route('/login', methods=['GET', 'POST']) +def login(): + if request.method == 'POST': + username = request.form['username'] + password = request.form['password'] + possible_user = db.User.get(login=username) + if not possible_user: + flash('Wrong username') + return redirect('/login') + if possible_user.password == password: + possible_user.last_login = datetime.now() + login_user(possible_user) + return redirect('/') + + flash('Wrong password') + return redirect('/login') + else: + return render_template('login.html') + +@app.route('/reg', methods=['GET', 'POST']) +def reg(): + if request.method == 'POST': + username = request.form['username'] + password = request.form['password'] + exist = db.User.get(login=username) + if exist: + flash('Username %s is already taken, choose another one' % username) + return redirect('/reg') + + user = db.User(login=username, password=password) + user.last_login = datetime.now() + flush() + login_user(user) + flash('Successfully registered') + return redirect('/') + else: + return render_template('reg.html') + +@app.route('/logout') +@login_required +def logout(): + logout_user() + flash('Logged out') + return redirect('/') \ No newline at end of file diff --git a/pony/options.py b/pony/options.py index 8e26fcad6..6c31ab487 100644 --- a/pony/options.py +++ b/pony/options.py @@ -59,7 +59,6 @@ CONSOLE_ENCODING = None # db options -PREFETCHING = True MAX_FETCH_COUNT = None # used for select(...).show() diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index f989d18c5..6c9068b60 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -1,39 +1,47 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import basestring, iteritems from functools import update_wrapper from pony.thirdparty.compiler import ast -from pony.utils import throw +from pony.utils import HashableDict, throw, copy_ast class TranslationError(Exception): pass +pre_method_caches = {} +post_method_caches = {} + class ASTTranslator(object): def __init__(translator, tree): translator.tree = tree - translator.pre_methods = {} - translator.post_methods = {} + translator_cls = translator.__class__ + pre_method_caches.setdefault(translator_cls, {}) + post_method_caches.setdefault(translator_cls, {}) def dispatch(translator, node): - cls = node.__class__ + translator_cls = translator.__class__ + pre_methods = pre_method_caches[translator_cls] + post_methods = post_method_caches[translator_cls] + node_cls = node.__class__ - try: pre_method = translator.pre_methods[cls] + try: pre_method = pre_methods[node_cls] except KeyError: - pre_method = getattr(translator, 'pre' + cls.__name__, translator.default_pre) - translator.pre_methods[cls] = pre_method - stop = translator.call(pre_method, node) + pre_method = getattr(translator_cls, 'pre' + node_cls.__name__, translator_cls.default_pre) + pre_methods[node_cls] = pre_method + stop = translator.call(pre_method, node) if stop: return for child in node.getChildNodes(): translator.dispatch(child) - try: post_method = translator.post_methods[cls] + try: post_method = post_methods[node_cls] except KeyError: - post_method = getattr(translator, 'post' + cls.__name__, translator.default_post) - translator.post_methods[cls] = post_method + post_method = getattr(translator_cls, 'post' + node_cls.__name__, translator_cls.default_post) + post_methods[node_cls] = post_method translator.call(post_method, node) def call(translator, method, node): - return method(node) + return method(translator, node) def default_pre(translator, node): pass def default_post(translator, node): @@ -53,15 +61,22 @@ def binop_src(op, node): return op.join((node.left.src, node.right.src)) def ast2src(tree): + src = getattr(tree, 'src', None) + if src is not None: + return src PythonTranslator(tree) return tree.src class PythonTranslator(ASTTranslator): def __init__(translator, tree): ASTTranslator.__init__(translator, tree) + translator.top_level_f_str = None translator.dispatch(tree) def call(translator, method, node): - node.src = method(node) + node.src = method(translator, node) + def default_pre(translator, node): + if getattr(node, 'src', None) is not None: + return True # node.src is already calculated, stop dispatching def default_post(translator, node): throw(NotImplementedError, node) def postGenExpr(translator, node): @@ -76,6 +91,20 @@ def postGenExprFor(translator, node): return src def postGenExprIf(translator, node): return 'if %s' % node.test.src + def postIfExp(translator, node): + return '%s if %s else %s' % (node.then.src, node.test.src, node.else_.src) + def postLambda(translator, node): + argnames = list(node.argnames) + kwargs_name = argnames.pop() if node.kwargs else None + varargs_name = argnames.pop() if node.varargs else None + def_argnames = argnames[-len(node.defaults):] if node.defaults else [] + nodef_argnames = argnames[:-len(node.defaults)] if node.defaults else argnames + args = ', '.join(nodef_argnames) + d_args = ', '.join('%s=%s' % (argname, default.src) for argname, default in zip(def_argnames, node.defaults)) + v_arg = '*%s' % varargs_name if varargs_name else None + kw_arg = '**%s' % kwargs_name if kwargs_name else None + args = ', '.join(x for x in [args, d_args, v_arg, kw_arg] if x) + return 'lambda %s: %s' % (args, node.code.src) @priority(14) def postOr(translator, node): return ' or '.join(expr.src for expr in node.nodes) @@ -171,6 +200,8 @@ def postConst(translator, node): s = str(value) if float(s) == value: return s return repr(value) + def postEllipsis(translator, node): + return '...' def postList(translator, node): node.priority = 1 return '[%s]' % ', '.join(item.src for item in node.nodes) @@ -199,20 +230,48 @@ def postAssName(translator, node): return node.name def postKeyword(translator, node): return '='.join((node.name, node.expr.src)) + def preStr(self, node): + if self.top_level_f_str is None: + self.top_level_f_str = node + def postStr(self, node): + if self.top_level_f_str is node: + self.top_level_f_str = None + return "f%r" % ('{%s}' % node.value.src) + return '{%s}' % node.value.src + def preJoinedStr(self, node): + if self.top_level_f_str is None: + self.top_level_f_str = node + def postJoinedStr(self, node): + result = ''.join( + value.value if isinstance(value, ast.Const) else value.src + for value in node.values) + if self.top_level_f_str is node: + self.top_level_f_str = None + return "f%r" % result + return result + def preFormattedValue(self, node): + if self.top_level_f_str is None: + self.top_level_f_str = node + def postFormattedValue(self, node): + res = '{%s:%s}' % (node.value.src, node.fmt_spec.src) + if self.top_level_f_str is node: + self.top_level_f_str = None + return "f%r" % res + return res nonexternalizable_types = (ast.Keyword, ast.Sliceobj, ast.List, ast.Tuple) class PreTranslator(ASTTranslator): def __init__(translator, tree, globals, locals, - special_functions, const_functions, additional_internal_names=()): + special_functions, const_functions, outer_names=()): ASTTranslator.__init__(translator, tree) translator.globals = globals translator.locals = locals translator.special_functions = special_functions translator.const_functions = const_functions translator.contexts = [] - if additional_internal_names: - translator.contexts.append(additional_internal_names) + if outer_names: + translator.contexts.append(outer_names) translator.externals = externals = set() translator.dispatch(tree) for node in externals.copy(): @@ -224,13 +283,13 @@ def __init__(translator, tree, globals, locals, def dispatch(translator, node): node.external = node.constant = None ASTTranslator.dispatch(translator, node) - childs = node.getChildNodes() - if node.external is None and childs and all( - getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in childs): + children = node.getChildNodes() + if node.external is None and children and all( + getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in children): node.external = True if node.external and not node.constant: externals = translator.externals - externals.difference_update(childs) + externals.difference_update(children) externals.add(node) def preGenExprInner(translator, node): translator.contexts.append(set()) @@ -260,6 +319,10 @@ def postName(translator, node): node.external = True def postConst(translator, node): node.external = node.constant = True + def postDict(translator, node): + node.external = True + def postList(translator, node): + node.external = True def postKeyword(translator, node): node.constant = node.expr.constant def postCallFunc(translator, node): @@ -274,32 +337,37 @@ def postCallFunc(translator, node): expr = '.'.join(reversed(attrs)) x = eval(expr, translator.globals, translator.locals) try: hash(x) - except TypeError: x = None - if x in translator.special_functions: - if x.__name__ == 'raw_sql': node.raw_sql = True - else: node.external = False - elif x in translator.const_functions: - for arg in node.args: - if not arg.constant: return - if node.star_args is not None and not node.star_args.constant: return - if node.dstar_args is not None and not node.dstar_args.constant: return - node.constant = True + except TypeError: pass + else: + if x in translator.special_functions: + if x.__name__ == 'raw_sql': node.raw_sql = True + elif x is getattr: + attr_node = node.args[1] + attr_node.parent_node = node + else: node.external = False + elif x in translator.const_functions: + for arg in node.args: + if not arg.constant: return + if node.star_args is not None and not node.star_args.constant: return + if node.dstar_args is not None and not node.dstar_args.constant: return + node.constant = True extractors_cache = {} -def create_extractors(code_key, tree, filter_num, globals, locals, - special_functions, const_functions, additional_internal_names=()): - cache_key = code_key, filter_num - result = extractors_cache.get(cache_key) - if result is None: - pretranslator = PreTranslator( - tree, globals, locals, special_functions, const_functions, additional_internal_names) +def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, outer_names=()): + result = extractors_cache.get(code_key) + if not result: + pretranslator = PreTranslator(tree, globals, locals, special_functions, const_functions, outer_names) extractors = {} for node in pretranslator.externals: src = node.src = ast2src(node) - if src == '.0': code = None - else: code = compile(src, src, 'eval') - extractors[filter_num, src] = code - varnames = list(sorted(extractors)) - result = extractors_cache[cache_key] = extractors, varnames, tree + if src == '.0': + def extractor(globals, locals): + return locals['.0'] + else: + code = compile(src, src, 'eval') + def extractor(globals, locals, code=code): + return eval(code, globals, locals) + extractors[src] = extractor + result = extractors_cache[code_key] = tree, extractors return result diff --git a/pony/orm/core.py b/pony/orm/core.py index 326563b6e..fc54cb498 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1,8 +1,8 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \ - basestring, unicode, buffer, int_types, builtins, pickle, with_metaclass + basestring, unicode, buffer, int_types, builtins, with_metaclass -import json, re, sys, types, datetime, logging, itertools +import json, re, sys, types, datetime, logging, itertools, warnings, inspect from operator import attrgetter, itemgetter from itertools import chain, starmap, repeat from time import time @@ -13,75 +13,97 @@ from collections import defaultdict from hashlib import md5 from inspect import isgeneratorfunction +from functools import wraps from pony.thirdparty.compiler import ast, parse import pony from pony import options from pony.orm.decompiling import decompile -from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of +from pony.orm.ormtypes import ( + LongStr, LongUnicode, numeric_types, raw_sql, RawSQL, normalize, Json, TrackedValue, QueryType, + Array, IntArray, StrArray, FloatArray + ) from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, OperationalError, IntegrityError, InternalError, ProgrammingError, NotSupportedError ) from pony import utils -from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \ - deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat - -__all__ = ''' - pony +from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ + get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ + between, concat, coalesce, HashableDict, deref_proxy, deduplicate - DBException RowNotFound MultipleRowsFound TooManyRowsFound +__all__ = [ + 'pony', - Warning Error InterfaceError DatabaseError DataError OperationalError - IntegrityError InternalError ProgrammingError NotSupportedError + 'DBException', 'RowNotFound', 'MultipleRowsFound', 'TooManyRowsFound', - OrmError ERDiagramError DBSchemaError MappingError - TableDoesNotExist TableIsNotEmpty ConstraintError CacheIndexError PermissionError - ObjectNotFound MultipleObjectsFoundError TooManyObjectsFoundError OperationWithDeletedObjectError - TransactionError ConnectionClosedError TransactionIntegrityError IsolationError CommitException RollbackException - UnrepeatableReadError OptimisticCheckError UnresolvableCyclicDependency UnexpectedError DatabaseSessionIsOver + 'Warning', 'Error', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', + 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - TranslationError ExprEvalError + 'OrmError', 'ERDiagramError', 'DBSchemaError', 'MappingError', 'BindingError', + 'TableDoesNotExist', 'TableIsNotEmpty', 'ConstraintError', 'CacheIndexError', + 'ObjectNotFound', 'MultipleObjectsFoundError', 'TooManyObjectsFoundError', 'OperationWithDeletedObjectError', + 'TransactionError', 'ConnectionClosedError', 'TransactionIntegrityError', 'IsolationError', + 'CommitException', 'RollbackException', 'UnrepeatableReadError', 'OptimisticCheckError', + 'UnresolvableCyclicDependency', 'UnexpectedError', 'DatabaseSessionIsOver', + 'PonyRuntimeWarning', 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', + 'TranslationError', 'ExprEvalError', 'PermissionError', - RowNotFound MultipleRowsFound TooManyRowsFound + 'Database', 'sql_debug', 'set_sql_debug', 'sql_debugging', 'show', - Database sql_debug show + 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator', + 'composite_key', 'composite_index', + 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', 'make_proxy', - PrimaryKey Required Optional Set Discriminator - composite_key composite_index - flush commit rollback db_session with_transaction + 'LongStr', 'LongUnicode', 'Json', 'IntArray', 'StrArray', 'FloatArray', - LongStr LongUnicode + 'select', 'left_join', 'get', 'exists', 'delete', - select left_join get exists delete + 'count', 'sum', 'min', 'max', 'avg', 'group_concat', 'distinct', - count sum min max avg distinct + 'JOIN', 'desc', 'between', 'concat', 'coalesce', 'raw_sql', - JOIN desc concat raw_sql + 'buffer', 'unicode', - buffer unicode + 'get_current_user', 'set_current_user', 'perm', 'has_perm', + 'get_user_groups', 'get_user_roles', 'get_object_labels', + 'user_groups_getter', 'user_roles_getter', 'obj_labels_getter' +] - get_current_user set_current_user perm has_perm - get_user_groups get_user_roles get_object_labels - user_groups_getter user_roles_getter obj_labels_getter - '''.split() - -debug = False suppress_debug_change = False def sql_debug(value): - global debug - if not suppress_debug_change: debug = value + # todo: make sql_debug deprecated + if not suppress_debug_change: + local.debug = value + + +def set_sql_debug(debug=True, show_values=None): + if not suppress_debug_change: + local.debug = debug + local.show_values = show_values + orm_logger = logging.getLogger('pony.orm') sql_logger = logging.getLogger('pony.orm.sql') orm_log_level = logging.INFO +def has_handlers(logger): + if not PY2: + return logger.hasHandlers() + while logger: + if logger.handlers: + return True + elif not logger.propagate: + return False + logger = logger.parent + return False + def log_orm(msg): - if logging.root.handlers: + if has_handlers(orm_logger): orm_logger.log(orm_log_level, msg) else: print(msg) @@ -89,15 +111,18 @@ def log_orm(msg): def log_sql(sql, arguments=None): if type(arguments) is list: sql = 'EXECUTEMANY (%d)\n%s' % (len(arguments), sql) - if logging.root.handlers: - sql_logger.log(orm_log_level, sql) # arguments can hold sensitive information + if has_handlers(sql_logger): + if local.show_values and arguments: + sql = '%s\n%s' % (sql, format_arguments(arguments)) + sql_logger.log(orm_log_level, sql) else: - print(sql) - if not arguments: pass - elif type(arguments) is list: - for args in arguments: print(args2str(args)) - else: print(args2str(arguments)) - print() + if (local.show_values is None or local.show_values) and arguments: + sql = '%s\n%s' % (sql, format_arguments(arguments)) + print(sql, end='\n\n') + +def format_arguments(arguments): + if type(arguments) is not list: return args2str(arguments) + return '\n'.join(args2str(args) for args in arguments) def args2str(args): if isinstance(args, (tuple, list)): @@ -113,6 +138,7 @@ class OrmError(Exception): pass class ERDiagramError(OrmError): pass class DBSchemaError(OrmError): pass class MappingError(OrmError): pass +class BindingError(OrmError): pass class TableDoesNotExist(OrmError): pass class TableIsNotEmpty(OrmError): pass @@ -180,13 +206,30 @@ def __init__(exc, msg, original_exc): class ExprEvalError(TranslationError): def __init__(exc, src, cause): assert isinstance(cause, Exception) - msg = '%s raises %s: %s' % (src, type(cause).__name__, str(cause)) + msg = '`%s` raises %s: %s' % (src, type(cause).__name__, str(cause)) TranslationError.__init__(exc, msg) exc.cause = cause -class OptimizationFailed(Exception): +class PonyInternalException(Exception): + pass + +class OptimizationFailed(PonyInternalException): pass # Internal exception, cannot be encountered in user code +class UseAnotherTranslator(PonyInternalException): + def __init__(self, translator): + Exception.__init__(self, 'This exception should be catched internally by PonyORM') + self.translator = translator + +class PonyRuntimeWarning(RuntimeWarning): + pass + +class DatabaseContainsIncorrectValue(PonyRuntimeWarning): + pass + +class DatabaseContainsIncorrectEmptyValue(DatabaseContainsIncorrectValue): + pass + def adapt_sql(sql, paramstyle): result = adapted_sql_cache.get((sql, paramstyle)) if result is not None: return result @@ -194,6 +237,7 @@ def adapt_sql(sql, paramstyle): result = [] args = [] kwargs = {} + original_sql = sql if paramstyle in ('format', 'pyformat'): sql = sql.replace('%', '%%') while True: try: i = sql.index('$', pos) @@ -229,31 +273,74 @@ def adapt_sql(sql, paramstyle): kwargs[key] = expr result.append('%%(%s)s' % key) else: throw(NotImplementedError) - adapted_sql = ''.join(result) - if args: - source = '(%s,)' % ', '.join(args) - code = compile(source, '', 'eval') - elif kwargs: - source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items()) + if args or kwargs: + adapted_sql = ''.join(result) + if args: source = '(%s,)' % ', '.join(args) + else: source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items()) code = compile(source, '', 'eval') else: + adapted_sql = original_sql.replace('$$', '$') code = compile('None', '', 'eval') - if paramstyle in ('format', 'pyformat'): sql = sql.replace('%%', '%') result = adapted_sql, code adapted_sql_cache[(sql, paramstyle)] = result return result -num_counter = itertools.count() + +class PrefetchContext(object): + def __init__(self, database=None): + self.database = database + self.attrs_to_prefetch_dict = defaultdict(set) + self.entities_to_prefetch = set() + self.relations_to_prefetch_cache = {} + def copy(self): + result = PrefetchContext(self.database) + result.attrs_to_prefetch_dict = self.attrs_to_prefetch_dict.copy() + result.entities_to_prefetch = self.entities_to_prefetch.copy() + return result + def __enter__(self): + assert local.prefetch_context is None + local.prefetch_context = self + def __exit__(self, exc_type, exc_val, exc_tb): + assert local.prefetch_context is self + local.prefetch_context = None + def get_frozen_attrs_to_prefetch(self, entity): + attrs_to_prefetch = self.attrs_to_prefetch_dict.get(entity, ()) + if type(attrs_to_prefetch) is set: + attrs_to_prefetch = frozenset(attrs_to_prefetch) + self.attrs_to_prefetch_dict[entity] = attrs_to_prefetch + return attrs_to_prefetch + def get_relations_to_prefetch(self, entity): + result = self.relations_to_prefetch_cache.get(entity) + if result is None: + attrs_to_prefetch = self.attrs_to_prefetch_dict[entity] + result = tuple(attr for attr in entity._attrs_ + if attr.is_relation and ( + attr in attrs_to_prefetch or + attr.py_type in self.entities_to_prefetch and not attr.is_collection)) + self.relations_to_prefetch_cache[entity] = result + return result + class Local(localbase): def __init__(local): + local.debug = False + local.show_values = None + local.debug_stack = [] local.db2cache = {} local.db_context_counter = 0 local.db_session = None + local.prefetch_context = None local.current_user = None local.perms_context = None local.user_groups_cache = {} local.user_roles_cache = defaultdict(dict) + def push_debug_state(local, debug, show_values): + local.debug_stack.append((local.debug, local.show_values)) + if not suppress_debug_change: + local.debug = debug + local.show_values = show_values + def pop_debug_state(local): + local.debug, local.show_values = local.debug_stack.pop() local = Local() @@ -276,14 +363,28 @@ def transact_reraise(exc_class, exceptions): reraise(exc_class, new_exc, tb) finally: del exceptions, exc, tb, new_exc +def rollback_and_reraise(exc_info): + try: + rollback() + finally: + reraise(*exc_info) + @cut_traceback def commit(): caches = _get_caches() if not caches: return + + try: + for cache in caches: + cache.flush() + except: + rollback_and_reraise(sys.exc_info()) + primary_cache = caches[0] other_caches = caches[1:] exceptions = [] - try: primary_cache.commit() + try: + primary_cache.commit() except: exceptions.append(sys.exc_info()) for cache in other_caches: @@ -315,10 +416,12 @@ def rollback(): select_re = re.compile(r'\s*select\b', re.IGNORECASE) class DBSessionContextManager(object): - __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', 'immediate', 'ddl', 'serializable', 'strict' - def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, - retry_exceptions=(TransactionError,), allowed_exceptions=()): - if retry is not 0: + __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', \ + 'immediate', 'ddl', 'serializable', 'strict', 'optimistic', \ + 'sql_debug', 'show_values' + def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, optimistic=True, + retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None): + if retry != 0: if type(retry) is not int: throw(TypeError, "'retry' parameter of db_session must be of integer type. Got: %s" % type(retry)) if retry < 0: throw(TypeError, @@ -332,10 +435,13 @@ def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False db_session.retry = retry db_session.ddl = ddl db_session.serializable = serializable - db_session.immediate = immediate or ddl or serializable + db_session.immediate = immediate or ddl or serializable or not optimistic db_session.strict = strict + db_session.optimistic = optimistic and not serializable db_session.retry_exceptions = retry_exceptions db_session.allowed_exceptions = allowed_exceptions + db_session.sql_debug = sql_debug + db_session.show_values = show_values def __call__(db_session, *args, **kwargs): if not args and not kwargs: return db_session if len(args) > 1: throw(TypeError, @@ -344,41 +450,50 @@ def __call__(db_session, *args, **kwargs): if kwargs: throw(TypeError, 'Pass only keyword arguments to db_session or use db_session as decorator') func = args[0] - if not isgeneratorfunction(func): - return db_session._wrap_function(func) - return db_session._wrap_generator_function(func) + if isgeneratorfunction(func) or hasattr(inspect, 'iscoroutinefunction') and inspect.iscoroutinefunction(func): + return db_session._wrap_coroutine_or_generator_function(func) + return db_session._wrap_function(func) def __enter__(db_session): - if db_session.retry is not 0: throw(TypeError, + if db_session.retry != 0: throw(TypeError, "@db_session can accept 'retry' parameter only when used as decorator and not as context manager") - if db_session.ddl: throw(TypeError, - "@db_session can accept 'ddl' parameter only when used as decorator and not as context manager") db_session._enter() def _enter(db_session): if local.db_session is None: assert not local.db_context_counter local.db_session = db_session + elif db_session.ddl and not local.db_session.ddl: throw(TransactionError, + 'Cannot start ddl transaction inside non-ddl transaction') elif db_session.serializable and not local.db_session.serializable: throw(TransactionError, 'Cannot start serializable transaction inside non-serializable transaction') local.db_context_counter += 1 + if db_session.sql_debug is not None: + local.push_debug_state(db_session.sql_debug, db_session.show_values) def __exit__(db_session, exc_type=None, exc=None, tb=None): local.db_context_counter -= 1 - if local.db_context_counter: return - assert local.db_session is db_session + try: + if not local.db_context_counter: + assert local.db_session is db_session + db_session._commit_or_rollback(exc_type, exc, tb) + finally: + if db_session.sql_debug is not None: + local.pop_debug_state() + def _commit_or_rollback(db_session, exc_type, exc, tb): try: if exc_type is None: can_commit = True elif not callable(db_session.allowed_exceptions): can_commit = issubclass(exc_type, tuple(db_session.allowed_exceptions)) else: - # exc can be None in Python 2.6 even if exc_type is not None - try: can_commit = exc is not None and db_session.allowed_exceptions(exc) - except: - rollback() - raise + assert exc is not None # exc can be None in Python 2.6 even if exc_type is not None + try: can_commit = db_session.allowed_exceptions(exc) + except: rollback_and_reraise(sys.exc_info()) if can_commit: commit() for cache in _get_caches(): cache.release() assert not local.db2cache - else: rollback() + else: + try: rollback() + except: + if exc_type is None: raise # if exc_type is not None it will be reraised outside of __exit__ finally: del exc, tb local.db_session = None @@ -386,28 +501,54 @@ def __exit__(db_session, exc_type=None, exc=None, tb=None): local.user_roles_cache.clear() def _wrap_function(db_session, func): def new_func(func, *args, **kwargs): - if db_session.ddl and local.db_context_counter: - if isinstance(func, types.FunctionType): func = func.__name__ + '()' - throw(TransactionError, '%s cannot be called inside of db_session' % func) + if local.db_context_counter: + if db_session.ddl: + fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func + throw(TransactionError, '@db_session-decorated %s function with `ddl` option ' + 'cannot be called inside of another db_session' % fname) + if db_session.retry: + fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func + message = '@db_session decorator with `retry=%d` option is ignored for %s function ' \ + 'because it is called inside another db_session' % (db_session.retry, fname) + warnings.warn(message, PonyRuntimeWarning, stacklevel=3) + if db_session.sql_debug is None: + return func(*args, **kwargs) + local.push_debug_state(db_session.sql_debug, db_session.show_values) + try: + return func(*args, **kwargs) + finally: + local.pop_debug_state() + exc = tb = None try: for i in xrange(db_session.retry+1): db_session._enter() exc_type = exc = tb = None - try: return func(*args, **kwargs) + try: + result = func(*args, **kwargs) + commit() + return result except: - exc_type, exc, tb = sys.exc_info() # exc can be None in Python 2.6 - retry_exceptions = db_session.retry_exceptions - if not callable(retry_exceptions): - do_retry = issubclass(exc_type, tuple(retry_exceptions)) + exc_type, exc, tb = sys.exc_info() + if getattr(exc, 'should_retry', False): + do_retry = True else: - do_retry = exc is not None and retry_exceptions(exc) - if not do_retry: raise - finally: db_session.__exit__(exc_type, exc, tb) + retry_exceptions = db_session.retry_exceptions + if not callable(retry_exceptions): + do_retry = issubclass(exc_type, tuple(retry_exceptions)) + else: + assert exc is not None # exc can be None in Python 2.6 + do_retry = retry_exceptions(exc) + if not do_retry: + raise + rollback() + finally: + db_session.__exit__(exc_type, exc, tb) reraise(exc_type, exc, tb) - finally: del exc, tb + finally: + del exc, tb return decorator(new_func, func) - def _wrap_generator_function(db_session, gen_func): + def _wrap_coroutine_or_generator_function(db_session, gen_func): for option in ('ddl', 'retry', 'serializable'): if getattr(db_session, option, None): throw(TypeError, "db_session with `%s` option cannot be applied to generator function" % option) @@ -425,7 +566,8 @@ def interact(iterator, input=None, exc_info=None): if throw_ is None: reraise(*exc_info) return throw_(*exc_info) - def new_gen_func(gen_func, *args, **kwargs): + @wraps(gen_func) + def new_gen_func(*args, **kwargs): db2cache_copy = {} def wrapped_interact(iterator, input=None, exc_info=None): @@ -436,28 +578,99 @@ def wrapped_interact(iterator, input=None, exc_info=None): local.db_session = db_session local.db2cache.update(db2cache_copy) db2cache_copy.clear() + if db_session.sql_debug is not None: + local.push_debug_state(db_session.sql_debug, db_session.show_values) try: try: output = interact(iterator, input, exc_info) except StopIteration as e: + commit() for cache in _get_caches(): - if cache.modified or cache.in_transaction: throw(TransactionError, - 'You need to manually commit() changes before exiting from the generator') - raise + cache.release() + assert not local.db2cache + raise e for cache in _get_caches(): if cache.modified or cache.in_transaction: throw(TransactionError, - 'You need to manually commit() changes before yielding from the generator') + 'You need to manually commit() changes before suspending the generator') except: - rollback() - raise + rollback_and_reraise(sys.exc_info()) else: return output finally: + if db_session.sql_debug is not None: + local.pop_debug_state() db2cache_copy.update(local.db2cache) local.db2cache.clear() local.db_context_counter = 0 local.db_session = None + gen = gen_func(*args, **kwargs) + iterator = gen.__await__() if hasattr(gen, '__await__') else iter(gen) + try: + output = wrapped_interact(iterator) + while True: + try: + input = yield output + except: + output = wrapped_interact(iterator, exc_info=sys.exc_info()) + else: + output = wrapped_interact(iterator, input) + except StopIteration: + assert not db2cache_copy and not local.db2cache + return + + if hasattr(types, 'coroutine'): + new_gen_func = types.coroutine(new_gen_func) + return new_gen_func + +db_session = DBSessionContextManager() + + +class SQLDebuggingContextManager(object): + def __init__(self, debug=True, show_values=None): + self.debug = debug + self.show_values = show_values + def __call__(self, *args, **kwargs): + if not kwargs and len(args) == 1 and callable(args[0]): + arg = args[0] + if not isgeneratorfunction(arg): + return self._wrap_function(arg) + return self._wrap_generator_function(arg) + return self.__class__(*args, **kwargs) + def __enter__(self): + local.push_debug_state(self.debug, self.show_values) + def __exit__(self, exc_type=None, exc=None, tb=None): + local.pop_debug_state() + def _wrap_function(self, func): + def new_func(func, *args, **kwargs): + self.__enter__() + try: + return func(*args, **kwargs) + finally: + self.__exit__() + return decorator(new_func, func) + def _wrap_generator_function(self, gen_func): + def interact(iterator, input=None, exc_info=None): + if exc_info is None: + return next(iterator) if input is None else iterator.send(input) + + if exc_info[0] is GeneratorExit: + close = getattr(iterator, 'close', None) + if close is not None: close() + reraise(*exc_info) + + throw_ = getattr(iterator, 'throw', None) + if throw_ is None: reraise(*exc_info) + return throw_(*exc_info) + + def new_gen_func(gen_func, *args, **kwargs): + def wrapped_interact(iterator, input=None, exc_info=None): + self.__enter__() + try: + return interact(iterator, input, exc_info) + finally: + self.__exit__() + gen = gen_func(*args, **kwargs) iterator = iter(gen) output = wrapped_interact(iterator) @@ -473,11 +686,12 @@ def wrapped_interact(iterator, input=None, exc_info=None): return return decorator(new_gen_func, gen_func) -db_session = DBSessionContextManager() +sql_debugging = SQLDebuggingContextManager() -def throw_db_session_is_over(obj, attr): - throw(DatabaseSessionIsOver, 'Cannot read value of %s.%s: the database session is over' - % (safe_repr(obj), attr.name)) + +def throw_db_session_is_over(action, obj, attr=None): + msg = 'Cannot %s %s%s: the database session is over' + throw(DatabaseSessionIsOver, msg % (action, safe_repr(obj), '.%s' % attr.name if attr else '')) def with_transaction(*args, **kwargs): deprecated(3, "@with_transaction decorator is deprecated, use @db_session decorator instead") @@ -494,6 +708,31 @@ def db_decorator(func, *args, **kwargs): if web: throw(web.Http404NotFound) raise +known_providers = ('sqlite', 'postgres', 'mysql', 'oracle') + +class OnConnectDecorator(object): + + @staticmethod + def check_provider(provider): + if provider: + if not isinstance(provider, basestring): + throw(TypeError, "'provider' option should be type of 'string', got %r" % type(provider).__name__) + if provider not in known_providers: + throw(BindingError, 'Unknown provider %s' % provider) + + def __init__(self, database, provider): + OnConnectDecorator.check_provider(provider) + self.provider = provider + self.database = database + + def __call__(self, func=None, provider=None): + if isinstance(func, types.FunctionType): + self.database._on_connect_funcs.append((func, provider or self.provider)) + if not provider and func is basestring: + provider = func + OnConnectDecorator.check_provider(provider) + return OnConnectDecorator(self.database, provider) + class Database(object): def __deepcopy__(self, memo): return self # Database cannot be cloned by deepcopy() @@ -516,27 +755,40 @@ def __init__(self, *args, **kwargs): self._global_stats_lock = RLock() self._dblocal = DbLocal() - self.provider = None + self.on_connect = OnConnectDecorator(self, None) + self._on_connect_funcs = [] + self.provider = self.provider_name = None if args or kwargs: self._bind(*args, **kwargs) + def call_on_connect(database, con): + for func, provider in database._on_connect_funcs: + if not provider or provider == database.provider_name: + func(database, con) + con.commit() @cut_traceback def bind(self, *args, **kwargs): self._bind(*args, **kwargs) def _bind(self, *args, **kwargs): # argument 'self' cannot be named 'database', because 'database' can be in kwargs if self.provider is not None: - throw(TypeError, 'Database object was already bound to %s provider' % self.provider.dialect) - if not args: - throw(TypeError, 'Database provider should be specified as a first positional argument') - provider, args = args[0], args[1:] + throw(BindingError, 'Database object was already bound to %s provider' % self.provider.dialect) + if len(args) == 1 and not kwargs and hasattr(args[0], 'keys'): + args, kwargs = (), args[0] + provider = None + if args: provider, args = args[0], args[1:] + elif 'provider' not in kwargs: throw(TypeError, 'Database provider is not specified') + else: provider = kwargs.pop('provider') if isinstance(provider, type) and issubclass(provider, DBAPIProvider): provider_cls = provider else: - if not isinstance(provider, basestring): throw(TypeError) + if not isinstance(provider, basestring): + throw(TypeError, 'Provider name should be string. Got: %r' % type(provider).__name__) if provider == 'pygresql': throw(TypeError, 'Pony no longer supports PyGreSQL module. Please use psycopg2 instead.') + self.provider_name = provider provider_module = import_module('pony.orm.dbproviders.' + provider) provider_cls = provider_module.provider_cls - self.provider = provider = provider_cls(*args, **kwargs) + kwargs['pony_call_on_connect'] = self.call_on_connect + self.provider = provider_cls(*args, **kwargs) @property def last_sql(database): return database._dblocal.last_sql @@ -547,20 +799,31 @@ def _update_local_stat(database, sql, query_start_time): dblocal = database._dblocal dblocal.last_sql = sql stats = dblocal.stats + query_end_time = time() + duration = query_end_time - query_start_time + stat = stats.get(sql) - if stat is not None: stat.query_executed(query_start_time) - else: stats[sql] = QueryStat(sql, query_start_time) + if stat is not None: + stat.query_executed(duration) + else: + stats[sql] = QueryStat(sql, duration) + + total_stat = stats.get(None) + if total_stat is not None: + total_stat.query_executed(duration) + else: + stats[None] = QueryStat(None, duration) def merge_local_stats(database): setdefault = database._global_stats.setdefault with database._global_stats_lock: for sql, stat in iteritems(database._dblocal.stats): global_stat = setdefault(sql, stat) if global_stat is not stat: global_stat.merge(stat) - database._dblocal.stats.clear() + database._dblocal.stats = {None: QueryStat(None)} @property def global_stats(database): with database._global_stats_lock: - return dict((sql, stat.copy()) for sql, stat in iteritems(database._global_stats)) + return {sql: stat.copy() for sql, stat in iteritems(database._global_stats)} @property def global_stats_lock(database): deprecated(3, "global_stats_lock is deprecated, just use global_stats property without any locking") @@ -579,7 +842,7 @@ def get_connection(database): def disconnect(database): provider = database.provider if provider is None: return - if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_sesison') + if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_session') cache = local.db2cache.get(database) if cache is not None: cache.rollback() provider.disconnect() @@ -598,14 +861,17 @@ def flush(database): @cut_traceback def commit(database): cache = local.db2cache.get(database) - if cache is not None: cache.commit() + if cache is not None: + cache.flush_and_commit() @cut_traceback def rollback(database): cache = local.db2cache.get(database) - if cache is not None: cache.rollback() + if cache is not None: + try: cache.rollback() + except: transact_reraise(RollbackException, [sys.exc_info()]) @cut_traceback def execute(database, sql, globals=None, locals=None): - return database._exec_raw_sql(sql, globals, locals, frame_depth=3, start_transaction=True) + return database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1, start_transaction=True) def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction=False): provider = database.provider if provider is None: throw(MappingError, 'Database object is not bound with a provider yet') @@ -621,7 +887,7 @@ def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction @cut_traceback def select(database, sql, globals=None, locals=None, frame_depth=0): if not select_re.match(sql): sql = 'select ' + sql - cursor = database._exec_raw_sql(sql, globals, locals, frame_depth + 3) + cursor = database._exec_raw_sql(sql, globals, locals, frame_depth+cut_traceback_depth+1) max_fetch_count = options.MAX_FETCH_COUNT if max_fetch_count is not None: result = cursor.fetchmany(max_fetch_count) @@ -637,7 +903,7 @@ def select(database, sql, globals=None, locals=None, frame_depth=0): return [ row_class(row) for row in result ] @cut_traceback def get(database, sql, globals=None, locals=None): - rows = database.select(sql, globals, locals, frame_depth=3) + rows = database.select(sql, globals, locals, frame_depth=cut_traceback_depth+1) if not rows: throw(RowNotFound) if len(rows) > 1: throw(MultipleRowsFound) row = rows[0] @@ -645,7 +911,7 @@ def get(database, sql, globals=None, locals=None): @cut_traceback def exists(database, sql, globals=None, locals=None): if not select_re.match(sql): sql = 'select ' + sql - cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=3) + cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1) result = cursor.fetchone() return bool(result) @cut_traceback @@ -675,17 +941,18 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti if start_transaction: cache.immediate = True connection = cache.prepare_connection_for_query_execution() cursor = connection.cursor() - if debug: log_sql(sql, arguments) + if local.debug: log_sql(sql, arguments) provider = database.provider t = time() try: new_id = provider.execute(cursor, sql, arguments, returning_id) except Exception as e: connection = cache.reconnect(e) cursor = connection.cursor() - if debug: log_sql(sql, arguments) + if local.debug: log_sql(sql, arguments) t = time() new_id = provider.execute(cursor, sql, arguments, returning_id) - if cache.immediate: cache.in_transaction = True + if cache.immediate: + cache.in_transaction = True database._update_local_stat(sql, t) if not returning_id: return cursor if PY2 and type(new_id) is long: new_id = int(new_id) @@ -694,7 +961,7 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti def generate_mapping(database, filename=None, check_tables=True, create_tables=False): provider = database.provider if provider is None: throw(MappingError, 'Database object is not bound with a provider yet') - if database.schema: throw(MappingError, 'Mapping was already generated') + if database.schema: throw(BindingError, 'Mapping was already generated') if filename is not None: throw(NotImplementedError) schema = database.schema = provider.dbschema_cls(provider) entities = list(sorted(database.entities.values(), key=attrgetter('_id_'))) @@ -702,6 +969,8 @@ def generate_mapping(database, filename=None, check_tables=True, create_tables=F entity._resolve_attr_types_() for entity in entities: entity._link_reverse_attrs_() + for entity in entities: + entity._check_table_options_() def get_columns(table, column_names): column_dict = table.column_dict @@ -713,7 +982,8 @@ def get_columns(table, column_names): is_subclass = entity._root_ is not entity if is_subclass: - if table_name is not None: throw(NotImplementedError) + if table_name is not None: throw(NotImplementedError, + 'Cannot specify table name for entity %r which is subclass of %r' % (entity.__name__, entity._root_.__name__)) table_name = entity._root_._table_ entity._table_ = table_name elif table_name is None: @@ -722,14 +992,8 @@ def get_columns(table, column_names): else: assert isinstance(table_name, (basestring, tuple)) table = schema.tables.get(table_name) - if table is None: table = schema.add_table(table_name) - elif table.entities: - for e in table.entities: - if e._root_ is not entity._root_: - throw(MappingError, "Entities %s and %s cannot be mapped to table %s " - "because they don't belong to the same hierarchy" - % (e, entity, table_name)) - table.entities.add(entity) + if table is None: table = schema.add_table(table_name, entity) + else: table.add_entity(entity) for attr in entity._new_attrs_: if attr.is_collection: @@ -760,12 +1024,15 @@ def get_columns(table, column_names): if not attr.table: seq_counter = itertools.count(2) while m2m_table is not None: - new_table_name = table_name + '_%d' % next(seq_counter) + if isinstance(table_name, basestring): + new_table_name = table_name + '_%d' % next(seq_counter) + else: + schema_name, base_name = provider.split_table_name(table_name) + new_table_name = schema_name, base_name + '_%d' % next(seq_counter) m2m_table = schema.tables.get(new_table_name) table_name = new_table_name - elif m2m_table.entities or m2m_table.m2m: - if isinstance(table_name, tuple): table_name = '.'.join(table_name) - throw(MappingError, "Table name '%s' is already in use" % table_name) + elif m2m_table.entities or m2m_table.m2m: throw(MappingError, + "Table name %s is already in use" % provider.format_table_name(table_name)) else: throw(NotImplementedError) attr.table = reverse.table = table_name m2m_table = schema.add_table(table_name) @@ -782,7 +1049,7 @@ def get_columns(table, column_names): m2m_table.m2m.add(reverse) else: if attr.is_required: pass - elif not attr.is_string: + elif not attr.type_has_empty_value: if attr.nullable is False: throw(TypeError, 'Optional attribute with non-string type %s must be nullable' % attr) attr.nullable = True @@ -844,19 +1111,32 @@ def get_columns(table, column_names): m2m_table = schema.tables[attr.table] parent_columns = get_columns(table, entity._pk_columns_) child_columns = get_columns(m2m_table, reverse.columns) - m2m_table.add_foreign_key(None, child_columns, table, parent_columns, attr.index) + on_delete = 'CASCADE' + m2m_table.add_foreign_key(reverse.fk_name, child_columns, table, parent_columns, + attr.index, on_delete) if attr.symmetric: - child_columns = get_columns(m2m_table, attr.reverse_columns) - m2m_table.add_foreign_key(None, child_columns, table, parent_columns) + reverse_child_columns = get_columns(m2m_table, attr.reverse_columns) + m2m_table.add_foreign_key(attr.reverse_fk_name, reverse_child_columns, table, parent_columns, + attr.reverse_index, on_delete) elif attr.reverse and attr.columns: rentity = attr.reverse.entity parent_table = schema.tables[rentity._table_] parent_columns = get_columns(parent_table, rentity._pk_columns_) child_columns = get_columns(table, attr.columns) - table.add_foreign_key(None, child_columns, parent_table, parent_columns, attr.index) + if attr.reverse.cascade_delete: + on_delete = 'CASCADE' + elif isinstance(attr, Optional) and attr.nullable: + on_delete = 'SET NULL' + else: + on_delete = None + table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index, + on_delete, interleave=attr.interleave) elif attr.index and attr.columns: - columns = tuple(imap(table.column_dict.__getitem__, attr.columns)) - table.add_index(attr.index, columns, is_unique=attr.is_unique) + if isinstance(attr.py_type, Array) and provider.dialect != 'PostgreSQL': + pass # GIN indexes are supported only in PostgreSQL + else: + columns = tuple(imap(table.column_dict.__getitem__, attr.columns)) + table.add_index(attr.index, columns, is_unique=attr.is_unique) entity._initialize_bits_() if create_tables: database.create_tables(check_tables) @@ -864,7 +1144,6 @@ def get_columns(table, column_names): @cut_traceback @db_session(ddl=True) def drop_table(database, table_name, if_exists=False, with_all_data=False): - table_name = database._get_table_name(table_name) database._drop_tables([ table_name ], if_exists, with_all_data, try_normalized=True) def _get_table_name(database, table_name): if isinstance(table_name, EntityMeta): @@ -878,9 +1157,13 @@ def _get_table_name(database, table_name): elif table_name is None: if database.schema is None: throw(MappingError, 'No mapping was generated for the database') else: throw(TypeError, 'Table name cannot be None') - elif not isinstance(table_name, basestring): - throw(TypeError, 'Invalid table name: %r' % table_name) - table_name = table_name[:] # table_name = templating.plainstr(table_name) + elif isinstance(table_name, tuple): + for component in table_name: + if not isinstance(component, basestring): + throw(TypeError, 'Invalid table name component: {}'.format(component)) + elif isinstance(table_name, basestring): + table_name = table_name[:] # table_name = templating.plainstr(table_name) + else: throw(TypeError, 'Invalid table name: {}'.format(table_name)) return table_name @cut_traceback @db_session(ddl=True) @@ -897,19 +1180,24 @@ def _drop_tables(database, table_names, if_exists, with_all_data, try_normalized if provider.table_exists(connection, table_name): existed_tables.append(table_name) elif not if_exists: if try_normalized: - normalized_table_name = provider.normalize_name(table_name) - if normalized_table_name != table_name \ - and provider.table_exists(connection, normalized_table_name): - throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)' - % (table_name, normalized_table_name)) - throw(TableDoesNotExist, 'Table %s does not exist' % table_name) + if isinstance(table_name, basestring): + normalized_table_name = provider.normalize_name(table_name) + else: + schema_name, base_name = provider.split_table_name(table_name) + normalized_table_name = schema_name, provider.normalize_name(base_name) + if normalized_table_name != table_name and provider.table_exists(connection, normalized_table_name): + throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)' % ( + provider.format_table_name(table_name), + provider.format_table_name(normalized_table_name))) + throw(TableDoesNotExist, 'Table %s does not exist' % provider.format_table_name(table_name)) if not with_all_data: for table_name in existed_tables: if provider.table_has_data(connection, table_name): throw(TableIsNotEmpty, 'Cannot drop table %s because it is not empty. Specify option ' - 'with_all_data=True if you want to drop table with all data' % table_name) + 'with_all_data=True if you want to drop table with all data' + % provider.format_table_name(table_name)) for table_name in existed_tables: - if debug: log_orm('DROPPING TABLE %s' % table_name) + if local.debug: log_orm('DROPPING TABLE %s' % provider.format_table_name(table_name)) provider.drop_table(connection, table_name) @cut_traceback @db_session(ddl=True) @@ -971,7 +1259,7 @@ def _get_schema_dict(database): return result def _get_schema_json(database): schema_json = json.dumps(database._get_schema_dict(), default=basic_converter, sort_keys=True) - schema_hash = md5(schema_json).hexdigest() + schema_hash = md5(schema_json.encode('utf-8')).hexdigest() return schema_json, schema_hash @cut_traceback def to_json(database, data, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): @@ -994,7 +1282,8 @@ def user_has_no_rights_to_see(obj, attr=None): caches = set() def obj_converter(obj): if not isinstance(obj, Entity): return converter(obj) - caches.add(obj._session_cache_) + cache = obj._session_cache_ + if cache is not None: caches.add(cache) if len(caches) > 1: throw(TransactionError, 'An attempt to serialize objects belonging to different transactions') if not can_view(user, obj): @@ -1160,7 +1449,7 @@ def deserialize(x): if t is list: return list(imap(deserialize, x)) if t is dict: if '_id_' not in x: - return dict((key, deserialize(val)) for key, val in iteritems(x)) + return {key: deserialize(val) for key, val in iteritems(x)} obj = objmap.get(x['_id_']) if obj is None: entity_name = x['class'] @@ -1324,7 +1613,7 @@ def get_user_groups(user): result = local.user_groups_cache.get(user) if result is not None: return result if user is None: return anybody_frozenset - result = set(['anybody']) + result = {'anybody'} for cls, func in usergroup_functions: if cls is None or isinstance(user, cls): groups = func(user) @@ -1400,14 +1689,12 @@ def decorator(func): class DbLocal(localbase): def __init__(dblocal): - dblocal.stats = {} + dblocal.stats = {None: QueryStat(None)} dblocal.last_sql = None class QueryStat(object): - def __init__(stat, sql, query_start_time=None): - if query_start_time is not None: - query_end_time = time() - duration = query_end_time - query_start_time + def __init__(stat, sql, duration=None): + if duration is not None: stat.min_time = stat.max_time = stat.sum_time = duration stat.db_count = 1 stat.cache_count = 0 @@ -1420,9 +1707,7 @@ def copy(stat): result = object.__new__(QueryStat) result.__dict__.update(stat.__dict__) return result - def query_executed(stat, query_start_time): - query_end_time = time() - duration = query_end_time - query_start_time + def query_executed(stat, duration): if stat.db_count: stat.min_time = builtins.min(stat.min_time, duration) stat.max_time = builtins.max(stat.max_time, duration) @@ -1447,6 +1732,8 @@ def avg_time(stat): if not stat.db_count: return None return stat.sum_time / stat.db_count +num_counter = itertools.count() + class SessionCache(object): def __init__(cache, database): cache.is_alive = True @@ -1463,6 +1750,7 @@ def __init__(cache, database): cache.objects_to_save = [] cache.saved_objects = [] cache.query_results = {} + cache.dbvals_deduplication_cache = defaultdict(dict) cache.modified = False cache.db_session = db_session = local.db_session cache.immediate = db_session is not None and db_session.immediate @@ -1476,12 +1764,17 @@ def connect(cache): assert cache.connection is None if cache.in_transaction: throw(ConnectionClosedError, 'Transaction cannot be continued because database connection failed') - provider = cache.database.provider - connection = provider.connect() - try: provider.set_transaction_mode(connection, cache) # can set cache.in_transaction + database = cache.database + provider = database.provider + connection, is_new_connection = provider.connect() + if is_new_connection: + database.call_on_connect(connection) + try: + provider.set_transaction_mode(connection, cache) # can set cache.in_transaction except: provider.drop(connection, cache) raise + cache.connection = connection return connection def reconnect(cache, exc): @@ -1489,7 +1782,7 @@ def reconnect(cache, exc): if exc is not None: exc = getattr(exc, 'original_exc', exc) if not provider.should_reconnect(exc): reraise(*sys.exc_info()) - if debug: log_orm('CONNECTION FAILED: %s' % exc) + if local.debug: log_orm('CONNECTION FAILED: %s' % exc) connection = cache.connection assert connection is not None cache.connection = None @@ -1503,7 +1796,7 @@ def prepare_connection_for_query_execution(cache): # in the interactive mode, outside of the db_session if cache.in_transaction or cache.modified: local.db_session = None - try: cache.commit() + try: cache.flush_and_commit() finally: local.db_session = db_session cache.db_session = db_session cache.immediate = cache.immediate or db_session.immediate @@ -1516,16 +1809,23 @@ def prepare_connection_for_query_execution(cache): except Exception as e: connection = cache.reconnect(e) if not cache.noflush_counter and cache.modified: cache.flush() return connection + def flush_and_commit(cache): + try: cache.flush() + except: + cache.rollback() + raise + try: cache.commit() + except: transact_reraise(CommitException, [sys.exc_info()]) def commit(cache): assert cache.is_alive - database = cache.database - provider = database.provider try: if cache.modified: cache.flush() if cache.in_transaction: assert cache.connection is not None - provider.commit(cache.connection, cache) + cache.database.provider.commit(cache.connection, cache) cache.for_update.clear() + cache.query_results.clear() + cache.max_id_cache.clear() cache.immediate = True except: cache.rollback() @@ -1544,21 +1844,30 @@ def close(cache, rollback=True): connection = cache.connection if connection is None: return cache.connection = None - if rollback: - try: provider.rollback(connection, cache) - except: - provider.drop(connection, cache) - raise - provider.release(connection, cache) - db_session = cache.db_session or local.db_session - if db_session and db_session.strict: - cache.clear() - def clear(cache): - for obj in cache.objects: - obj._vals_ = obj._dbvals_ = obj._session_cache_ = None - cache.objects = cache.indexes = cache.seeds = cache.for_update = cache.modified_collections \ - = cache.objects_to_save = cache.saved_objects = cache.query_results \ - = cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + + try: + if rollback: + try: provider.rollback(connection, cache) + except: + provider.drop(connection, cache) + raise + provider.release(connection, cache) + finally: + db_session = cache.db_session or local.db_session + if db_session and db_session.strict: + for obj in cache.objects: + obj._vals_ = obj._dbvals_ = obj._session_cache_ = None + cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + else: + for obj in cache.objects: + obj._dbvals_ = obj._session_cache_ = None + for attr, setdata in iteritems(obj._vals_): + if attr.is_collection: + if not setdata.is_fully_loaded: obj._vals_[attr] = None + + cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ + = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ + = cache.modified_collections = cache.collection_statistics = cache.dbvals_deduplication_cache = None @contextmanager def flush_disabled(cache): cache.noflush_counter += 1 @@ -1568,34 +1877,39 @@ def flush(cache): if cache.noflush_counter: return assert cache.is_alive assert not cache.saved_objects - if not cache.immediate: cache.immediate = True - for i in xrange(50): - if not cache.modified: return - - with cache.flush_disabled(): - for obj in cache.objects_to_save: # can grow during iteration - if obj is not None: obj._before_save_() - - cache.query_results.clear() - modified_m2m = cache._calc_modified_m2m() - for attr, (added, removed) in iteritems(modified_m2m): - if not removed: continue - attr.remove_m2m(removed) - for obj in cache.objects_to_save: - if obj is not None: obj._save_() - for attr, (added, removed) in iteritems(modified_m2m): - if not added: continue - attr.add_m2m(added) - - cache.max_id_cache.clear() - cache.modified_collections.clear() - cache.objects_to_save[:] = () - cache.modified = False - - cache.call_after_save_hooks() - else: - if cache.modified: throw(TransactionError, - 'Recursion depth limit reached in obj._after_save_() call') + prev_immediate = cache.immediate + cache.immediate = True + try: + for i in xrange(50): + if not cache.modified: return + + with cache.flush_disabled(): + for obj in cache.objects_to_save: # can grow during iteration + if obj is not None: obj._before_save_() + + cache.query_results.clear() + modified_m2m = cache._calc_modified_m2m() + for attr, (added, removed) in iteritems(modified_m2m): + if not removed: continue + attr.remove_m2m(removed) + for obj in cache.objects_to_save: + if obj is not None: obj._save_() + for attr, (added, removed) in iteritems(modified_m2m): + if not added: continue + attr.add_m2m(added) + + cache.max_id_cache.clear() + cache.modified_collections.clear() + cache.objects_to_save[:] = () + cache.modified = False + + cache.call_after_save_hooks() + else: + if cache.modified: throw(TransactionError, + 'Recursion depth limit reached in obj._after_save_() call') + finally: + if not cache.in_transaction: + cache.immediate = prev_immediate def call_after_save_hooks(cache): saved_objects = cache.saved_objects cache.saved_objects = [] @@ -1627,7 +1941,7 @@ def _calc_modified_m2m(cache): cache.modified_collections.clear() return modified_m2m def update_simple_index(cache, obj, attr, old_val, new_val, undo): - assert old_val != new_val + if old_val == new_val: return cache_index = cache.indexes[attr] if new_val is not None: obj2 = cache_index.setdefault(new_val, obj) @@ -1636,7 +1950,7 @@ def update_simple_index(cache, obj, attr, old_val, new_val, undo): if old_val is not None: del cache_index[old_val] undo.append((cache_index, old_val, new_val)) def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval): - assert old_dbval != new_dbval + if old_dbval == new_dbval: return cache_index = cache.indexes[attr] if new_dbval is not None: obj2 = cache_index.setdefault(new_dbval, obj) @@ -1649,6 +1963,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): if None in prev_vals: prev_vals = None if None in new_vals: new_vals = None if prev_vals is None and new_vals is None: return + if prev_vals == new_vals: return cache_index = cache.indexes[attrs] if new_vals is not None: obj2 = cache_index.setdefault(new_vals, obj) @@ -1659,6 +1974,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): if prev_vals is not None: del cache_index[prev_vals] undo.append((cache_index, prev_vals, new_vals)) def db_update_composite_index(cache, obj, attrs, prev_vals, new_vals): + if prev_vals == new_vals: return cache_index = cache.indexes[attrs] if None not in new_vals: obj2 = cache_index.setdefault(new_vals, obj) @@ -1700,7 +2016,8 @@ class Attribute(object): 'id', 'pk_offset', 'pk_columns_offset', 'py_type', 'sql_type', 'entity', 'name', \ 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ - 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden' + 'cascade_delete', 'index', 'reverse_index', 'original_default', 'sql_default', 'py_check', 'hidden', \ + 'optimistic', 'fk_name', 'type_has_empty_value', 'interleave' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @cut_traceback @@ -1720,12 +2037,13 @@ def __init__(attr, py_type, *args, **kwargs): if attr.is_pk: attr.pk_offset = 0 else: attr.pk_offset = None attr.id = next(attr_id_counter) - if not isinstance(py_type, (type, basestring, types.FunctionType)): + if not isinstance(py_type, (type, basestring, types.FunctionType, Array)): if py_type is datetime: throw(TypeError, 'datetime is the module and cannot be used as attribute type. Use datetime.datetime instead') throw(TypeError, 'Incorrect type of attribute: %r' % py_type) attr.py_type = py_type attr.is_string = type(py_type) is type and issubclass(py_type, basestring) + attr.type_has_empty_value = attr.is_string or hasattr(attr.py_type, 'default_empty_value') attr.is_collection = isinstance(attr, Collection) attr.is_relation = isinstance(attr.py_type, (EntityMeta, basestring, types.FunctionType)) attr.is_basic = not attr.is_collection and not attr.is_relation @@ -1759,15 +2077,19 @@ def __init__(attr, py_type, *args, **kwargs): if len(attr.columns) == 1: attr.column = attr.columns[0] else: attr.columns = [] attr.index = kwargs.pop('index', None) + attr.reverse_index = kwargs.pop('reverse_index', None) + attr.fk_name = kwargs.pop('fk_name', None) attr.col_paths = [] attr._columns_checked = False attr.composite_keys = [] attr.lazy = kwargs.pop('lazy', getattr(py_type, 'lazy', False)) attr.lazy_sql_cache = None attr.is_volatile = kwargs.pop('volatile', False) + attr.optimistic = kwargs.pop('optimistic', None) attr.sql_default = kwargs.pop('sql_default', None) attr.py_check = kwargs.pop('py_check', None) attr.hidden = kwargs.pop('hidden', False) + attr.interleave = kwargs.pop('interleave', None) attr.kwargs = kwargs attr.converters = [] def _init_(attr, entity, name): @@ -1796,8 +2118,8 @@ def _init_(attr, entity, name): 'Default value for required attribute %s cannot be empty string' % attr) elif attr.default is None and not attr.nullable: throw(TypeError, 'Default value for non-nullable attribute %s cannot be set to None' % attr) - elif attr.is_string and not attr.is_required and not attr.nullable: - attr.default = '' + elif attr.type_has_empty_value and not attr.is_required and not attr.nullable: + attr.default = '' if attr.is_string else attr.py_type.default_empty_value() else: attr.default = None @@ -1820,6 +2142,12 @@ def _init_(attr, entity, name): elif attr.is_unique: throw(TypeError, 'Unique attribute %s cannot be of type float' % attr) if attr.is_volatile and (attr.is_pk or attr.is_collection): throw(TypeError, '%s attribute %s cannot be volatile' % (attr.__class__.__name__, attr)) + + if attr.interleave is not None: + if attr.is_collection: throw(TypeError, + '`interleave` option cannot be specified for %s attribute %r' % (attr.__class__.__name__, attr)) + if attr.interleave not in (True, False): throw(TypeError, + '`interleave` option value should be True, False or None. Got: %r' % attr.interleave) def linked(attr): reverse = attr.reverse if attr.cascade_delete is None: @@ -1831,13 +2159,25 @@ def linked(attr): if reverse.is_collection: throw(TypeError, "'cascade_delete' option cannot be set for attribute %s, " "because reverse attribute %s is collection" % (attr, reverse)) + if attr.is_collection and not reverse.is_collection: + if attr.fk_name is not None: + throw(TypeError, 'You should specify fk_name in %s instead of %s' % (reverse, attr)) + for option in attr.kwargs: + throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option)) @cut_traceback def __repr__(attr): owner_name = attr.entity.__name__ if attr.entity else '?' return '%s.%s' % (owner_name, attr.name or '?') def __lt__(attr, other): return attr.id < other.id + def _get_entity(attr, obj, entity): + if entity is not None: + return entity + if obj is not None: + return obj.__class__ + return attr.entity def validate(attr, val, obj=None, entity=None, from_db=False): + val = deref_proxy(val) if val is None: if not attr.nullable and not from_db and not attr.is_required: # for required attribute the exception will be thrown later with another message @@ -1850,10 +2190,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if callable(default): val = default() else: val = default - if entity is not None: pass - elif obj is not None: entity = obj.__class__ - else: entity = attr.entity - + entity = attr._get_entity(obj, entity) reverse = attr.reverse if not reverse: if isinstance(val, Entity): throw(TypeError, 'Attribute %s must be of %s type. Got: %s' @@ -1865,7 +2202,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if converter is not None: try: if from_db: return converter.sql2py(val) - val = converter.validate(val) + val = converter.validate(val, obj) except UnicodeDecodeError as e: throw(ValueError, 'Value for attribute %s cannot be converted to %s: %s' % (attr, unicode.__name__, truncate_repr(val))) @@ -1880,29 +2217,30 @@ def validate(attr, val, obj=None, entity=None, from_db=False): except TypeError: throw(TypeError, 'Attribute %s must be of %s type. Got: %r' % (attr, rentity.__name__, val)) else: - if obj is not None: cache = obj._session_cache_ + if obj is not None and obj._status_ is not None: cache = obj._session_cache_ else: cache = entity._database_._get_cache() if cache is not val._session_cache_: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') if attr.py_check is not None and not attr.py_check(val): throw(ValueError, 'Check for attribute %s failed. Value: %s' % (attr, truncate_repr(val))) return val - def parse_value(attr, row, offsets): + def parse_value(attr, row, offsets, dbvals_deduplication_cache): assert len(attr.columns) == len(offsets) if not attr.reverse: if len(offsets) > 1: throw(NotImplementedError) offset = offsets[0] - val = attr.validate(row[offset], None, attr.entity, from_db=True) + dbval = attr.validate(row[offset], None, attr.entity, from_db=True) + dbval = deduplicate(dbval, dbvals_deduplication_cache) else: - vals = [ row[offset] for offset in offsets ] - if None in vals: - assert len(set(vals)) == 1 - val = None - else: val = attr.py_type._get_by_raw_pkval_(vals) - return val + dbvals = [ row[offset] for offset in offsets ] + if None in dbvals: + assert len(set(dbvals)) == 1 + dbval = None + else: dbval = attr.py_type._get_by_raw_pkval_(dbvals) + return dbval def load(attr, obj): - if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + cache = obj._session_cache_ + if cache is None or not cache.is_alive: throw_db_session_is_over('load attribute', obj, attr) if not attr.columns: reverse = attr.reverse assert reverse is not None and reverse.columns @@ -1919,7 +2257,7 @@ def load(attr, obj): from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] pk_columns = entity._pk_columns_ pk_converters = entity._pk_converters_ - criteria_list = [ [ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] + criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ] sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] sql, adapter = database._ast2sql(sql_ast) @@ -1929,7 +2267,7 @@ def load(attr, obj): arguments = adapter(obj._get_raw_pkval_()) cursor = database._exec_sql(sql, arguments) row = cursor.fetchone() - dbval = attr.parse_value(row, offsets) + dbval = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) attr.db_set(obj, dbval) else: obj._load_() return obj._vals_[attr] @@ -1937,26 +2275,26 @@ def load(attr, obj): def __get__(attr, obj, cls=None): if obj is None: return attr if attr.pk_offset is not None: return attr.get(obj) - result = attr.get(obj) + value = attr.get(obj) bit = obj._bits_except_volatile_[attr] wbits = obj._wbits_ if wbits is not None and not wbits & bit: obj._rbits_ |= bit - return result + return value def get(attr, obj): if attr.pk_offset is None and obj._status_ in ('deleted', 'cancelled'): throw_object_was_deleted(obj) vals = obj._vals_ - if vals is None: throw_db_session_is_over(obj, attr) + if vals is None: throw_db_session_is_over('read value of', obj, attr) val = vals[attr] if attr in vals else attr.load(obj) if val is not None and attr.reverse and val._subclasses_ and val._status_ not in ('deleted', 'cancelled'): - seeds = obj._session_cache_.seeds[val._pk_attrs_] - if val in seeds: val._load_() + cache = obj._session_cache_ + if cache is not None and val in cache.seeds[val._pk_attrs_]: + val._load_() return val @cut_traceback def __set__(attr, obj, new_val, undo_funcs=None): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot assign new value to attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) reverse = attr.reverse new_val = attr.validate(new_val, obj, from_db=False) @@ -2039,25 +2377,28 @@ def undo_func(): raise def db_set(attr, obj, new_dbval, is_reverse_call=False): cache = obj._session_cache_ - assert cache.is_alive + assert cache is not None and cache.is_alive assert obj._status_ not in created_or_deleted_statuses assert attr.pk_offset is None if new_dbval is NOT_LOADED: assert is_reverse_call old_dbval = obj._dbvals_.get(attr, NOT_LOADED) + if old_dbval is not NOT_LOADED: + if old_dbval == new_dbval or ( + not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): + return - if attr.py_type is float: - if old_dbval is NOT_LOADED: pass - elif attr.converters[0].equals(old_dbval, new_dbval): return - elif old_dbval == new_dbval: return - - bit = obj._bits_[attr] + bit = obj._bits_except_volatile_[attr] if obj._rbits_ & bit: assert old_dbval is not NOT_LOADED - if new_dbval is NOT_LOADED: diff = '' - else: diff = ' (was: %s, now: %s)' % (old_dbval, new_dbval) - throw(UnrepeatableReadError, - 'Value of %s.%s for %s was updated outside of current transaction%s' - % (obj.__class__.__name__, attr.name, obj, diff)) + msg = 'Value of %s for %s was updated outside of current transaction' % (attr, obj) + if new_dbval is not NOT_LOADED: + msg = '%s (was: %s, now: %s)' % (msg, old_dbval, new_dbval) + elif isinstance(attr.reverse, Optional): + assert old_dbval is not None + msg = "Multiple %s objects linked with the same %s object. " \ + "Maybe %s attribute should be Set instead of Optional" \ + % (attr.entity.__name__, old_dbval, attr.reverse) + throw(UnrepeatableReadError, msg) if new_dbval is NOT_LOADED: obj._dbvals_.pop(attr, None) else: obj._dbvals_[attr] = new_dbval @@ -2065,9 +2406,8 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): wbit = bool(obj._wbits_ & bit) if not wbit: old_val = obj._vals_.get(attr, NOT_LOADED) - assert old_val == old_dbval + assert old_val == old_dbval, (old_val, old_dbval) if attr.is_part_of_unique_index: - cache = obj._session_cache_ if attr.is_unique: cache.db_update_simple_index(obj, attr, old_val, new_dbval) get_val = obj._vals_.get for attrs, i in attr.composite_keys: @@ -2076,8 +2416,13 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): vals[i] = new_dbval new_vals = tuple(vals) cache.db_update_composite_index(obj, attrs, old_vals, new_vals) - if new_dbval is NOT_LOADED: obj._vals_.pop(attr, None) - else: obj._vals_[attr] = new_dbval + if new_dbval is NOT_LOADED: + obj._vals_.pop(attr, None) + elif attr.reverse: + obj._vals_[attr] = new_dbval + else: + assert len(attr.converters) == 1 + obj._vals_[attr] = attr.converters[0].dbval2val(new_dbval, obj) reverse = attr.reverse if not reverse: pass @@ -2171,6 +2516,8 @@ def describe(attr): options = [] if attr.args: options.append(', '.join(imap(str, attr.args))) if attr.auto: options.append('auto=True') + for k, v in sorted(attr.kwargs.items()): + options.append('%s=%r' % (k, v)) if not isinstance(attr, PrimaryKey) and attr.is_unique: options.append('unique=True') if attr.default is not None: options.append('default=%r' % attr.default) if not options: options = '' @@ -2186,7 +2533,13 @@ class Required(Attribute): def validate(attr, val, obj=None, entity=None, from_db=False): val = Attribute.validate(attr, val, obj, entity, from_db) if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): - throw(ValueError, 'Attribute %s is required' % (attr if obj is None else '%r.%s' % (obj, attr.name))) + if not from_db: + throw(ValueError, 'Attribute %s is required' % ( + attr if obj is None or obj._status_ is None else '%r.%s' % (obj, attr.name))) + else: + warnings.warn('Database contains %s for required attribute %s' + % ('NULL' if val is None else 'empty string', attr), + DatabaseContainsIncorrectEmptyValue) return val class Discriminator(Required): @@ -2217,7 +2570,7 @@ def process_entity_inheritance(attr, entity): entity._discriminator_ = entity.__name__ discr_value = entity._discriminator_ if discr_value is not None: - try: entity._discriminator_ = discr_value = attr.validate(discr_value) + try: entity._discriminator_ = discr_value = attr.validate(discr_value, None, entity) except ValueError: throw(TypeError, "Incorrect discriminator value is set for %s attribute '%s' of '%s' type: %r" % (entity.__name__, attr.name, attr.py_type.__name__, discr_value)) @@ -2228,10 +2581,18 @@ def process_entity_inheritance(attr, entity): % (entity.__name__, attr.name, attr.py_type.__name__)) attr.code2cls[discr_value] = entity def validate(attr, val, obj=None, entity=None, from_db=False): - if from_db: return val - elif val is DEFAULT: + if from_db: + return val + entity = attr._get_entity(obj, entity) + if val is DEFAULT: assert entity is not None return entity._discriminator_ + if val != entity._discriminator_: + for cls in entity._subclasses_: + if val == cls._discriminator_: + break + else: throw(TypeError, 'Invalid discriminator attribute value for %s. Expected: %r, got: %r' + % (entity.__name__, entity._discriminator_, val)) return Attribute.validate(attr, val, obj, entity) def load(attr, obj): assert False # pragma: no cover @@ -2332,7 +2693,7 @@ def __new__(cls, *args, **kwargs): class Collection(Attribute): __slots__ = 'table', 'wrapper_class', 'symmetric', 'reverse_column', 'reverse_columns', \ 'nplus1_threshold', 'cached_load_sql', 'cached_add_m2m_sql', 'cached_remove_m2m_sql', \ - 'cached_count_sql', 'cached_empty_sql' + 'cached_count_sql', 'cached_empty_sql', 'reverse_fk_name' def __init__(attr, py_type, *args, **kwargs): if attr.__class__ is Collection: throw(TypeError, "'Collection' is abstract type") table = kwargs.pop('table', None) # TODO: rename table to link_table or m2m_table @@ -2365,8 +2726,9 @@ def __init__(attr, py_type, *args, **kwargs): if len(attr.reverse_columns) == 1: attr.reverse_column = attr.reverse_columns[0] else: attr.reverse_columns = [] + attr.reverse_fk_name = kwargs.pop('reverse_fk_name', None) + attr.nplus1_threshold = kwargs.pop('nplus1_threshold', 1) - for option in attr.kwargs: throw(TypeError, 'Unknown option %r' % option) attr.cached_load_sql = {} attr.cached_add_m2m_sql = None attr.cached_remove_m2m_sql = None @@ -2379,8 +2741,11 @@ def _init_(attr, entity, name): if attr.default is not None: throw(TypeError, 'Default value could not be set for collection attribute') attr.symmetric = (attr.py_type == entity.__name__ and attr.reverse == name) - if not attr.symmetric and attr.reverse_columns: throw(TypeError, - "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only") + if not attr.symmetric: + if attr.reverse_columns: + throw(TypeError, "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only") + if attr.reverse_index: + throw(TypeError, "'reverse_index' option can be set for symmetric relations only") if attr.py_check is not None: throw(NotImplementedError, "'py_check' parameter is not supported for collection attributes") def load(attr, obj): @@ -2411,7 +2776,7 @@ def param(i, j, converter): else: return [ 'PARAM', (i, j, None), converter ] if batch_size == 1: - return [ [ 'EQ', [ 'COLUMN', alias, column ], param(start, j, converter) ] + return [ [ converter.EQ, [ 'COLUMN', alias, column ], param(start, j, converter) ] for j, (column, converter) in enumerate(izip(columns, converters)) ] if len(columns) == 1: column = columns[0] @@ -2426,7 +2791,7 @@ def param(i, j, converter): condition = [ 'IN', row, param_list ] return [ condition ] else: - conditions = [ [ 'AND' ] + [ [ 'EQ', [ 'COLUMN', alias, column ], param(i+start, j, converter) ] + conditions = [ [ 'AND' ] + [ [ converter.EQ, [ 'COLUMN', alias, column ], param(i+start, j, converter) ] for j, (column, converter) in enumerate(izip(columns, converters)) ] for i in xrange(batch_size) ] return [ [ 'OR' ] + conditions ] @@ -2434,6 +2799,7 @@ def param(i, j, converter): class Set(Collection): __slots__ = [] def validate(attr, val, obj=None, entity=None, from_db=False): + val = deref_proxy(val) assert val is not NOT_LOADED if val is DEFAULT: return set() reverse = attr.reverse @@ -2450,19 +2816,76 @@ def validate(attr, val, obj=None, entity=None, from_db=False): except TypeError: throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, val)) for item in items: + item = deref_proxy(item) if not isinstance(item, rentity): throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, item)) - if obj is not None: cache = obj._session_cache_ + if obj is not None and obj._status_ is not None: cache = obj._session_cache_ else: cache = entity._database_._get_cache() for item in items: if item._session_cache_ is not cache: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') return items + def prefetch_load_all(attr, objects): + entity = attr.entity + database = entity._database_ + cache = database._get_cache() + if cache is None or not cache.is_alive: + throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over') + reverse = attr.reverse + rentity = reverse.entity + objects = sorted(objects, key=entity._get_raw_pkval_) + max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) + result = set() + if not reverse.is_collection: + for i in xrange(0, len(objects), max_batch_size): + batch = objects[i:i+max_batch_size] + sql, adapter, attr_offsets = rentity._construct_batchload_sql_(len(batch), reverse) + arguments = adapter(batch) + cursor = database._exec_sql(sql, arguments) + result.update(rentity._fetch_objects(cursor, attr_offsets)) + else: + pk_len = len(entity._pk_columns_) + m2m_dict = defaultdict(set) + for i in xrange(0, len(objects), max_batch_size): + batch = objects[i:i+max_batch_size] + sql, adapter = attr.construct_sql_m2m(len(batch)) + arguments = adapter(batch) + cursor = database._exec_sql(sql, arguments) + if len(batch) > 1: + for row in cursor.fetchall(): + obj = entity._get_by_raw_pkval_(row[:pk_len]) + item = rentity._get_by_raw_pkval_(row[pk_len:]) + m2m_dict[obj].add(item) + else: + obj = batch[0] + m2m_dict[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} + + for obj2, items in iteritems(m2m_dict): + setdata2 = obj2._vals_.get(attr) + if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() + else: + phantoms = setdata2 - items + if setdata2.added: phantoms -= setdata2.added + if phantoms: throw(UnrepeatableReadError, + 'Phantom object %s disappeared from collection %s.%s' + % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) + items -= setdata2 + if setdata2.removed: items -= setdata2.removed + setdata2 |= items + reverse.db_reverse_add(items, obj2) + result.update(items) + for obj in objects: + setdata = obj._vals_.get(attr) + if setdata is None: + setdata = obj._vals_[attr] = SetData() + setdata.is_fully_loaded = True + setdata.absent = None + setdata.count = len(setdata) + return result def load(attr, obj, items=None): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load collection %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load collection', obj, attr) assert obj._status_ not in del_statuses setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() @@ -2470,14 +2893,13 @@ def load(attr, obj, items=None): entity = attr.entity reverse = attr.reverse rentity = reverse.entity - if not reverse: throw(NotImplementedError) database = obj._database_ if cache is not database._get_cache(): throw(TransactionError, "Transaction of object %s belongs to different thread") if items: if not reverse.is_collection: - items = set(item for item in items if reverse not in item._vals_) + items = {item for item in items if reverse not in item._vals_} else: items = set(items) items -= setdata @@ -2497,15 +2919,14 @@ def load(attr, obj, items=None): items.append(obj) arguments = adapter(items) cursor = database._exec_sql(sql, arguments) - loaded_items = set(imap(rentity._get_by_raw_pkval_, cursor.fetchall())) + loaded_items = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} setdata |= loaded_items reverse.db_reverse_add(loaded_items, obj) return setdata counter = cache.collection_statistics.setdefault(attr, 0) nplus1_threshold = attr.nplus1_threshold - prefetching = options.PREFETCHING and not attr.lazy and nplus1_threshold is not None \ - and (counter >= nplus1_threshold or cache.noflush_counter) + prefetching = not attr.lazy and nplus1_threshold is not None and counter >= nplus1_threshold objects = [ obj ] setdata_list = [ setdata ] @@ -2540,16 +2961,16 @@ def load(attr, obj, items=None): items = d.get(obj2) if items is None: items = d[obj2] = set() items.add(item) - else: d[obj] = set(imap(rentity._get_by_raw_pkval_, cursor.fetchall())) + else: d[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} for obj2, items in iteritems(d): setdata2 = obj2._vals_.get(attr) - if setdata2 is None: setdata2 = obj._vals_[attr] = SetData() + if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() else: phantoms = setdata2 - items if setdata2.added: phantoms -= setdata2.added if phantoms: throw(UnrepeatableReadError, 'Phantom object %s disappeared from collection %s.%s' - % (safe_repr(phantoms.pop()), safe_repr(obj), attr.name)) + % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) items -= setdata2 if setdata2.removed: items -= setdata2.removed setdata2 |= items @@ -2599,7 +3020,7 @@ def construct_sql_m2m(attr, batch_size=1, items_count=0): return sql, adapter def copy(attr, obj): if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) reverse = attr.reverse @@ -2623,8 +3044,7 @@ def __set__(attr, obj, new_items, undo_funcs=None): if isinstance(new_items, SetInstance) and new_items._obj_ is obj and new_items._attr_ is attr: return # after += or -= cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): new_items = attr.validate(new_items, obj) @@ -2794,7 +3214,7 @@ def remove_m2m(attr, removed): columns = reverse.columns + attr.columns converters = reverse.converters + attr.converters for i, (column, converter) in enumerate(izip(columns, converters)): - where_list.append([ 'EQ', ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ]) + where_list.append([ converter.EQ, ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ]) from_ast = [ 'FROM', [ None, 'TABLE', attr.table ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) @@ -2842,6 +3262,35 @@ def unpickle_setwrapper(obj, attrname, items): setdata.count = len(setdata) return wrapper + +class SetIterator(object): + def __init__(self, wrapper): + self._wrapper = wrapper + self._query = None + self._iter = None + + def __iter__(self): + return self + + def next(self): + if self._iter is None: + self._iter = iter(self._wrapper.copy()) + return next(self._iter) + + __next__ = next + + def _get_query(self): + if self._query is None: + self._query = self._wrapper.select() + return self._query + + def _get_type_(self): + return QueryType(self._get_query()) + + def _normalize_var(self, query_type): + return query_type, self._get_query() + + class SetInstance(object): __slots__ = '_obj_', '_attr_', '_attrnames_' _parent_ = None @@ -2859,7 +3308,8 @@ def __repr__(wrapper): return '<%s %r.%s>' % (wrapper.__class__.__name__, wrapper._obj_, wrapper._attr_.name) @cut_traceback def __str__(wrapper): - if not wrapper._obj_._session_cache_.is_alive: content = '...' + cache = wrapper._obj_._session_cache_ + if cache is None or not cache.is_alive: content = '...' else: content = ', '.join(imap(str, wrapper)) return '%s([%s])' % (wrapper.__class__.__name__, content) @cut_traceback @@ -2867,7 +3317,7 @@ def __nonzero__(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = attr.load(obj) if setdata: return True @@ -2878,7 +3328,7 @@ def is_empty(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.is_fully_loaded: return not setdata @@ -2892,7 +3342,7 @@ def is_empty(wrapper): if cached_sql is None: where_list = [ 'WHERE' ] for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)): - where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) + where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = rentity._table_ select_list, attr_offsets = rentity._construct_select_clause_() @@ -2901,7 +3351,7 @@ def is_empty(wrapper): select_list = [ 'ALL' ] + [ [ 'COLUMN', None, column ] for column in attr.columns ] attr_offsets = None sql_ast = [ 'SELECT', select_list, [ 'FROM', [ None, 'TABLE', table_name ] ], - where_list, [ 'LIMIT', [ 'VALUE', 1 ] ] ] + where_list, [ 'LIMIT', 1 ] ] sql, adapter = database._ast2sql(sql_ast) attr.cached_empty_sql = sql, adapter, attr_offsets else: sql, adapter, attr_offsets = cached_sql @@ -2924,7 +3374,7 @@ def __len__(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) return len(setdata) @@ -2934,10 +3384,11 @@ def count(wrapper): obj = wrapper._obj_ cache = obj._session_cache_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.count is not None: return setdata.count + if cache is None or not cache.is_alive: throw_db_session_is_over('read value of', obj, attr) entity = attr.entity reverse = attr.reverse database = entity._database_ @@ -2945,10 +3396,10 @@ def count(wrapper): if cached_sql is None: where_list = [ 'WHERE' ] for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)): - where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) + where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = reverse.entity._table_ else: table_name = attr.table - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ None, 'TABLE', table_name ] ], where_list ] sql, adapter = database._ast2sql(sql_ast) attr.cached_count_sql = sql, adapter @@ -2962,7 +3413,7 @@ def count(wrapper): return setdata.count @cut_traceback def __iter__(wrapper): - return iter(wrapper.copy()) + return SetIterator(wrapper) @cut_traceback def __eq__(wrapper, other): if isinstance(other, SetInstance): @@ -2985,8 +3436,10 @@ def __contains__(wrapper, item): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) if not isinstance(item, attr.py_type): return False + if item._session_cache_ is not obj._session_cache_: + throw(TransactionError, 'An attempt to mix objects belonging to different transactions') reverse = attr.reverse if not reverse.is_collection: @@ -3021,15 +3474,13 @@ def create(wrapper, **kwargs): kwargs[reverse.name] = wrapper._obj_ item_type = attr.py_type item = item_type(**kwargs) - wrapper.add(item) return item @cut_traceback def add(wrapper, new_items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): reverse = attr.reverse @@ -3068,8 +3519,7 @@ def remove(wrapper, items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): reverse = attr.reverse @@ -3111,8 +3561,8 @@ def __isub__(wrapper, items): def clear(wrapper): obj = wrapper._obj_ attr = wrapper._attr_ - if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + cache = obj._session_cache_ + if cache is None or not obj._session_cache_.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) attr.__set__(obj, ()) @cut_traceback @@ -3128,16 +3578,18 @@ def select(wrapper, *args): s = 'lambda item: JOIN(obj in item.%s)' if reverse.is_collection else 'lambda item: item.%s == obj' query = query.filter(s % reverse.name, {'obj' : obj, 'JOIN': JOIN}) if args: - func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+1) query = query.filter(func, globals, locals) return query filter = select - def limit(wrapper, limit, offset=None): + def limit(wrapper, limit=None, offset=None): return wrapper.select().limit(limit, offset) def page(wrapper, pagenum, pagesize=10): return wrapper.select().page(pagenum, pagesize) def order_by(wrapper, *args): return wrapper.select().order_by(*args) + def sort_by(wrapper, *args): + return wrapper.select().sort_by(*args) def random(wrapper, limit): return wrapper.select().random(limit) @@ -3168,7 +3620,8 @@ def distinct(multiset): return multiset._items_.copy() @cut_traceback def __repr__(multiset): - if multiset._obj_._session_cache_.is_alive: + cache = multiset._obj_._session_cache_ + if cache is not None and cache.is_alive: size = builtins.sum(itervalues(multiset._items_)) if size == 1: size_str = ' (1 item)' else: size_str = ' (%d items)' % size @@ -3305,11 +3758,26 @@ def __init__(entity, name, bases, cls_dict): new_attrs.append(attr) new_attrs.sort(key=attrgetter('id')) + interleave_attrs = [] + for attr in new_attrs: + if attr.interleave is not None: + if attr.interleave: + interleave_attrs.append(attr) + entity._interleave_ = None + if interleave_attrs: + if len(interleave_attrs) > 1: throw(TypeError, + 'only one attribute may be marked as interleave. Got: %s' + % ', '.join(repr(attr) for attr in interleave_attrs)) + interleave = interleave_attrs[0] + if not interleave.is_relation: throw(TypeError, + 'Interleave attribute should be part of relationship. Got: %r' % attr) + entity._interleave_ = interleave + indexes = entity._indexes_ = entity.__dict__.get('_indexes_', []) for attr in new_attrs: if attr.is_unique: indexes.append(Index(attr, is_pk=isinstance(attr, PrimaryKey))) for index in indexes: index._init_(entity) - primary_keys = set(index.attrs for index in indexes if index.is_pk) + primary_keys = {index.attrs for index in indexes if index.is_pk} if direct_bases: if primary_keys: throw(ERDiagramError, 'Primary key cannot be redefined in derived classes') base_indexes = [] @@ -3317,7 +3785,7 @@ def __init__(entity, name, bases, cls_dict): for index in base._indexes_: if index not in base_indexes and index not in indexes: base_indexes.append(index) indexes[:0] = base_indexes - primary_keys = set(index.attrs for index in indexes if index.is_pk) + primary_keys = {index.attrs for index in indexes if index.is_pk} if len(primary_keys) > 1: throw(ERDiagramError, 'Only one primary key can be defined in each entity class') elif not primary_keys: @@ -3346,7 +3814,7 @@ def __init__(entity, name, bases, cls_dict): entity._new_attrs_ = new_attrs entity._attrs_ = base_attrs + new_attrs - entity._adict_ = dict((attr.name, attr) for attr in entity._attrs_) + entity._adict_ = {attr.name: attr for attr in entity._attrs_} entity._subclass_attrs_ = [] entity._subclass_adict_ = {} for base in entity._all_bases_: @@ -3438,7 +3906,7 @@ def _link_reverse_attrs_(entity): database = entity._database_ for attr in entity._new_attrs_: py_type = attr.py_type - if not issubclass(py_type, Entity): continue + if not isinstance(py_type, EntityMeta): continue entity2 = py_type if entity2._database_ is not database: @@ -3484,6 +3952,12 @@ def _link_reverse_attrs_(entity): attr2.reverse = attr attr.linked() attr2.linked() + def _check_table_options_(entity): + if entity._root_ is not entity: + if '_table_options_' in entity.__dict__: throw(TypeError, + 'Cannot redefine %s options in %s entity' % (entity._root_.__name__, entity.__name__)) + elif not hasattr(entity, '_table_options_'): + entity._table_options_ = {} def _get_pk_columns_(entity): if entity._pk_columns_ is not None: return entity._pk_columns_ pk_columns = [] @@ -3503,64 +3977,51 @@ def _get_pk_columns_(entity): return pk_columns def __iter__(entity): return EntityIter(entity) - def _normalize_args_(entity, kwargs, setdefault=False): - avdict = {} - if setdefault: - for name in kwargs: - if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name) - for attr in entity._attrs_: - val = kwargs.get(attr.name, DEFAULT) - avdict[attr] = attr.validate(val, None, entity, from_db=False) - else: - get_attr = entity._adict_.get - for name, val in iteritems(kwargs): - attr = get_attr(name) - if attr is None: throw(TypeError, 'Unknown attribute %r' % name) - avdict[attr] = attr.validate(val, None, entity, from_db=False) - if entity._pk_is_composite_: - get_val = avdict.get - pkval = tuple(get_val(attr) for attr in entity._pk_attrs_) - if None in pkval: pkval = None - else: pkval = avdict.get(entity._pk_attrs_[0]) - return pkval, avdict @cut_traceback def __getitem__(entity, key): if type(key) is not tuple: key = (key,) - if len(key) != len(entity._pk_attrs_): - throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)' - % (entity.__name__, len(key), len(entity._pk_attrs_))) - kwargs = dict(izip(imap(attrgetter('name'), entity._pk_attrs_), key)) - return entity._find_one_(kwargs) + if len(key) == len(entity._pk_attrs_): + kwargs = {attr.name: value for attr, value in izip(entity._pk_attrs_, key)} + return entity._find_one_(kwargs) + if len(key) == len(entity._pk_columns_): + return entity._get_by_raw_pkval_(key, from_db=False, seed=False) + + throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)' + % (entity.__name__, len(key), len(entity._pk_attrs_))) @cut_traceback def exists(entity, *args, **kwargs): - if args: return entity._query_from_args_(args, kwargs, frame_depth=3).exists() + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).exists() try: obj = entity._find_one_(kwargs) except ObjectNotFound: return False except MultipleObjectsFoundError: return True return True @cut_traceback def get(entity, *args, **kwargs): - if args: return entity._query_from_args_(args, kwargs, frame_depth=3).get() + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).get() try: return entity._find_one_(kwargs) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_for_update(entity, *args, **kwargs): nowait = kwargs.pop('nowait', False) - if args: return entity._query_from_args_(args, kwargs, frame_depth=3).for_update(nowait).get() - try: return entity._find_one_(kwargs, True, nowait) # can throw MultipleObjectsFoundError + skip_locked = kwargs.pop('skip_locked', False) + if nowait and skip_locked: + throw(TypeError, 'nowait and skip_locked options are mutually exclusive') + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1) \ + .for_update(nowait, skip_locked).get() + try: return entity._find_one_(kwargs, True, nowait, skip_locked) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_by_sql(entity, sql, globals=None, locals=None): - objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=3) # can throw MultipleObjectsFoundError + objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=cut_traceback_depth+1) # can throw MultipleObjectsFoundError if not objects: return None assert len(objects) == 1 return objects[0] @cut_traceback def select(entity, *args): - return entity._query_from_args_(args, kwargs=None, frame_depth=3) + return entity._query_from_args_(args, kwargs=None, frame_depth=cut_traceback_depth+1) @cut_traceback def select_by_sql(entity, sql, globals=None, locals=None): - return entity._find_by_sql_(None, sql, globals, locals, frame_depth=3) + return entity._find_by_sql_(None, sql, globals, locals, frame_depth=cut_traceback_depth+1) @cut_traceback def select_random(entity, limit): if entity._pk_is_composite_: return entity.select().random(limit) @@ -3574,7 +4035,7 @@ def select_random(entity, limit): if max_id is None: max_id_sql = entity._cached_max_id_sql_ if max_id_sql is None: - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', [ 'COLUMN', None, pk.column ] ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', None, [ 'COLUMN', None, pk.column ] ] ], [ 'FROM', [ None, 'TABLE', entity._table_ ] ] ] max_id_sql, adapter = database._ast2sql(sql_ast) entity._cached_max_id_sql_ = max_id_sql @@ -3623,15 +4084,24 @@ def select_random(entity, limit): if obj in seeds: obj._load_() if found_in_cache: shuffle(result) return result - def _find_one_(entity, kwargs, for_update=False, nowait=False): + def _find_one_(entity, kwargs, for_update=False, nowait=False, skip_locked=False): if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) - pkval, avdict = entity._normalize_args_(kwargs, False) + avdict = {} + get_attr = entity._adict_.get + for name, val in iteritems(kwargs): + attr = get_attr(name) + if attr is None: throw(TypeError, 'Unknown attribute %r' % name) + avdict[attr] = attr.validate(val, None, entity, from_db=False) + if entity._pk_is_composite_: + pkval = tuple(imap(avdict.get, entity._pk_attrs_)) + if None in pkval: pkval = None + else: pkval = avdict.get(entity._pk_attrs_[0]) for attr in avdict: if attr.is_collection: throw(TypeError, 'Collection attribute %s cannot be specified as search criteria' % attr) obj, unique = entity._find_in_cache_(pkval, avdict, for_update) - if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait) + if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait, skip_locked) if obj is None: throw(ObjectNotFound, entity, pkval) return obj def _find_in_cache_(entity, pkval, avdict, for_update=False): @@ -3654,9 +4124,9 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False): get_val = avdict.get vals = tuple(get_val(attr) for attr in attrs) if None in vals: continue + unique = True cache_index = cache_indexes.get(attrs) if cache_index is None: continue - unique = True obj = cache_index.get(vals) if obj is not None: break if obj is None: @@ -3683,11 +4153,11 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False): entity._set_rbits((obj,), avdict) return obj, unique return None, unique - def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False): + def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False, skip_locked=False): database = entity._database_ - query_attrs = dict((attr, value is None) for attr, value in iteritems(avdict)) + query_attrs = {attr: value is None for attr, value in iteritems(avdict)} limit = 2 if not unique else None - sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait) + sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait, skip_locked) arguments = adapter(avdict) if for_update: database._get_cache().immediate = True cursor = database._exec_sql(sql, arguments) @@ -3719,11 +4189,12 @@ def _find_by_sql_(entity, max_fetch_count, sql, globals, locals, frame_depth): objects = entity._fetch_objects(cursor, attr_offsets, max_fetch_count) return objects - def _construct_select_clause_(entity, alias=None, distinct=False, - query_attrs=(), attrs_to_prefetch=(), all_attributes=False): + def _construct_select_clause_(entity, alias=None, distinct=False, query_attrs=(), all_attributes=False): attr_offsets = {} select_list = [ 'DISTINCT' ] if distinct else [ 'ALL' ] root = entity._root_ + pc = local.prefetch_context + attrs_to_prefetch = pc.attrs_to_prefetch_dict.get(entity, ()) if pc else () for attr in chain(root._attrs_, root._subclass_attrs_): if not all_attributes and not issubclass(attr.entity, entity) \ and not issubclass(entity, attr.entity): continue @@ -3738,12 +4209,13 @@ def _construct_select_clause_(entity, alias=None, distinct=False, def _construct_discriminator_criteria_(entity, alias=None): discr_attr = entity._discriminator_attr_ if discr_attr is None: return None - code2cls = discr_attr.code2cls discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in entity._subclasses_ ] discr_values.append([ 'VALUE', entity._discriminator_]) return [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True): - query_key = batch_size, attr, from_seeds + pc = local.prefetch_context + attrs_to_prefetch = pc.get_frozen_attrs_to_prefetch(entity) if pc is not None else () + query_key = batch_size, attr, from_seeds, attrs_to_prefetch cached_sql = entity._batchload_sql_cache_.get(query_key) if cached_sql is not None: return cached_sql select_list, attr_offsets = entity._construct_select_clause_(all_attributes=True) @@ -3763,12 +4235,10 @@ def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True): cached_sql = sql, adapter, attr_offsets entity._batchload_sql_cache_[query_key] = cached_sql return cached_sql - def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False): - if limit and entity._database_.provider.dialect == 'MSSQL': - order_by_pk = True # todo: use TOP 1 instead of FETCH NEXT and remove this line - if nowait: assert for_update + def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False, skip_locked=False): + if nowait or skip_locked: assert for_update sorted_query_attrs = tuple(sorted(query_attrs.items())) - query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait + query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait, skip_locked cached_sql = entity._find_sql_cache_.get(query_key) if cached_sql is not None: return cached_sql select_list, attr_offsets = entity._construct_select_clause_(query_attrs=query_attrs) @@ -3785,7 +4255,8 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda if attr_is_none: where_list.append([ 'IS_NULL', [ 'COLUMN', None, attr.column ] ]) else: if len(attr.converters) > 1: throw(NotImplementedError) - where_list.append([ 'EQ', [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), attr.converters[0] ] ]) + converter = attr.converters[0] + where_list.append([ converter.EQ, [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), converter ] ]) elif not attr.columns: throw(NotImplementedError) else: attr_entity = attr.py_type; assert attr_entity == attr.reverse.entity @@ -3794,12 +4265,12 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda where_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ]) else: for j, (column, converter) in enumerate(izip(attr.columns, attr_entity._pk_converters_)): - where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ]) + where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ]) if not for_update: sql_ast = [ 'SELECT', select_list, from_list, where_list ] - else: sql_ast = [ 'SELECT_FOR_UPDATE', bool(nowait), select_list, from_list, where_list ] + else: sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked, select_list, from_list, where_list ] if order_by_pk: sql_ast.append([ 'ORDER_BY' ] + [ [ 'COLUMN', None, column ] for column in entity._pk_columns_ ]) - if limit is not None: sql_ast.append([ 'LIMIT', [ 'VALUE', limit ] ]) + if limit is not None: sql_ast.append([ 'LIMIT', limit ]) database = entity._database_ sql, adapter = database._ast2sql(sql_ast) cached_sql = sql, adapter, attr_offsets @@ -3836,31 +4307,43 @@ def _set_rbits(entity, objects, attrs): if wbits is None: continue rbits = get_rbits(obj.__class__) if rbits is None: - rbits = sum(obj._bits_.get(attr, 0) for attr in attrs) + rbits = sum(obj._bits_except_volatile_.get(attr, 0) for attr in attrs) rbits_dict[obj.__class__] = rbits obj._rbits_ |= rbits & ~wbits def _parse_row_(entity, row, attr_offsets): discr_attr = entity._discriminator_attr_ - if not discr_attr: real_entity_subclass = entity + if not discr_attr: + discr_value = None + real_entity_subclass = entity else: discr_offset = attr_offsets[discr_attr][0] discr_value = discr_attr.validate(row[discr_offset], None, entity, from_db=True) real_entity_subclass = discr_attr.code2cls[discr_value] + discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x + + database = entity._database_ + cache = local.db2cache[database] avdict = {} for attr in real_entity_subclass._attrs_: offsets = attr_offsets.get(attr) - if offsets is None or attr.is_discriminator: continue - avdict[attr] = attr.parse_value(row, offsets) - if not entity._pk_is_composite_: pkval = avdict.pop(entity._pk_attrs_[0], None) - else: pkval = tuple(avdict.pop(attr, None) for attr in entity._pk_attrs_) + if offsets is None: + continue + if attr.is_discriminator: + avdict[attr] = discr_value + else: + avdict[attr] = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) + + pkval = tuple(avdict.pop(attr) for attr in entity._pk_attrs_) + assert None not in pkval + if not entity._pk_is_composite_: pkval = pkval[0] return real_entity_subclass, pkval, avdict def _load_many_(entity, objects): database = entity._database_ cache = database._get_cache() seeds = cache.seeds[entity._pk_attrs_] if not seeds: return - objects = set(obj for obj in objects if obj in seeds) + objects = {obj for obj in objects if obj in seeds} objects = sorted(objects, key=attrgetter('_pkval_')) max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) while objects: @@ -3903,7 +4386,6 @@ def _query_from_args_(entity, args, kwargs, frame_depth): for_expr = ast.GenExprFor(ast.AssName(name, 'OP_ASSIGN'), ast.Name('.0'), [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(name), [ for_expr ]) locals = locals.copy() if locals is not None else {} - assert '.0' not in locals locals['.0'] = entity return Query(code_key, inner_expr, globals, locals, cells) def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs=None, obj_to_init=None): @@ -3928,7 +4410,9 @@ def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs= if obj is None: with cache.flush_disabled(): - obj = obj_to_init or object.__new__(entity) + obj = obj_to_init + if obj_to_init is None: + obj = object.__new__(entity) cache.objects.add(obj) obj._pkval_ = pkval obj._status_ = status @@ -3961,7 +4445,7 @@ def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs= assert cache.in_transaction cache.for_update.add(obj) return obj - def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True): + def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True, seed=True): i = 0 pkval = [] for attr in entity._pk_attrs_: @@ -3969,16 +4453,19 @@ def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True): val = raw_pkval[i] i += 1 if not attr.reverse: val = attr.validate(val, None, entity, from_db=from_db) - else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db) + else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db, seed=seed) else: if not attr.reverse: throw(NotImplementedError) vals = raw_pkval[i:i+len(attr.columns)] - val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db) + val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db, seed=seed) i += len(attr.columns) pkval.append(val) if not entity._pk_is_composite_: pkval = pkval[0] else: pkval = tuple(pkval) - obj = entity._get_from_identity_map_(pkval, 'loaded', for_update) + if seed: + obj = entity._get_from_identity_map_(pkval, 'loaded', for_update) + else: + obj = entity[pkval] assert obj._status_ != 'cancelled' return obj def _get_propagation_mixin_(entity): @@ -3990,6 +4477,8 @@ def _get_propagation_mixin_(entity): def fget(wrapper, attr=attr): attrnames = wrapper._attrnames_ + (attr.name,) items = [ x for x in (attr.__get__(item) for item in wrapper) if x is not None ] + if attr.py_type is Json: + return [ item.get_untracked() if isinstance(item, TrackedValue) else item for item in items ] return Multiset(wrapper._obj_, attrnames, items) elif not attr.is_collection: def fget(wrapper, attr=attr): @@ -4079,21 +4568,21 @@ def _get_attrs_(entity, only=None, exclude=None, with_collections=False, with_la entity._attrnames_cache_[key] = attrs return attrs -def populate_criteria_list(criteria_list, columns, converters, params_count=0, table_alias=None): - assert len(columns) == len(converters) - for column, converter in izip(columns, converters): - if converter is not None: - criteria_list.append([ 'EQ', [ 'COLUMN', table_alias, column ], - [ 'PARAM', (params_count, None, None), converter ] ]) +def populate_criteria_list(criteria_list, columns, converters, operations, + params_count=0, table_alias=None, optimistic=False): + for column, op, converter in izip(columns, operations, converters): + if op == 'IS_NULL': + criteria_list.append([ op, [ 'COLUMN', None, column ] ]) else: - criteria_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ]) + criteria_list.append([ op, [ 'COLUMN', table_alias, column ], + [ 'PARAM', (params_count, None, None), converter, optimistic ] ]) params_count += 1 return params_count -statuses = set(['created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted']) -del_statuses = set(['marked_to_delete', 'deleted', 'cancelled']) -created_or_deleted_statuses = set(['created']) | del_statuses -saved_statuses = set(['inserted', 'updated', 'deleted']) +statuses = {'created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted'} +del_statuses = {'marked_to_delete', 'deleted', 'cancelled'} +created_or_deleted_statuses = {'created'} | del_statuses +saved_statuses = {'inserted', 'updated', 'deleted'} def throw_object_was_deleted(obj): assert obj._status_ in del_statuses @@ -4119,6 +4608,64 @@ def unpickle_entity(d): def safe_repr(obj): return Entity.__repr__(obj) +def make_proxy(obj): + proxy = EntityProxy(obj) + return proxy + +class EntityProxy(object): + def __init__(self, obj): + entity = obj.__class__ + object.__setattr__(self, '_entity_', entity) + pkval = obj.get_pk() + if pkval is None: + cache = obj._session_cache_ + if obj._status_ in del_statuses or cache is None or not cache.is_alive: + throw(ValueError, 'Cannot make a proxy for %s object: primary key is not specified' % entity.__name__) + flush() + pkval = obj.get_pk() + assert pkval is not None + object.__setattr__(self, '_obj_pk_', pkval) + + def __repr__(self): + entity = self._entity_ + pkval = self._obj_pk_ + pkrepr = ','.join(repr(item) for item in pkval) if isinstance(pkval, tuple) else repr(pkval) + return '' % (entity.__name__, pkrepr) + + def _get_object(self): + entity = self._entity_ + pkval = self._obj_pk_ + cache = entity._database_._get_cache() + attrs = entity._pk_attrs_ + if attrs in cache.indexes and pkval in cache.indexes[attrs]: + obj = cache.indexes[attrs][pkval] + else: + obj = entity[pkval] + return obj + + def __getattr__(self, name): + obj = self._get_object() + return getattr(obj, name) + + def __setattr__(self, name, value): + obj = self._get_object() + setattr(obj, name, value) + + def __eq__(self, other): + entity = self._entity_ + pkval = self._obj_pk_ + if isinstance(other, EntityProxy): + entity2 = other._entity_ + pkval2 = other._obj_pk_ + return entity == entity2 and pkval == pkval2 + elif isinstance(other, entity): + return pkval == other._pkval_ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + class Entity(with_metaclass(EntityMeta)): __slots__ = '_session_cache_', '_status_', '_pkval_', '_newid_', '_dbvals_', '_vals_', '_rbits_', '_wbits_', '_save_pos_', '__weakref__' def __reduce__(obj): @@ -4128,19 +4675,29 @@ def __reduce__(obj): OrmError, '%s object %s has to be stored in DB before it can be pickled' % (obj._status_.capitalize(), safe_repr(obj))) d = {'__class__' : obj.__class__} - adict = obj._adict_ for attr, val in iteritems(obj._vals_): if not attr.is_collection: d[attr.name] = val return unpickle_entity, (d,) @cut_traceback def __init__(obj, *args, **kwargs): + obj._status_ = None entity = obj.__class__ if args: raise TypeError('%s constructor accept only keyword arguments. Got: %d positional argument%s' % (entity.__name__, len(args), len(args) > 1 and 's' or '')) if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) - pkval, avdict = entity._normalize_args_(kwargs, True) + avdict = {} + for name in kwargs: + if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name) + for attr in entity._attrs_: + val = kwargs.get(attr.name, DEFAULT) + avdict[attr] = attr.validate(val, obj, from_db=False) + if entity._pk_is_composite_: + pkval = tuple(imap(avdict.get, entity._pk_attrs_)) + if None in pkval: pkval = None + else: pkval = avdict.get(entity._pk_attrs_[0]) + undo_funcs = [] cache = entity._database_._get_cache() cache_indexes = cache.indexes @@ -4177,6 +4734,7 @@ def __init__(obj, *args, **kwargs): obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True + @cut_traceback def get_pk(obj): pkval = obj._get_raw_pkval_() if len(pkval) == 1: return pkval[0] @@ -4224,10 +4782,23 @@ def __repr__(obj): if obj._pk_is_composite_: pkval = ','.join(imap(repr, pkval)) else: pkval = repr(pkval) return '%s[%s]' % (obj.__class__.__name__, pkval) + @classmethod + def _prefetch_load_all_(entity, objects): + objects = sorted(objects, key=entity._get_raw_pkval_) + database = entity._database_ + cache = database._get_cache() + if cache is None or not cache.is_alive: + throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over') + max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) + for i in xrange(0, len(objects), max_batch_size): + batch = objects[i:i+max_batch_size] + sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(batch)) + arguments = adapter(batch) + cursor = database._exec_sql(sql, arguments) + entity._fetch_objects(cursor, attr_offsets) def _load_(obj): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) entity = obj.__class__ database = entity._database_ if cache is not database._get_cache(): @@ -4235,10 +4806,9 @@ def _load_(obj): seeds = cache.seeds[entity._pk_attrs_] max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) objects = [ obj ] - if options.PREFETCHING: - for seed in seeds: - if len(objects) >= max_batch_size: break - if seed is not obj: objects.append(seed) + for seed in seeds: + if len(objects) >= max_batch_size: break + if seed is not obj: objects.append(seed) sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(objects)) arguments = adapter(objects) cursor = database._exec_sql(sql, arguments) @@ -4248,8 +4818,7 @@ def _load_(obj): @cut_traceback def load(obj, *attrs): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) entity = obj.__class__ database = entity._database_ if cache is not database._get_cache(): @@ -4292,7 +4861,7 @@ def load(obj, *attrs): offsets.append(len(select_list) - 1) select_list.append([ 'COLUMN', None, column ]) from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ]] - criteria_list = [ [ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] + criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(obj._pk_columns_, obj._pk_converters_)) ] where_list = [ 'WHERE' ] + criteria_list @@ -4307,10 +4876,27 @@ def load(obj, *attrs): objects = entity._fetch_objects(cursor, attr_offsets) if obj not in objects: throw(UnrepeatableReadError, 'Phantom object %s disappeared' % safe_repr(obj)) + def _attr_changed_(obj, attr): + cache = obj._session_cache_ + if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr) + if obj._status_ in del_statuses: throw_object_was_deleted(obj) + status = obj._status_ + wbits = obj._wbits_ + bit = obj._bits_[attr] + objects_to_save = cache.objects_to_save + if wbits is not None and bit: + obj._wbits_ |= bit + if status != 'modified': + assert status in ('loaded', 'inserted', 'updated') + assert obj._save_pos_ is None + obj._status_ = 'modified' + obj._save_pos_ = len(objects_to_save) + objects_to_save.append(obj) + cache.modified = True def _db_set_(obj, avdict, unpickling=False): assert obj._status_ not in created_or_deleted_statuses cache = obj._session_cache_ - assert cache.is_alive + assert cache is not None and cache.is_alive cache.seeds[obj._pk_attrs_].discard(obj) if not avdict: return @@ -4322,50 +4908,62 @@ def _db_set_(obj, avdict, unpickling=False): assert attr.pk_offset is None assert new_dbval is not NOT_LOADED old_dbval = get_dbval(attr, NOT_LOADED) - if unpickling and old_dbval is not NOT_LOADED: - del avdict[attr] - continue - elif attr.py_type is float: - if old_dbval is NOT_LOADED: pass - elif attr.converters[0].equals(old_dbval, new_dbval): + if old_dbval is not NOT_LOADED: + if unpickling or old_dbval == new_dbval or ( + not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): del avdict[attr] continue - elif old_dbval == new_dbval: - del avdict[attr] - continue - bit = obj._bits_[attr] - if rbits & bit: throw(UnrepeatableReadError, - 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' - % (obj.__class__.__name__, attr.name, obj, old_dbval, new_dbval)) + if unpickling: + new_vals = avdict + new_dbvals = {attr: attr.converters[0].val2dbval(val, obj) if not attr.reverse else val + for attr, val in iteritems(avdict)} + else: + new_dbvals = avdict + new_vals = {attr: attr.converters[0].dbval2val(dbval, obj) if not attr.reverse else dbval + for attr, dbval in iteritems(avdict)} + + for attr, new_val in items_list(new_vals): + new_dbval = new_dbvals[attr] + old_dbval = get_dbval(attr, NOT_LOADED) + bit = obj._bits_except_volatile_[attr] + if rbits & bit: + errormsg = 'Please contact PonyORM developers so they can ' \ + 'reproduce your error and fix a bug: support@ponyorm.org' + assert old_dbval is not NOT_LOADED, errormsg + throw(UnrepeatableReadError, + 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' + % (obj.__class__.__name__, attr.name, obj, old_dbval, new_dbval)) if attr.reverse: attr.db_update_reverse(obj, old_dbval, new_dbval) obj._dbvals_[attr] = new_dbval - if wbits & bit: del avdict[attr] + if wbits & bit: + del new_vals[attr] + + for attr, new_val in iteritems(new_vals): if attr.is_unique: old_val = get_val(attr) - if old_val != new_dbval: - cache.db_update_simple_index(obj, attr, old_val, new_dbval) + if old_val != new_val: + cache.db_update_simple_index(obj, attr, old_val, new_val) for attrs in obj._composite_keys_: - for attr in attrs: - if attr in avdict: break - else: continue - vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! - prev_vals = tuple(vals) - for i, attr in enumerate(attrs): - if attr in avdict: vals[i] = avdict[attr] - new_vals = tuple(vals) - cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) - - for attr, new_dbval in iteritems(avdict): - obj._vals_[attr] = new_dbval + if any(attr in new_vals for attr in attrs): + key_vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! + prev_key_vals = tuple(key_vals) + for i, attr in enumerate(attrs): + if attr in new_vals: key_vals[i] = new_vals[attr] + new_key_vals = tuple(key_vals) + if prev_key_vals != new_key_vals: + cache.db_update_composite_index(obj, attrs, prev_key_vals, new_key_vals) + + obj._vals_.update(new_vals) def _delete_(obj, undo_funcs=None): status = obj._status_ if status in del_statuses: return is_recursive_call = undo_funcs is not None if not is_recursive_call: undo_funcs = [] cache = obj._session_cache_ + assert cache is not None and cache.is_alive with cache.flush_disabled(): get_val = obj._vals_.get undo_list = [] @@ -4387,9 +4985,22 @@ def undo_func(): undo_funcs.append(undo_func) try: for attr in obj._attrs_: - reverse = attr.reverse - if not reverse: continue + if not attr.is_collection: continue + if isinstance(attr, Set): + set_wrapper = attr.__get__(obj) + if not set_wrapper.__nonzero__(): pass + elif attr.cascade_delete: + for robj in set_wrapper: robj._delete_(undo_funcs) + elif not attr.reverse.is_required: attr.__set__(obj, (), undo_funcs) + else: throw(ConstraintError, "Cannot delete object %s, because it has non-empty set of %s, " + "and 'cascade_delete' option of %s is not set" + % (obj, attr.name, attr)) + else: throw(NotImplementedError) + + for attr in obj._attrs_: if not attr.is_collection: + reverse = attr.reverse + if not reverse: continue if not reverse.is_collection: val = get_val(attr) if attr in obj._vals_ else attr.load(obj) if val is None: continue @@ -4404,16 +5015,6 @@ def undo_func(): if val is None: continue reverse.reverse_remove((val,), obj, undo_funcs) else: throw(NotImplementedError) - elif isinstance(attr, Set): - set_wrapper = attr.__get__(obj) - if not set_wrapper.__nonzero__(): pass - elif attr.cascade_delete: - for robj in set_wrapper: robj._delete_(undo_funcs) - elif not reverse.is_required: attr.__set__(obj, (), undo_funcs) - else: throw(ConstraintError, "Cannot delete object %s, because it has non-empty set of %s, " - "and 'cascade_delete' option of %s is not set" - % (obj, attr.name, attr)) - else: throw(NotImplementedError) cache_indexes = cache.indexes for attr in obj._simple_keys_: @@ -4459,14 +5060,13 @@ def undo_func(): raise @cut_traceback def delete(obj): - if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, - 'Cannot delete object %s: the database session is over' % safe_repr(obj)) + cache = obj._session_cache_ + if cache is None or not cache.is_alive: throw_db_session_is_over('delete object', obj) obj._delete_() @cut_traceback def set(obj, **kwargs): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change object', obj) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): avdict, collection_avdict = obj._keyargs_to_avdicts_(kwargs) @@ -4478,6 +5078,7 @@ def set(obj, **kwargs): for attr in avdict: if attr not in obj._vals_ and attr.reverse and not attr.reverse.is_collection: attr.load(obj) # loading of one-to-one relations + if wbits is not None: new_wbits = wbits for attr in avdict: new_wbits |= obj._bits_[attr] @@ -4489,12 +5090,16 @@ def set(obj, **kwargs): obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True + if not collection_avdict: - for attr in avdict: - if attr.reverse or attr.is_part_of_unique_index: break - else: + if not any(attr.reverse or attr.is_part_of_unique_index for attr in avdict): obj._vals_.update(avdict) return + + for attr, value in items_list(avdict): + if value == get_val(attr): + avdict.pop(attr) + undo_funcs = [] undo = [] def undo_func(): @@ -4513,17 +5118,15 @@ def undo_func(): if attr not in avdict: continue new_val = avdict[attr] old_val = get_val(attr) - if old_val != new_val: cache.update_simple_index(obj, attr, old_val, new_val, undo) + cache.update_simple_index(obj, attr, old_val, new_val, undo) for attrs in obj._composite_keys_: - for attr in attrs: - if attr in avdict: break - else: continue - vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! - prev_vals = tuple(vals) - for i, attr in enumerate(attrs): - if attr in avdict: vals[i] = avdict[attr] - new_vals = tuple(vals) - cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo) + if any(attr in avdict for attr in attrs): + vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! + prev_vals = tuple(vals) + for i, attr in enumerate(attrs): + if attr in avdict: vals[i] = avdict[attr] + new_vals = tuple(vals) + cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo) for attr, new_val in iteritems(avdict): if not attr.reverse: continue old_val = get_val(attr) @@ -4555,14 +5158,19 @@ def _construct_optimistic_criteria_(obj): optimistic_columns = [] optimistic_converters = [] optimistic_values = [] + optimistic_operations = [] for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): + converters = attr.converters + assert converters + optimistic = attr.optimistic if attr.optimistic is not None else converters[0].optimistic + if not optimistic: continue dbval = obj._dbvals_[attr] optimistic_columns.extend(attr.columns) - if dbval is not None: converters = attr.converters - else: converters = repeat(None, len(attr.converters)) - optimistic_converters.extend(converters) - optimistic_values.extend(attr.get_raw_values(dbval)) - return optimistic_columns, optimistic_converters, optimistic_values + optimistic_converters.extend(attr.converters) + values = attr.get_raw_values(dbval) + optimistic_values.extend(values) + optimistic_operations.extend('IS_NULL' if dbval is None else converter.EQ for converter in converters) + return optimistic_operations, optimistic_columns, optimistic_converters, optimistic_values def _save_principal_objects_(obj, dependent_objects): if dependent_objects is None: dependent_objects = [] elif obj in dependent_objects: @@ -4578,7 +5186,7 @@ def _save_principal_objects_(obj, dependent_objects): val = obj._vals_[attr] if val is not None and val._status_ == 'created': val._save_(dependent_objects) - def _update_dbvals_(obj, after_create): + def _update_dbvals_(obj, after_create, new_dbvals): bits = obj._bits_ vals = obj._vals_ dbvals = obj._dbvals_ @@ -4594,21 +5202,34 @@ def _update_dbvals_(obj, after_create): for key, i in attr.composite_keys: keyval = tuple(get_val(attr) for attr in key) cache_indexes[key].pop(keyval, None) - del vals[attr] elif after_create and val is None: obj._rbits_ &= ~bits[attr] - del vals[attr] - else: dbvals[attr] = val + else: + if attr in new_dbvals: + dbvals[attr] = new_dbvals[attr] + continue + # Clear value of volatile attribute or null values after create, because the value may be changed in the DB + del vals[attr] + dbvals.pop(attr, None) + def _save_created_(obj): auto_pk = (obj._pkval_ is None) attrs = [] values = [] + new_dbvals = {} for attr in obj._attrs_with_columns_: if auto_pk and attr.is_pk: continue val = obj._vals_[attr] if val is not None: attrs.append(attr) - values.extend(attr.get_raw_values(val)) + if not attr.reverse: + assert len(attr.converters) == 1 + dbval = attr.converters[0].val2dbval(val, obj) + new_dbvals[attr] = dbval + values.append(dbval) + else: + new_dbvals[attr] = val + values.extend(attr.get_raw_values(val)) attrs = tuple(attrs) database = obj._database_ @@ -4658,26 +5279,34 @@ def _save_created_(obj): obj._status_ = 'inserted' obj._rbits_ = obj._all_bits_except_volatile_ obj._wbits_ = 0 - obj._update_dbvals_(True) + obj._update_dbvals_(True, new_dbvals) def _save_updated_(obj): update_columns = [] values = [] + new_dbvals = {} for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._wbits_): update_columns.extend(attr.columns) val = obj._vals_[attr] - values.extend(attr.get_raw_values(val)) + if not attr.reverse: + assert len(attr.converters) == 1 + dbval = attr.converters[0].val2dbval(val, obj) + new_dbvals[attr] = dbval + values.append(dbval) + else: + new_dbvals[attr] = val + values.extend(attr.get_raw_values(val)) if update_columns: for attr in obj._pk_attrs_: val = obj._vals_[attr] values.extend(attr.get_raw_values(val)) cache = obj._session_cache_ - if obj not in cache.for_update: - optimistic_columns, optimistic_converters, optimistic_values = \ + optimistic_session = cache.db_session is None or cache.db_session.optimistic + if optimistic_session and obj not in cache.for_update: + optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) - else: optimistic_columns = optimistic_converters = () - query_key = (tuple(update_columns), tuple(optimistic_columns), - tuple(converter is not None for converter in optimistic_converters)) + else: optimistic_columns = optimistic_converters = optimistic_ops = () + query_key = tuple(update_columns), tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._update_sql_cache_.get(query_key) if cached_sql is None: @@ -4690,52 +5319,101 @@ def _save_updated_(obj): where_list = [ 'WHERE' ] pk_columns = obj._pk_columns_ pk_converters = obj._pk_converters_ - params_count = populate_criteria_list(where_list, pk_columns, pk_converters, params_count) - if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_converters, params_count) + params_count = populate_criteria_list(where_list, pk_columns, pk_converters, repeat('EQ'), params_count) + if optimistic_columns: populate_criteria_list( + where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count, optimistic=True) sql_ast = [ 'UPDATE', obj._table_, list(izip(update_columns, update_params)), where_list ] sql, adapter = database._ast2sql(sql_ast) obj._update_sql_cache_[query_key] = sql, adapter else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) - if cursor.rowcount != 1: - throw(OptimisticCheckError, 'Object %s was updated outside of current transaction' % safe_repr(obj)) + if cursor.rowcount == 0 and cache.db_session.optimistic: + throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'updated' obj._rbits_ |= obj._wbits_ & obj._all_bits_except_volatile_ obj._wbits_ = 0 - obj._update_dbvals_(False) + obj._update_dbvals_(False, new_dbvals) def _save_deleted_(obj): values = [] values.extend(obj._get_raw_pkval_()) cache = obj._session_cache_ - if obj not in cache.for_update: - optimistic_columns, optimistic_converters, optimistic_values = \ + optimistic_session = cache.db_session is None or cache.db_session.optimistic + if optimistic_session and obj not in cache.for_update: + optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) - else: optimistic_columns = optimistic_converters = () - query_key = (tuple(optimistic_columns), tuple(converter is not None for converter in optimistic_converters)) + else: optimistic_columns = optimistic_converters = optimistic_ops = () + query_key = tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._delete_sql_cache_.get(query_key) if cached_sql is None: where_list = [ 'WHERE' ] - params_count = populate_criteria_list(where_list, obj._pk_columns_, obj._pk_converters_) - if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_converters, params_count) + params_count = populate_criteria_list(where_list, obj._pk_columns_, obj._pk_converters_, repeat('EQ')) + if optimistic_columns: populate_criteria_list( + where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count, optimistic=True) from_ast = [ 'FROM', [ None, 'TABLE', obj._table_ ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) obj.__class__._delete_sql_cache_[query_key] = sql, adapter else: sql, adapter = cached_sql arguments = adapter(values) - database._exec_sql(sql, arguments, start_transaction=True) + cursor = database._exec_sql(sql, arguments, start_transaction=True) + if cursor.rowcount == 0 and cache.db_session.optimistic: + throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) + + def find_updated_attributes(obj): + entity = obj.__class__ + attrs_to_select = [] + attrs_to_select.extend(entity._pk_attrs_) + discr = entity._discriminator_attr_ + if discr is not None and discr.pk_offset is None: + attrs_to_select.append(discr) + for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): + optimistic = attr.optimistic if attr.optimistic is not None else attr.converters[0].optimistic + if optimistic: + attrs_to_select.append(attr) + + optimistic_converters = [] + attr_offsets = {} + select_list = [ 'ALL' ] + for attr in attrs_to_select: + optimistic_converters.extend(attr.converters) + attr_offsets[attr] = offsets = [] + for columns in attr.columns: + select_list.append([ 'COLUMN', None, columns]) + offsets.append(len(select_list) - 2) + + from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] + pk_columns = entity._pk_columns_ + pk_converters = entity._pk_converters_ + criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] + for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ] + sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] + database = entity._database_ + sql, adapter = database._ast2sql(sql_ast) + arguments = adapter(obj._get_raw_pkval_()) + cursor = database._exec_sql(sql, arguments) + row = cursor.fetchone() + if row is None: + return "Object %s was deleted outside of current transaction" % safe_repr(obj) + + real_entity_subclass, pkval, avdict = entity._parse_row_(row, attr_offsets) + diff = [] + for attr, new_dbval in avdict.items(): + old_dbval = obj._dbvals_[attr] + converter = attr.converters[0] + if old_dbval != new_dbval and ( + attr.reverse or not converter.dbvals_equal(old_dbval, new_dbval)): + diff.append('%s (%r -> %r)' % (attr.name, old_dbval, new_dbval)) + + return "Object %s was updated outside of current transaction%s" % ( + safe_repr(obj), ('. Changes: %s' % ', '.join(diff) if diff else '')) + def _save_(obj, dependent_objects=None): - cache = obj._session_cache_ - assert cache.is_alive status = obj._status_ - if status in ('created', 'modified'): obj._save_principal_objects_(dependent_objects) @@ -4746,6 +5424,7 @@ def _save_(obj, dependent_objects=None): assert obj._status_ in saved_statuses cache = obj._session_cache_ + assert cache is not None and cache.is_alive cache.saved_objects.append((obj, obj._status_)) objects_to_save = cache.objects_to_save save_pos = obj._save_pos_ @@ -4760,7 +5439,7 @@ def flush(obj): assert obj._save_pos_ is not None, 'save_pos is None for %s object' % obj._status_ cache = obj._session_cache_ - assert not cache.saved_objects + assert cache is not None and cache.is_alive and not cache.saved_objects with cache.flush_disabled(): obj._before_save_() # should be inside flush_disabled to prevent infinite recursion # TODO: add to documentation that flush is disabled inside before_xxx hooks @@ -4789,7 +5468,8 @@ def after_delete(obj): pass @cut_traceback def to_dict(obj, only=None, exclude=None, with_collections=False, with_lazy=False, related_objects=False): - if obj._session_cache_.modified: obj._session_cache_.flush() + cache = obj._session_cache_ + if cache is not None and cache.is_alive and cache.modified: cache.flush() attrs = obj.__class__._get_attrs_(only, exclude, with_collections, with_lazy) result = {} for attr in attrs: @@ -4852,18 +5532,19 @@ def get_globals_and_locals(args, kwargs, frame_depth, from_generator=False): % (len(args) > 4 and 's' or '', ', '.join(imap(repr, args[3:])))) else: locals = {} - locals.update(sys._getframe(frame_depth+1).f_locals) + if frame_depth is not None: + locals.update(sys._getframe(frame_depth+1).f_locals) if type(func) is types.GeneratorType: globals = func.gi_frame.f_globals locals.update(func.gi_frame.f_locals) - else: + elif frame_depth is not None: globals = sys._getframe(frame_depth+1).f_globals if kwargs: throw(TypeError, 'Keyword arguments cannot be specified together with positional arguments') return func, globals, locals def make_query(args, frame_depth, left_join=False): gen, globals, locals = get_globals_and_locals( - args, kwargs=None, frame_depth=frame_depth+1, from_generator=True) + args, kwargs=None, frame_depth=frame_depth+1 if frame_depth is not None else None, from_generator=True) if isinstance(gen, types.GeneratorType): tree, external_names, cells = decompile(gen) code_key = id(gen.gi_frame.f_code) @@ -4878,35 +5559,35 @@ def make_query(args, frame_depth, left_join=False): @cut_traceback def select(*args): - return make_query(args, frame_depth=3) + return make_query(args, frame_depth=cut_traceback_depth+1) @cut_traceback def left_join(*args): - return make_query(args, frame_depth=3, left_join=True) + return make_query(args, frame_depth=cut_traceback_depth+1, left_join=True) @cut_traceback def get(*args): - return make_query(args, frame_depth=3).get() + return make_query(args, frame_depth=cut_traceback_depth+1).get() @cut_traceback def exists(*args): - return make_query(args, frame_depth=3).exists() + return make_query(args, frame_depth=cut_traceback_depth+1).exists() @cut_traceback def delete(*args): - return make_query(args, frame_depth=3).delete() + return make_query(args, frame_depth=cut_traceback_depth+1).delete() def make_aggrfunc(std_func): def aggrfunc(*args, **kwargs): - if kwargs: return std_func(*args, **kwargs) - if len(args) != 1: return std_func(*args) + if not args: + return std_func(**kwargs) arg = args[0] if type(arg) is types.GeneratorType: try: iterator = arg.gi_frame.f_locals['.0'] - except: return std_func(*args) + except: return std_func(*args, **kwargs) if isinstance(iterator, EntityIter): - return getattr(select(arg), std_func.__name__)() - return std_func(*args) + return getattr(select(arg), std_func.__name__)(*args[1:], **kwargs) + return std_func(*args, **kwargs) aggrfunc.__name__ = std_func.__name__ return aggrfunc @@ -4915,6 +5596,7 @@ def aggrfunc(*args, **kwargs): min = make_aggrfunc(builtins.min) max = make_aggrfunc(builtins.max) avg = make_aggrfunc(utils.avg) +group_concat = make_aggrfunc(utils.group_concat) distinct = make_aggrfunc(utils.distinct) @@ -4924,34 +5606,48 @@ def JOIN(expr): def desc(expr): if isinstance(expr, Attribute): return expr.desc - if isinstance(expr, int_types) and expr > 0: + if isinstance(expr, DescWrapper): + return expr.attr + if isinstance(expr, int_types): return -expr if isinstance(expr, basestring): return 'desc(%s)' % expr return expr -def raw_sql(sql, result_type=None): - globals = sys._getframe(1).f_globals - locals = sys._getframe(1).f_locals - return RawSQL(sql, globals, locals, result_type) - -def extract_vars(extractors, globals, locals, cells=None): +def extract_vars(code_key, filter_num, extractors, globals, locals, cells=None): if cells: locals = locals.copy() for name, cell in cells.items(): - locals[name] = cell.cell_contents + try: + locals[name] = cell.cell_contents + except ValueError: + throw(NameError, 'Free variable `%s` referenced before assignment in enclosing scope' % name) vars = {} - vartypes = {} - for key, code in iteritems(extractors): - filter_num, src = key - if src == '.0': value = locals['.0'] - else: - try: value = eval(code, globals, locals) - except Exception as cause: raise ExprEvalError(src, cause) - if src == 'None' and value is not None: throw(TranslationError) - if src == 'True' and value is not True: throw(TranslationError) - if src == 'False' and value is not False: throw(TranslationError) - try: vartypes[key] = get_normalized_type_of(value) + vartypes = HashableDict() + for src, extractor in iteritems(extractors): + varkey = filter_num, src, code_key + try: value = extractor(globals, locals) + except Exception as cause: raise ExprEvalError(src, cause) + + if isinstance(value, types.GeneratorType): + value = make_query((value,), frame_depth=None) + + if isinstance(value, QueryResultIterator): + qr = value._query_result + value = qr if not qr._items else tuple(qr._items[value._position:]) + + if isinstance(value, QueryResult) and value._items: + value = tuple(value._items) + + if isinstance(value, (Query, QueryResult, SetIterator)): + query = value._get_query() + vars.update(query._vars) + vartypes.update(query._translator.vartypes) + + if src == 'None' and value is not None: throw(TranslationError) + if src == 'True' and value is not True: throw(TranslationError) + if src == 'False' and value is not False: throw(TranslationError) + try: vartypes[varkey], value = normalize(value) except TypeError: if not isinstance(value, dict): unsupported = False @@ -4960,10 +5656,11 @@ def extract_vars(extractors, globals, locals, cells=None): else: unsupported = True if unsupported: typename = type(value).__name__ - if src == '.0': throw(TypeError, 'Cannot iterate over non-entity object') + if src == '.0': + throw(TypeError, 'Query cannot iterate over anything but entity class or another query') throw(TypeError, 'Expression `%s` has unsupported type %r' % (src, typename)) - vartypes[key] = get_normalized_type_of(value) - vars[key] = value + vartypes[varkey], value = normalize(value) + vars[varkey] = value return vars, vartypes def unpickle_query(query_result): @@ -4972,46 +5669,81 @@ def unpickle_query(query_result): class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) - extractors, varnames, tree = create_extractors( - code_key, tree, 0, globals, locals, special_functions, const_functions) - vars, vartypes = extract_vars(extractors, globals, locals, cells) + tree, extractors = create_extractors(code_key, tree, globals, locals, special_functions, const_functions) + filter_num = 0 + vars, vartypes = extract_vars(code_key, filter_num, extractors, globals, locals, cells) node = tree.quals[0].iter - origin = vars[0, node.src] - if isinstance(origin, EntityIter): origin = origin.entity - elif not isinstance(origin, EntityMeta): - if node.src == '.0': throw(TypeError, 'Cannot iterate over non-entity object') - throw(TypeError, 'Cannot iterate over non-entity object %s' % node.src) - query._origin = origin - database = origin._database_ - if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) - if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) + varkey = filter_num, node.src, code_key + origin = vars[varkey] + if isinstance(origin, Query): + prev_query = origin + elif isinstance(origin, QueryResult): + prev_query = origin._query + elif isinstance(origin, QueryResultIterator): + prev_query = origin._query_result._query + elif isinstance(origin, SetIterator): + prev_query = origin._query + else: + prev_query = None + if not isinstance(origin, EntityMeta): + if node.src == '.0': throw(TypeError, + 'Query can only iterate over entity or another query (not a list of objects)') + throw(TypeError, 'Cannot iterate over non-entity object %s' % node.src) + database = origin._database_ + if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) + if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) + + if prev_query is not None: + database = prev_query._translator.database + filter_num = prev_query._filter_num + 1 + vars, vartypes = extract_vars(code_key, filter_num, extractors, globals, locals, cells) + + query._filter_num = filter_num database.provider.normalize_vars(vars, vartypes) - query._vars = vars - query._key = code_key, tuple(vartypes[name] for name in varnames), left_join + + query._code_key = code_key + query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) query._database = database - translator = database._translator_cache.get(query._key) + translator, vars = query._get_translator(query._key, vars) + query._vars = vars + if translator is None: - pickled_tree = pickle.dumps(tree, 2) - tree = pickle.loads(pickled_tree) # tree = deepcopy(tree) + pickled_tree = pickle_ast(tree) + tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls - translator = translator_cls(tree, extractors, vartypes, left_join=left_join) + try: + translator = translator_cls(tree_copy, None, code_key, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) + except UseAnotherTranslator as e: + translator = e.translator name_path = translator.can_be_optimized() if name_path: - tree = pickle.loads(pickled_tree) # tree = deepcopy(tree) - try: translator = translator_cls(tree, extractors, vartypes, left_join=True, optimize=name_path) - except OptimizationFailed: translator.optimization_failed = True + tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) + try: + translator = translator_cls(tree_copy, None, code_key, filter_num, extractors, vars, vartypes.copy(), + left_join=True, optimize=name_path) + except UseAnotherTranslator as e: + translator = e.translator + except OptimizationFailed: + translator.optimization_failed = True translator.pickled_tree = pickled_tree - database._translator_cache[query._key] = translator + if translator.can_be_cached: + database._translator_cache[query._key] = translator + query._translator = translator query._filters = () query._next_kwarg_id = 0 - query._for_update = query._nowait = False + query._for_update = query._nowait = query._skip_locked = False query._distinct = None query._prefetch = False - query._entities_to_prefetch = set() - query._attrs_to_prefetch_dict = defaultdict(set) + query._prefetch_context = PrefetchContext(query._database) + def _get_query(query): + return query + def _get_type_(query): + return QueryType(query) + def _normalize_var(query, query_type): + return query_type, query def _clone(query, **kwargs): new_query = object.__new__(Query) new_query.__dict__.update(query.__dict__) @@ -5019,20 +5751,57 @@ def _clone(query, **kwargs): return new_query def __reduce__(query): return unpickle_query, (query._fetch(),) - def _construct_sql_and_arguments(query, range=None, aggr_func_name=None): + def _get_translator(query, query_key, vars): + new_vars = vars.copy() + database = query._database + translator = database._translator_cache.get(query_key) + all_func_vartypes = {} + if translator is not None: + if translator.func_extractors_map: + for func, func_extractors in iteritems(translator.func_extractors_map): + func_id = id(func.func_code if PY2 else func.__code__) + func_filter_num = translator.filter_num, 'func', func_id + func_vars, func_vartypes = extract_vars( + func_id, func_filter_num, func_extractors, func.__globals__, {}, func.__closure__) # todo closures + database.provider.normalize_vars(func_vars, func_vartypes) + new_vars.update(func_vars) + all_func_vartypes.update(func_vartypes) + if all_func_vartypes != translator.func_vartypes: + return None, vars.copy() + for key, val in iteritems(translator.fixed_param_values): + assert key in new_vars + if val != new_vars[key]: + del database._translator_cache[query_key] + return None, vars.copy() + return translator, new_vars + def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type - if isinstance(expr_type, EntityMeta) and query._attrs_to_prefetch_dict: - attrs_to_prefetch = tuple(sorted(query._attrs_to_prefetch_dict.get(expr_type, ()))) + attrs_to_prefetch_dict = query._prefetch_context.attrs_to_prefetch_dict + if isinstance(expr_type, EntityMeta) and attrs_to_prefetch_dict: + attrs_to_prefetch = tuple(sorted(attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () - sql_key = query._key + (range, query._distinct, aggr_func_name, query._for_update, query._nowait, - options.INNER_JOIN_SYNTAX, attrs_to_prefetch) + sql_key = HashableDict( + query._key, + vartypes=HashableDict(query._translator.vartypes), + fixed_param_values=HashableDict(translator.fixed_param_values), + limit=limit, + offset=offset, + distinct=query._distinct, + aggr_func=(aggr_func_name, aggr_func_distinct, sep), + for_update=query._for_update, + nowait=query._nowait, + skip_locked=query._skip_locked, + inner_join_syntax=options.INNER_JOIN_SYNTAX, + attrs_to_prefetch=attrs_to_prefetch + ) database = query._database cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( - range, query._distinct, aggr_func_name, query._for_update, query._nowait, attrs_to_prefetch) + limit, offset, query._distinct, aggr_func_name, aggr_func_distinct, sep, + query._for_update, query._nowait, query._skip_locked) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) cache_entry = sql, adapter, attr_offsets @@ -5040,132 +5809,124 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None): else: sql, adapter, attr_offsets = cache_entry arguments = adapter(query._vars) if query._translator.query_result_is_cacheable: - arguments_type = type(arguments) - if arguments_type is tuple: arguments_key = arguments - elif arguments_type is dict: arguments_key = tuple(sorted(iteritems(arguments))) + arguments_key = HashableDict(arguments) if type(arguments) is dict else arguments try: hash(arguments_key) except: query_key = None # arguments are unhashable - else: query_key = sql_key + (arguments_key) + else: query_key = HashableDict(sql_key, arguments_key=arguments_key) else: query_key = None return sql, arguments, attr_offsets, query_key def get_sql(query): sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments() return sql - def _fetch(query, range=None): + def _actual_fetch(query, limit=None, offset=None): translator = query._translator - sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(range) - database = query._database - cache = database._get_cache() - if query._for_update: cache.immediate = True - cache.prepare_connection_for_query_execution() # may clear cache.query_results - try: result = cache.query_results[query_key] - except KeyError: - cursor = database._exec_sql(sql, arguments) - if isinstance(translator.expr_type, EntityMeta): - entity = translator.expr_type - result = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, - used_attrs=translator.get_used_attrs()) - elif len(translator.row_layout) == 1: - func, slice_or_offset, src = translator.row_layout[0] - result = list(starmap(func, cursor.fetchall())) + with query._prefetch_context: + sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(limit, offset) + database = query._database + cache = database._get_cache() + if query._for_update: cache.immediate = True + cache.prepare_connection_for_query_execution() # may clear cache.query_results + items = cache.query_results.get(query_key) + if items is None: + cursor = database._exec_sql(sql, arguments) + if isinstance(translator.expr_type, EntityMeta): + entity = translator.expr_type + items = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, + used_attrs=translator.get_used_attrs()) + elif len(translator.row_layout) == 1: + func, slice_or_offset, src = translator.row_layout[0] + items = list(starmap(func, cursor.fetchall())) + else: + items = [ tuple(func(sql_row[slice_or_offset]) + for func, slice_or_offset, src in translator.row_layout) + for sql_row in cursor.fetchall() ] + for i, t in enumerate(translator.expr_type): + if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in items) + if query_key is not None: cache.query_results[query_key] = items else: - result = [ tuple(func(sql_row[slice_or_offset]) - for func, slice_or_offset, src in translator.row_layout) - for sql_row in cursor.fetchall() ] - for i, t in enumerate(translator.expr_type): - if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in result) - if query_key is not None: cache.query_results[query_key] = result - else: - stats = database._dblocal.stats - stat = stats.get(sql) - if stat is not None: stat.cache_count += 1 - else: stats[sql] = QueryStat(sql) - - if query._prefetch: query._do_prefetch(result) - return QueryResult(result, query, translator.expr_type, translator.col_names) + stats = database._dblocal.stats + stat = stats.get(sql) + if stat is not None: stat.cache_count += 1 + else: stats[sql] = QueryStat(sql) + if query._prefetch: query._do_prefetch(items) + return items @cut_traceback def prefetch(query, *args): - query = query._clone(_entities_to_prefetch=query._entities_to_prefetch.copy(), - _attrs_to_prefetch_dict=query._attrs_to_prefetch_dict.copy()) + query = query._clone(_prefetch_context=query._prefetch_context.copy()) query._prefetch = True + prefetch_context = query._prefetch_context for arg in args: if isinstance(arg, EntityMeta): entity = arg if query._database is not entity._database_: throw(TypeError, 'Entity %s belongs to different database and cannot be prefetched' % entity.__name__) - query._entities_to_prefetch.add(entity) + prefetch_context.entities_to_prefetch.add(entity) elif isinstance(arg, Attribute): attr = arg entity = attr.entity if query._database is not entity._database_: throw(TypeError, 'Entity of attribute %s belongs to different database and cannot be prefetched' % attr) if isinstance(attr.py_type, EntityMeta) or attr.lazy: - query._attrs_to_prefetch_dict[entity].add(attr) + prefetch_context.attrs_to_prefetch_dict[entity].add(attr) else: throw(TypeError, 'Argument of prefetch() query method must be entity class or attribute. ' 'Got: %r' % arg) return query - def _do_prefetch(query, result): + def _do_prefetch(query, query_result): expr_type = query._translator.expr_type - object_list = [] - object_set = set() - append_to_object_list = object_list.append - add_to_object_set = object_set.add + all_objects = set() + objects_to_process = set() + objects_to_prefetch = set() if isinstance(expr_type, EntityMeta): - for obj in result: - if obj not in object_set: - add_to_object_set(obj) - append_to_object_list(obj) + objects_to_process.update(query_result) + all_objects.update(query_result) elif type(expr_type) is tuple: - for i, t in enumerate(expr_type): - if not isinstance(t, EntityMeta): continue - for row in result: - obj = row[i] - if obj not in object_set: - add_to_object_set(obj) - append_to_object_list(obj) + obj_indexes = [ i for i, t in enumerate(expr_type) if isinstance(t, EntityMeta) ] + if obj_indexes: + for row in query_result: + objects_to_prefetch.update(row[i] for i in obj_indexes) + all_objects.update(objects_to_prefetch) + + prefetch_context = local.prefetch_context + assert prefetch_context + collection_prefetch_dict = defaultdict(set) + + objects_to_prefetch_dict = defaultdict(set) + while objects_to_process or objects_to_prefetch: + for obj in objects_to_process: + entity = obj.__class__ + relations_to_prefetch = prefetch_context.get_relations_to_prefetch(entity) + for attr in relations_to_prefetch: + if attr.is_collection: + collection_prefetch_dict[attr].add(obj) + else: + obj2 = attr.get(obj) + if obj2 is not None and obj2 not in all_objects: + all_objects.add(obj2) + objects_to_prefetch.add(obj2) + + next_objects_to_process = set() + for attr, objects in collection_prefetch_dict.items(): + items = attr.prefetch_load_all(objects) + if attr.reverse.is_collection: + objects_to_prefetch.update(items) + else: + next_objects_to_process.update(item for item in items if item not in all_objects) + collection_prefetch_dict.clear() - cache = query._database._get_cache() - entities_to_prefetch = query._entities_to_prefetch - attrs_to_prefetch_dict = query._attrs_to_prefetch_dict - prefetching_attrs_cache = {} - for obj in object_list: - entity = obj.__class__ - if obj in cache.seeds[entity._pk_attrs_]: obj._load_() + for obj in objects_to_prefetch: + objects_to_prefetch_dict[obj.__class__._root_].add(obj) + objects_to_prefetch.clear() - all_attrs_to_prefetch = prefetching_attrs_cache.get(entity) - if all_attrs_to_prefetch is None: - all_attrs_to_prefetch = [] - append = all_attrs_to_prefetch.append - attrs_to_prefetch = attrs_to_prefetch_dict[entity] - for attr in obj._attrs_: - if attr.is_collection: - if attr in attrs_to_prefetch: append(attr) - elif attr.is_relation: - if attr in attrs_to_prefetch or attr.py_type in entities_to_prefetch: append(attr) - elif attr.lazy: - if attr in attrs_to_prefetch: append(attr) - prefetching_attrs_cache[entity] = all_attrs_to_prefetch + for entity, objects in objects_to_prefetch_dict.items(): + next_objects_to_process.update(objects) + entity._prefetch_load_all_(objects) + objects_to_prefetch_dict.clear() - for attr in all_attrs_to_prefetch: - if attr.is_collection: - if not isinstance(attr, Set): throw(NotImplementedError) - setdata = obj._vals_.get(attr) - if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) - for obj2 in setdata: - if obj2 not in object_set: - add_to_object_set(obj2) - append_to_object_list(obj2) - elif attr.is_relation: - obj2 = attr.get(obj) - if obj2 is not None and obj2 not in object_set: - add_to_object_set(obj2) - append_to_object_list(obj2) - elif attr.lazy: attr.get(obj) - else: assert False # pragma: no cover + objects_to_process = next_objects_to_process @cut_traceback - def show(query, width=None): - query._fetch().show(width) + def show(query, width=None, stream=None): + query._fetch().show(width, stream) @cut_traceback def get(query): objects = query[:2] @@ -5200,11 +5961,11 @@ def delete(query, bulk=None): if not isinstance(query._translator.expr_type, EntityMeta): throw(TypeError, 'Delete query should be applied to a single entity. Got: %s' % ast2src(query._translator.tree.expr)) - objects = query._fetch() + objects = query._actual_fetch() for obj in objects: obj._delete_() return len(objects) translator = query._translator - sql_key = query._key + ('DELETE',) + sql_key = HashableDict(query._key, sql_command='DELETE') database = query._database cache = database._get_cache() cache_entry = database._constructed_sql_cache.get(sql_key) @@ -5217,29 +5978,36 @@ def delete(query, bulk=None): cache.immediate = True cache.prepare_connection_for_query_execution() # may clear cache.query_results cursor = database._exec_sql(sql, arguments) + cache.query_results.clear() return cursor.rowcount @cut_traceback def __len__(query): - return len(query._fetch()) + return len(query._actual_fetch()) @cut_traceback def __iter__(query): - return iter(query._fetch()) + return iter(query._fetch(lazy=True)) @cut_traceback def order_by(query, *args): - if not args: throw(TypeError, 'order_by() method requires at least one argument') + return query._order_by('order_by', *args) + @cut_traceback + def sort_by(query, *args): + return query._order_by('sort_by', *args) + def _order_by(query, method_name, *args): + if not args: throw(TypeError, '%s() method requires at least one argument' % method_name) if args[0] is None: - if len(args) > 1: throw(TypeError, 'When first argument of order_by() method is None, it must be the only argument') - tup = ((),) - new_key = query._key + tup + if len(args) > 1: throw(TypeError, 'When first argument of %s() method is None, it must be the only argument' % method_name) + tup = (('without_order',),) + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup - new_translator = query._database._translator_cache.get(new_key) + + new_translator, new_vars = query._get_translator(new_key, query._vars) if new_translator is None: new_translator = query._translator.without_order() query._database._translator_cache[new_key] = new_translator return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) if isinstance(args[0], (basestring, types.FunctionType)): - func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+2) return query._process_lambda(func, globals, locals, order_by=True) if isinstance(args[0], RawSQL): @@ -5253,15 +6021,18 @@ def order_by(query, *args): else: throw(TypeError, "order_by() method receive an argument of invalid type: %r" % arg) if numbers and attributes: throw(TypeError, 'order_by() method receive invalid combination of arguments') - new_key = query._key + ('order_by', args,) - new_filters = query._filters + ((numbers, args),) - new_translator = query._database._translator_cache.get(new_key) + + tup = (('order_by_numbers' if numbers else 'order_by_attributes', args),) + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) + new_filters = query._filters + tup + + new_translator, new_vars = query._get_translator(new_key, query._vars) if new_translator is None: if numbers: new_translator = query._translator.order_by_numbers(args) else: new_translator = query._translator.order_by_attributes(args) query._database._translator_cache[new_key] = new_translator return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) - def _process_lambda(query, func, globals, locals, order_by): + def _process_lambda(query, func, globals, locals, order_by=False, original_names=False): prev_translator = query._translator argnames = () if isinstance(func, basestring): @@ -5273,7 +6044,6 @@ def _process_lambda(query, func, globals, locals, order_by): cells = None elif type(func) is types.FunctionType: argnames = get_lambda_args(func) - subquery = prev_translator.subquery func_id = id(func.func_code if PY2 else func.__code__) func_ast, external_names, cells = decompile(func) elif not order_by: throw(TypeError, @@ -5281,57 +6051,58 @@ def _process_lambda(query, func, globals, locals, order_by): else: assert False # pragma: no cover if argnames: - expr_type = prev_translator.expr_type - expr_count = len(expr_type) if type(expr_type) is tuple else 1 - if len(argnames) != expr_count: - throw(TypeError, 'Incorrect number of lambda arguments. ' - 'Expected: %d, got: %d' % (expr_count, len(argnames))) - - filter_num = len(query._filters) + 1 - extractors, varnames, func_ast = create_extractors( - func_id, func_ast, filter_num, globals, locals, special_functions, const_functions, - argnames or prev_translator.subquery) + if original_names: + for name in argnames: + if name not in prev_translator.namespace: throw(TypeError, + 'Lambda argument `%s` does not correspond to any variable in original query' % name) + else: + expr_type = prev_translator.expr_type + expr_count = len(expr_type) if type(expr_type) is tuple else 1 + if len(argnames) != expr_count: + throw(TypeError, 'Incorrect number of lambda arguments. ' + 'Expected: %d, got: %d' % (expr_count, len(argnames))) + else: + original_names = True + + new_filter_num = query._filter_num + 1 + func_ast, extractors = create_extractors( + func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.namespace) if extractors: - vars, vartypes = extract_vars(extractors, globals, locals, cells) + vars, vartypes = extract_vars(func_id, new_filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) - new_query_vars = query._vars.copy() - new_query_vars.update(vars) - sorted_vartypes = tuple(vartypes[name] for name in varnames) - else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () - - new_key = query._key + (('order_by' if order_by else 'filter', func_id, sorted_vartypes),) - new_filters = query._filters + ((order_by, func_ast, argnames, extractors, vartypes),) - new_translator = query._database._translator_cache.get(new_key) + new_vars = query._vars.copy() + new_vars.update(vars) + else: new_vars, vartypes = query._vars, HashableDict() + tup = (('order_by' if order_by else 'where' if original_names else 'filter', func_id, vartypes),) + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) + new_filters = query._filters + (('apply_lambda', func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) + + new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: prev_optimized = prev_translator.optimize - new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, extractors, vartypes) + new_translator = prev_translator.apply_lambda(func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: - tree = pickle.loads(prev_translator.pickled_tree) # tree = deepcopy(tree) - prev_extractors = prev_translator.extractors - prev_vartypes = prev_translator.vartypes + tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) translator_cls = prev_translator.__class__ - new_translator = translator_cls(tree, prev_extractors, prev_vartypes, - left_join=True, optimize=name_path) + try: + new_translator = translator_cls( + tree_copy, None, prev_translator.original_code_key, prev_translator.original_filter_num, + prev_translator.extractors, None, prev_translator.vartypes.copy(), + left_join=True, optimize=name_path) + except UseAnotherTranslator: + assert False new_translator = query._reapply_filters(new_translator) - new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, extractors, vartypes) + new_translator = new_translator.apply_lambda(func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) query._database._translator_cache[new_key] = new_translator - return query._clone(_vars=new_query_vars, _key=new_key, _filters=new_filters, _translator=new_translator) + return query._clone(_filter_num=new_filter_num, _vars=new_vars, _key=new_key, _filters=new_filters, + _translator=new_translator) def _reapply_filters(query, translator): - for i, tup in enumerate(query._filters): - if not tup: - translator = translator.without_order() - elif len(tup) == 1: - attrnames = tup[0] - translator.apply_kwfilters(attrnames) - elif len(tup) == 2: - numbers, args = tup - if numbers: translator = translator.order_by_numbers(args) - else: translator = translator.order_by_attributes(args) - else: - order_by, func_ast, argnames, extractors, vartypes = tup - translator = translator.apply_lambda(i+1, order_by, func_ast, argnames, extractors, vartypes) + for tup in query._filters: + method_name, args = tup[0], tup[1:] + translator_method = getattr(translator, method_name) + translator = translator_method(*args) return translator @cut_traceback def filter(query, *args, **kwargs): @@ -5339,14 +6110,36 @@ def filter(query, *args, **kwargs): if isinstance(args[0], RawSQL): raw = args[0] return query.filter(lambda: raw) - func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1) return query._process_lambda(func, globals, locals, order_by=False) if not kwargs: return query entity = query._translator.expr_type if not isinstance(entity, EntityMeta): throw(TypeError, 'Keyword arguments are not allowed: since query result type is not an entity, filter() method can accept only lambda') + return query._apply_kwargs(kwargs) + @cut_traceback + def where(query, *args, **kwargs): + if args: + if isinstance(args[0], RawSQL): + raw = args[0] + return query.where(lambda: raw) + func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1) + return query._process_lambda(func, globals, locals, order_by=False, original_names=True) + if not kwargs: return query + if len(query._translator.tree.quals) > 1: throw(TypeError, + 'Keyword arguments are not allowed: query iterates over more than one entity') + return query._apply_kwargs(kwargs, original_names=True) + def _apply_kwargs(query, kwargs, original_names=False): + translator = query._translator + if original_names: + tablerefs = translator.sqlquery.tablerefs + alias = translator.tree.quals[0].assign.name + tableref = tablerefs[alias] + entity = tableref.entity + else: + entity = translator.expr_type get_attr = entity._adict_.get filterattrs = [] value_dict = {} @@ -5365,44 +6158,51 @@ def filter(query, *args, **kwargs): value_dict[id] = val filterattrs = tuple(filterattrs) - new_key = query._key + ('filter', filterattrs) - new_filters = query._filters + ((filterattrs,),) - new_translator = query._database._translator_cache.get(new_key) + tup = (('apply_kwfilters', filterattrs, original_names),) + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) + new_filters = query._filters + tup + new_vars = query._vars.copy() + new_vars.update(value_dict) + new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: - new_translator = query._translator.apply_kwfilters(filterattrs) + new_translator = translator.apply_kwfilters(filterattrs, original_names) query._database._translator_cache[new_key] = new_translator - new_query = query._clone(_key=new_key, _filters=new_filters, _translator=new_translator, - _next_kwarg_id=next_id, _vars=query._vars.copy()) - new_query._vars.update(value_dict) - return new_query + return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator, + _next_kwarg_id=next_id, _vars=new_vars) @cut_traceback def __getitem__(query, key): - if isinstance(key, slice): - step = key.step - if step is not None and step != 1: throw(TypeError, "Parameter 'step' of slice object is not allowed here") - start = key.start - if start is None: start = 0 - elif start < 0: throw(TypeError, "Parameter 'start' of slice object cannot be negative") - stop = key.stop - if stop is None: - if not start: return query._fetch() - else: throw(TypeError, "Parameter 'stop' of slice object should be specified") - else: throw(TypeError, 'If you want apply index to query, convert it to list first') - if start >= stop: return [] - return query._fetch(range=(start, stop)) + if not isinstance(key, slice): + throw(TypeError, 'If you want apply index to a query, convert it to list first') + step = key.step + if step is not None and step != 1: throw(TypeError, "Parameter 'step' of slice object is not allowed here") + start = key.start + if start is None: start = 0 + elif start < 0: throw(TypeError, "Parameter 'start' of slice object cannot be negative") + stop = key.stop + if stop is None: + if not start: + return query._fetch() + else: + return query._fetch(limit=None, offset=start) + if start >= stop: + return query._fetch(limit=0) + return query._fetch(limit=stop-start, offset=start) + def _fetch(query, limit=None, offset=None, lazy=False): + return QueryResult(query, limit, offset, lazy=lazy) @cut_traceback - def limit(query, limit, offset=None): - start = offset or 0 - stop = start + limit - return query[start:stop] + def fetch(query, limit=None, offset=None): + return query._fetch(limit, offset) + @cut_traceback + def limit(query, limit=None, offset=None): + return query._fetch(limit, offset, lazy=True) @cut_traceback def page(query, pagenum, pagesize=10): - start = (pagenum - 1) * pagesize - stop = pagenum * pagesize - return query[start:stop] - def _aggregate(query, aggr_func_name): + offset = (pagenum - 1) * pagesize + return query._fetch(pagesize, offset, lazy=True) + def _aggregate(query, aggr_func_name, distinct=None, sep=None): translator = query._translator - sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(aggr_func_name=aggr_func_name) + sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments( + aggr_func_name=aggr_func_name, aggr_func_distinct=distinct, sep=sep) cache = query._database._get_cache() try: result = cache.query_results[query_key] except KeyError: @@ -5414,18 +6214,29 @@ def _aggregate(query, aggr_func_name): if result is None: pass elif aggr_func_name == 'COUNT': pass else: - expr_type = float if aggr_func_name == 'AVG' else translator.expr_type + if aggr_func_name == 'AVG': + expr_type = float + elif aggr_func_name == 'GROUP_CONCAT': + expr_type = basestring + else: + expr_type = translator.expr_type provider = query._database.provider converter = provider.get_converter_by_py_type(expr_type) result = converter.sql2py(result) if query_key is not None: cache.query_results[query_key] = result return result @cut_traceback - def sum(query): - return query._aggregate('SUM') + def sum(query, distinct=None): + return query._aggregate('SUM', distinct) + @cut_traceback + def avg(query, distinct=None): + return query._aggregate('AVG', distinct) @cut_traceback - def avg(query): - return query._aggregate('AVG') + def group_concat(query, sep=None, distinct=None): + if sep is not None: + if not isinstance(sep, basestring): + throw(TypeError, '`sep` option for `group_concat` should be of type str. Got: %s' % type(sep).__name__) + return query._aggregate('GROUP_CONCAT', distinct, sep) @cut_traceback def min(query): return query._aggregate('MIN') @@ -5433,44 +6244,139 @@ def min(query): def max(query): return query._aggregate('MAX') @cut_traceback - def count(query): - return query._aggregate('COUNT') + def count(query, distinct=None): + return query._aggregate('COUNT', distinct) @cut_traceback - def for_update(query, nowait=False): - provider = query._database.provider - if nowait and not provider.select_for_update_nowait_syntax: throw(TranslationError, - '%s provider does not support SELECT FOR UPDATE NOWAIT syntax' % provider.dialect) - return query._clone(_for_update=True, _nowait=nowait) + def for_update(query, nowait=False, skip_locked=False): + if nowait and skip_locked: + throw(TypeError, 'nowait and skip_locked options are mutually exclusive') + return query._clone(_for_update=True, _nowait=nowait, _skip_locked=skip_locked) def random(query, limit): return query.order_by('random()')[:limit] def to_json(query, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return query._database.to_json(query[:], include, exclude, converter, with_schema, schema_hash) -def strcut(s, width): - if len(s) <= width: - return s + ' ' * (width - len(s)) - else: - return s[:width-3] + '...' -class QueryResult(list): - __slots__ = '_query', '_expr_type', '_col_names' - def __init__(result, list, query, expr_type, col_names): - result[:] = list - result._query = query - result._expr_type = expr_type - result._col_names = col_names - def __getstate__(result): - return list(result), result._expr_type, result._col_names - def __setstate__(result, state): - result[:] = state[0] - result._expr_type = state[1] - result._col_names = state[2] +class QueryResultIterator(object): + __slots__ = '_query_result', '_position' + def __init__(self, query_result): + self._query_result = query_result + self._position = 0 + def _get_type_(self): + if self._position != 0: + throw(NotImplementedError, 'Cannot use partially exhausted iterator, please convert to list') + return self._query_result._get_type_() + def _normalize_var(self, query_type): + if self._position != 0: throw(NotImplementedError) + return self._query_result._normalize_var(query_type) + def next(self): + qr = self._query_result + if qr._items is None: + qr._items = qr._query._actual_fetch(qr._limit, qr._offset) + if self._position >= len(qr._items): + raise StopIteration + item = qr._items[self._position] + self._position += 1 + return item + __next__ = next + def __length_hint__(self): + return len(self._query_result) - self._position + + +def make_query_result_method_error_stub(name, title=None): + def func(self, *args, **kwargs): + throw(TypeError, 'In order to do %s, cast QueryResult to list first' % (title or name)) + return func + +class QueryResult(object): + __slots__ = '_query', '_limit', '_offset', '_items', '_expr_type', '_col_names' + def __init__(self, query, limit, offset, lazy): + translator = query._translator + self._query = query + self._limit = limit + self._offset = offset + self._items = None if lazy else self._query._actual_fetch(limit, offset) + self._expr_type = translator.expr_type + self._col_names = translator.col_names + def _get_query(self): + return self._query + def _get_type_(self): + if self._items is None: + return QueryType(self._query, self._limit, self._offset) + item_type = self._query._translator.expr_type + return tuple(item_type for item in self._items) + def _normalize_var(self, query_type): + if self._items is None: + return query_type, self._query + items = tuple(normalize(item) for item in self._items) + item_type = self._query._translator.expr_type + return tuple(item_type for item in items), items + def _get_items(self): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + return self._items + def __getstate__(self): + return self._get_items(), self._limit, self._offset, self._expr_type, self._col_names + def __setstate__(self, state): + self._query = None + self._items, self._limit, self._offset, self._expr_type, self._col_names = state + def __repr__(self): + if self._items is not None: + return self.__str__() + return '' % hex(id(self)) + def __str__(self): + return repr(self._get_items()) + def __iter__(self): + return QueryResultIterator(self) + def __len__(self): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + return len(self._items) + def __getitem__(self, key): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + return self._items[key] + def __contains__(self, item): + return item in self._get_items() + def index(self, item): + return self._get_items().index(item) + def _other_items(self, other): + return other._get_items() if isinstance(other, QueryResult) else other + def __eq__(self, other): + return self._get_items() == self._other_items(other) + def __ne__(self, other): + return self._get_items() != self._other_items(other) + def __lt__(self, other): + return self._get_items() < self._other_items(other) + def __le__(self, other): + return self._get_items() <= self._other_items(other) + def __gt__(self, other): + return self._get_items() > self._other_items(other) + def __ge__(self, other): + return self._get_items() >= self._other_items(other) + def __reversed__(self): + return reversed(self._get_items()) + def reverse(self): + self._get_items().reverse() + def sort(self, *args, **kwargs): + self._get_items().sort(*args, **kwargs) + def shuffle(self): + shuffle(self._get_items()) @cut_traceback - def show(result, width=None): + def show(self, width=None, stream=None): + if stream is None: + stream = sys.stdout + def writeln(s): + stream.write(s) + stream.write('\n') + + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + if not width: width = options.CONSOLE_WIDTH max_columns = width // 5 - expr_type = result._expr_type - col_names = result._col_names + expr_type = self._expr_type + col_names = self._col_names def to_str(x): return tostring(x).replace('\n', ' ') @@ -5483,11 +6389,11 @@ def to_str(x): col_name = col_names[0] row_maker = lambda obj: (getattr(obj, col_name),) else: row_maker = attrgetter(*col_names) - rows = [ tuple(to_str(value) for value in row_maker(obj)) for obj in result ] + rows = [tuple(to_str(value) for value in row_maker(obj)) for obj in self._items] elif len(col_names) == 1: - rows = [ (to_str(obj),) for obj in result ] + rows = [(to_str(obj),) for obj in self._items] else: - rows = [ tuple(to_str(value) for value in row) for row in result ] + rows = [tuple(to_str(value) for value in row) for row in self._items] remaining_columns = {} for col_num, colname in enumerate(col_names): @@ -5511,12 +6417,46 @@ def to_str(x): for col_num, max_len in remaining_columns.items(): width_dict[col_num] = base_len - print(strjoin('|', (strcut(colname, width_dict[i]) for i, colname in enumerate(col_names)))) - print(strjoin('+', ('-' * width_dict[i] for i in xrange(len(col_names))))) + writeln(strjoin('|', (strcut(colname, width_dict[i]) for i, colname in enumerate(col_names)))) + writeln(strjoin('+', ('-' * width_dict[i] for i in xrange(len(col_names))))) for row in rows: - print(strjoin('|', (strcut(item, width_dict[i]) for i, item in enumerate(row)))) - def to_json(result, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): - return result._query._database.to_json(result, include, exclude, converter, with_schema, schema_hash) + writeln(strjoin('|', (strcut(item, width_dict[i]) for i, item in enumerate(row)))) + stream.flush() + def to_json(self, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): + return self._query._database.to_json(self, include, exclude, converter, with_schema, schema_hash) + + def __add__(self, other): + result = [] + result.extend(self) + result.extend(other) + return result + def __radd__(self, other): + result = [] + result.extend(other) + result.extend(self) + return result + def to_list(self): + return list(self) + + __setitem__ = make_query_result_method_error_stub('__setitem__', 'item assignment') + __delitem__ = make_query_result_method_error_stub('__delitem__', 'item deletion') + __iadd__ = make_query_result_method_error_stub('__iadd__', '+=') + __imul__ = make_query_result_method_error_stub('__imul__', '*=') + __mul__ = make_query_result_method_error_stub('__mul__', '*') + __rmul__ = make_query_result_method_error_stub('__rmul__', '*') + append = make_query_result_method_error_stub('append', 'append') + clear = make_query_result_method_error_stub('clear', 'clear') + extend = make_query_result_method_error_stub('extend', 'extend') + insert = make_query_result_method_error_stub('insert', 'insert') + pop = make_query_result_method_error_stub('pop', 'pop') + remove = make_query_result_method_error_stub('remove', 'remove') + + +def strcut(s, width): + if len(s) <= width: + return s + ' ' * (width - len(s)) + else: + return s[:width-3] + '...' @cut_traceback @@ -5541,5 +6481,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = set([ itertools.count, utils.count, count, random, raw_sql ]) -const_functions = set([ buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta ]) +special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr} +const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index f5a4b8bbf..a6612d24c 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, basestring, unicode, buffer, int_types +from pony.py23compat import PY2, basestring, unicode, buffer, int_types, iteritems -import os, re +import os, re, json from decimal import Decimal, InvalidOperation from datetime import datetime, date, time, timedelta from uuid import uuid4, UUID @@ -9,7 +9,7 @@ import pony from pony.utils import is_utf8, decorator, throw, localbase, deprecated from pony.converting import str2date, str2time, str2datetime, str2timedelta -from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType +from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedArray, Json, QueryType, Array class DBException(Exception): def __init__(exc, original_exc, *args): @@ -45,25 +45,48 @@ class NotSupportedError(DatabaseError): pass @decorator def wrap_dbapi_exceptions(func, provider, *args, **kwargs): dbapi_module = provider.dbapi_module - try: return func(provider, *args, **kwargs) - except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) - except dbapi_module.ProgrammingError as e: raise ProgrammingError(e) - except dbapi_module.InternalError as e: raise InternalError(e) - except dbapi_module.IntegrityError as e: raise IntegrityError(e) - except dbapi_module.OperationalError as e: raise OperationalError(e) - except dbapi_module.DataError as e: raise DataError(e) - except dbapi_module.DatabaseError as e: raise DatabaseError(e) - except dbapi_module.InterfaceError as e: - if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb': - throw(InterfaceError, e, 'MySQL server misconfiguration') - raise InterfaceError(e) - except dbapi_module.Error as e: raise Error(e) - except dbapi_module.Warning as e: raise Warning(e) + should_retry = False + try: + try: + if provider.dialect != 'SQLite': + return func(provider, *args, **kwargs) + else: + provider.local_exceptions.keep_traceback = True + try: return func(provider, *args, **kwargs) + finally: provider.local_exceptions.keep_traceback = False + except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) + except dbapi_module.ProgrammingError as e: + if provider.dialect == 'PostgreSQL': + msg = str(e) + if msg.startswith('operator does not exist:') and ' json ' in msg: + msg += ' (Note: use column type `jsonb` instead of `json`)' + raise ProgrammingError(e, msg, *e.args[1:]) + raise ProgrammingError(e) + except dbapi_module.InternalError as e: raise InternalError(e) + except dbapi_module.IntegrityError as e: raise IntegrityError(e) + except dbapi_module.OperationalError as e: + if provider.dialect == 'PostgreSQL' and e.pgcode == '40001': + should_retry = True + if provider.dialect == 'SQLite': + provider.restore_exception() + raise OperationalError(e) + except dbapi_module.DataError as e: raise DataError(e) + except dbapi_module.DatabaseError as e: raise DatabaseError(e) + except dbapi_module.InterfaceError as e: + if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb': + throw(InterfaceError, e, 'MySQL server misconfiguration') + raise InterfaceError(e) + except dbapi_module.Error as e: raise Error(e) + except dbapi_module.Warning as e: raise Warning(e) + except Exception as e: + if should_retry: + e.should_retry = True + raise def unexpected_args(attr, args): - throw(TypeError, - 'Unexpected positional argument%s for attribute %s: %r' - % ((args > 1 and 's' or ''), attr, ', '.join(repr(arg) for arg in args))) + throw(TypeError, 'Unexpected positional argument{} for attribute {}: {}'.format( + len(args) > 1 and 's' or '', attr, ', '.join(repr(arg) for arg in args)) + ) version_re = re.compile('[0-9\.]+') @@ -77,13 +100,12 @@ def get_version_tuple(s): class DBAPIProvider(object): paramstyle = 'qmark' quote_char = '"' - max_params_count = 200 + max_params_count = 999 max_name_len = 128 table_if_not_exists_syntax = True index_if_not_exists_syntax = True max_time_precision = default_time_precision = 6 uint64_support = False - select_for_update_nowait_syntax = True # SQLite and PostgreSQL does not limit varchar max length. varchar_default_max_len = None @@ -93,6 +115,7 @@ class DBAPIProvider(object): dbschema_cls = None translator_cls = None sqlbuilder_cls = None + array_converter_cls = None name_before_table = 'schema_name' default_schema_name = None @@ -101,9 +124,12 @@ class DBAPIProvider(object): def __init__(provider, *args, **kwargs): pool_mockup = kwargs.pop('pony_pool_mockup', None) + call_on_connect = kwargs.pop('pony_call_on_connect', None) if pool_mockup: provider.pool = pool_mockup else: provider.pool = provider.get_pool(*args, **kwargs) - connection = provider.connect() + connection, is_new_connection = provider.connect() + if call_on_connect: + call_on_connect(connection) provider.inspect_connection(connection) provider.release(connection) @@ -126,36 +152,36 @@ def get_default_m2m_table_name(provider, attr, reverse): return provider.normalize_name(name) def get_default_column_names(provider, attr, reverse_pk_columns=None): - normalize = provider.normalize_name + normalize_name = provider.normalize_name if reverse_pk_columns is None: - return [ normalize(attr.name) ] + return [ normalize_name(attr.name) ] elif len(reverse_pk_columns) == 1: - return [ normalize(attr.name) ] + return [ normalize_name(attr.name) ] else: prefix = attr.name + '_' - return [ normalize(prefix + column) for column in reverse_pk_columns ] + return [ normalize_name(prefix + column) for column in reverse_pk_columns ] def get_default_m2m_column_names(provider, entity): - normalize = provider.normalize_name + normalize_name = provider.normalize_name columns = entity._get_pk_columns_() if len(columns) == 1: - return [ normalize(entity.__name__.lower()) ] + return [ normalize_name(entity.__name__.lower()) ] else: prefix = entity.__name__.lower() + '_' - return [ normalize(prefix + column) for column in columns ] + return [ normalize_name(prefix + column) for column in columns ] def get_default_index_name(provider, table_name, column_names, is_pk=False, is_unique=False, m2m=False): - if is_pk: index_name = 'pk_%s' % table_name + if is_pk: index_name = 'pk_%s' % provider.base_name(table_name) else: if is_unique: template = 'unq_%(tname)s__%(cnames)s' elif m2m: template = 'idx_%(tname)s' else: template = 'idx_%(tname)s__%(cnames)s' - index_name = template % dict(tname=table_name, + index_name = template % dict(tname=provider.base_name(table_name), cnames='_'.join(name for name in column_names)) return provider.normalize_name(index_name.lower()) def get_default_fk_name(provider, child_table_name, parent_table_name, child_column_names): - fk_name = 'fk_%s__%s' % (child_table_name, '__'.join(child_column_names)) + fk_name = 'fk_%s__%s' % (provider.base_name(child_table_name), '__'.join(child_column_names)) return provider.normalize_name(fk_name.lower()) def split_table_name(provider, table_name): @@ -169,6 +195,13 @@ def split_table_name(provider, table_name): size, 's' if size != 1 else '', table_name)) return table_name[0], table_name[1] + def base_name(provider, name): + if not isinstance(name, basestring): + assert type(name) is tuple + name = name[-1] + assert isinstance(name, basestring) + return name + def quote_name(provider, name): quote_char = provider.quote_char if isinstance(name, basestring): @@ -176,8 +209,14 @@ def quote_name(provider, name): return quote_char + name + quote_char return '.'.join(provider.quote_name(item) for item in name) + def format_table_name(provider, name): + return provider.quote_name(name) + def normalize_vars(provider, vars, vartypes): - pass + for key, value in iteritems(vars): + vartype = vartypes[key] + if isinstance(vartype, QueryType): + vartypes[key], vars[key] = value._normalize_var(vartype) def ast2sql(provider, ast): builder = provider.sqlbuilder_cls(provider, ast) @@ -197,14 +236,14 @@ def set_transaction_mode(provider, connection, cache): @wrap_dbapi_exceptions def commit(provider, connection, cache=None): core = pony.orm.core - if core.debug: core.log_orm('COMMIT') + if core.local.debug: core.log_orm('COMMIT') connection.commit() if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def rollback(provider, connection, cache=None): core = pony.orm.core - if core.debug: core.log_orm('ROLLBACK') + if core.local.debug: core.log_orm('ROLLBACK') connection.rollback() if cache is not None: cache.in_transaction = False @@ -214,20 +253,20 @@ def release(provider, connection, cache=None): if cache is not None and cache.db_session is not None and cache.db_session.ddl: provider.drop(connection, cache) else: - if core.debug: core.log_orm('RELEASE CONNECTION') + if core.local.debug: core.log_orm('RELEASE CONNECTION') provider.pool.release(connection) @wrap_dbapi_exceptions def drop(provider, connection, cache=None): core = pony.orm.core - if core.debug: core.log_orm('CLOSE CONNECTION') + if core.local.debug: core.log_orm('CLOSE CONNECTION') provider.pool.drop(connection) if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def disconnect(provider): core = pony.orm.core - if core.debug: core.log_orm('DISCONNECT') + if core.local.debug: core.log_orm('DISCONNECT') provider.pool.disconnect() @wrap_dbapi_exceptions @@ -246,6 +285,11 @@ def _get_converter_type_by_py_type(provider, py_type): if isinstance(py_type, type): for t, converter_cls in provider.converter_classes: if issubclass(py_type, t): return converter_cls + if issubclass(py_type, Array): + converter_cls = provider.array_converter_cls + if converter_cls is None: + throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect) + return converter_cls if isinstance(py_type, RawSQLType): return Converter # for cases like select(raw_sql(...) for x in X) throw(TypeError, 'No database converter found for type %s' % py_type) @@ -272,9 +316,8 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): throw(NotImplementedError) def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - cursor.execute('SELECT 1 FROM %s LIMIT 1' % table_name) + cursor.execute('SELECT 1 FROM %s LIMIT 1' % provider.quote_name(table_name)) return cursor.fetchone() is not None def disable_fk_checks(provider, connection): @@ -284,9 +327,8 @@ def enable_fk_checks(provider, connection, prev_state): pass def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - sql = 'DROP TABLE %s' % table_name + sql = 'DROP TABLE %s' % provider.quote_name(table_name) cursor.execute(sql) class Pool(localbase): @@ -302,12 +344,15 @@ def connect(pool): pool.forked_connections.append((pool.con, pool.pid)) pool.con = pool.pid = None core = pony.orm.core + is_new_connection = False if pool.con is None: - if core.debug: core.log_orm('GET NEW CONNECTION') + if core.local.debug: core.log_orm('GET NEW CONNECTION') + is_new_connection = True pool._connect() pool.pid = pid - elif core.debug: core.log_orm('GET CONNECTION FROM THE LOCAL POOL') - return pool.con + elif core.local.debug: + core.log_orm('GET CONNECTION FROM THE LOCAL POOL') + return pool.con, is_new_connection def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) def release(pool, con): @@ -326,6 +371,9 @@ def disconnect(pool): if con is not None: con.close() class Converter(object): + EQ = 'EQ' + NE = 'NE' + optimistic = True def __deepcopy__(converter, memo): return converter # Converter instances are "immutable" def __init__(converter, provider, py_type, attr=None): @@ -339,12 +387,18 @@ def __init__(converter, provider, py_type, attr=None): def init(converter, kwargs): attr = converter.attr if attr and attr.args: unexpected_args(attr, attr.args) - def validate(converter, val): + def validate(converter, val, obj=None): return val def py2sql(converter, val): return val def sql2py(converter, val): return val + def val2dbval(self, val, obj=None): + return val + def dbval2val(self, dbval, obj=None): + return dbval + def dbvals_equal(self, x, y): + return x == y def get_sql_type(converter, attr=None): if attr is not None and attr.sql_type is not None: return attr.sql_type @@ -376,7 +430,7 @@ def get_fk_type(converter, sql_type): assert False class BoolConverter(Converter): - def validate(converter, val): + def validate(converter, val, obj=None): return bool(val) def sql2py(converter, val): return bool(val) @@ -390,9 +444,12 @@ def __init__(converter, provider, py_type, attr=None): Converter.__init__(converter, provider, py_type, attr) def init(converter, kwargs): attr = converter.attr - if not attr.args: max_len = None - elif len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) - else: max_len = attr.args[0] + max_len = kwargs.pop('max_len', None) + if len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) + elif attr.args: + if max_len is not None: throw(TypeError, + 'Max length option specified twice: as a positional argument and as a `max_len` named argument') + max_len = attr.args[0] if issubclass(attr.py_type, (LongStr, LongUnicode)): if max_len is not None: throw(TypeError, 'Max length is not supported for CLOBs') elif max_len is None: max_len = converter.provider.varchar_default_max_len @@ -401,7 +458,7 @@ def init(converter, kwargs): converter.max_len = max_len converter.db_encoding = kwargs.pop('db_encoding', None) converter.autostrip = kwargs.pop('autostrip', True) - def validate(converter, val): + def validate(converter, val, obj=None): if PY2 and isinstance(val, str): val = val.decode('ascii') elif not isinstance(val, unicode): throw(TypeError, 'Value type for attribute %s must be %s. Got: %r' % (converter.attr, unicode.__name__, type(val))) @@ -472,8 +529,10 @@ def init(converter, kwargs): converter.max_val = max_val or highest converter.size = size converter.unsigned = unsigned - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, int_types): pass + elif hasattr(val, '__index__'): + val = val.__index__() elif isinstance(val, basestring): try: val = int(val) except ValueError: throw(ValueError, @@ -497,10 +556,13 @@ def sql_type(converter): return converter.unsigned_types.get(converter.size) class RealConverter(Converter): + EQ = 'FLOAT_EQ' + NE = 'FLOAT_NE' # The tolerance is necessary for Oracle, because it has different representation of float numbers. # For other databases the default tolerance is set because the precision can be lost during # Python -> JavaScript -> Python conversion default_tolerance = 1e-14 + optimistic = False def init(converter, kwargs): Converter.init(converter, kwargs) min_val = kwargs.pop('min', None) @@ -516,7 +578,7 @@ def init(converter, kwargs): converter.min_val = min_val converter.max_val = max_val converter.tolerance = kwargs.pop('tolerance', converter.default_tolerance) - def validate(converter, val): + def validate(converter, val, obj=None): try: val = float(val) except ValueError: throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val)) @@ -527,7 +589,7 @@ def validate(converter, val): throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' % (val, converter.attr, converter.max_val)) return val - def equals(converter, x, y): + def dbvals_equal(converter, x, y): tolerance = converter.tolerance if tolerance is None or x is None or y is None: return x == y denominator = max(abs(x), abs(y)) @@ -581,7 +643,7 @@ def init(converter, kwargs): converter.min_val = min_val converter.max_val = max_val - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, float): s = str(val) if float(s) != val: s = repr(val) @@ -602,18 +664,24 @@ def sql_type(converter): return 'DECIMAL(%d, %d)' % (converter.precision, converter.scale) class BlobConverter(Converter): - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, buffer): return val if isinstance(val, str): return buffer(val) throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) def sql2py(converter, val): - if not isinstance(val, buffer): val = buffer(val) + if not isinstance(val, buffer): + try: val = buffer(val) + except: pass + elif PY2 and converter.attr is not None and converter.attr.is_part_of_unique_index: + try: hash(val) + except TypeError: + val = buffer(val) return val def sql_type(converter): return 'BLOB' class DateConverter(Converter): - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, datetime): return val.date() if isinstance(val, date): return val if isinstance(val, basestring): return str2date(val) @@ -663,7 +731,7 @@ def sql_type(converter): class TimeConverter(ConverterWithMicroseconds): sql_type_name = 'TIME' - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, time): pass elif isinstance(val, basestring): val = str2time(val) else: throw(TypeError, "Attribute %r: expected type is 'time'. Got: %r" % (converter.attr, val)) @@ -677,7 +745,7 @@ def sql2py(converter, val): class TimedeltaConverter(ConverterWithMicroseconds): sql_type_name = 'INTERVAL' - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, timedelta): pass elif isinstance(val, basestring): val = str2timedelta(val) else: throw(TypeError, "Attribute %r: expected type is 'timedelta'. Got: %r" % (converter.attr, val)) @@ -691,7 +759,7 @@ def sql2py(converter, val): class DatetimeConverter(ConverterWithMicroseconds): sql_type_name = 'DATETIME' - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, datetime): pass elif isinstance(val, basestring): val = str2datetime(val) else: throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val)) @@ -709,7 +777,7 @@ def __init__(converter, provider, py_type, attr=None): attr.auto = False if not attr.default: attr.default = uuid4 Converter.__init__(converter, provider, py_type, attr) - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, UUID): return val if isinstance(val, buffer): return UUID(bytes=val) if isinstance(val, basestring): @@ -725,3 +793,80 @@ def py2sql(converter, val): sql2py = validate def sql_type(converter): return "UUID" + +class JsonConverter(Converter): + json_kwargs = {} + class JsonEncoder(json.JSONEncoder): + def default(converter, obj): + if isinstance(obj, Json): + return obj.wrapped + return json.JSONEncoder.default(converter, obj) + def validate(converter, val, obj=None): + if obj is None or converter.attr is None: + return val + if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: + return val + return TrackedValue.make(obj, converter.attr, val) + def val2dbval(converter, val, obj=None): + return json.dumps(val, cls=converter.JsonEncoder, **converter.json_kwargs) + def dbval2val(converter, dbval, obj=None): + if isinstance(dbval, (int, bool, float, type(None))): + return dbval + val = json.loads(dbval) + if obj is None: + return val + return TrackedValue.make(obj, converter.attr, val) + def dbvals_equal(converter, x, y): + if x == y: return True # optimization + if isinstance(x, basestring): x = json.loads(x) + if isinstance(y, basestring): y = json.loads(y) + return x == y + def sql_type(converter): + return "JSON" + +class ArrayConverter(Converter): + array_types = { + int: ('int', IntConverter), + unicode: ('text', StrConverter), + float: ('real', RealConverter) + } + + def __init__(converter, provider, py_type, attr=None): + Converter.__init__(converter, provider, py_type, attr) + converter.item_converter = converter.array_types[converter.py_type.item_type][1] + + def validate(converter, val, obj=None): + if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: + return val + + if isinstance(val, basestring) or not hasattr(val, '__len__'): + items = [val] + else: + items = list(val) + item_type = converter.py_type.item_type + if item_type == float: + item_type = (float, int) + for i, v in enumerate(items): + if PY2 and isinstance(v, str): + v = v.decode('ascii') + if not isinstance(v, item_type): + if hasattr(v, '__index__'): + items[i] = v.__index__() + else: + throw(TypeError, 'Cannot store %s item in array of %s' % + (type(v).__name__, converter.py_type.item_type.__name__)) + + if obj is None or converter.attr is None: + return items + return TrackedArray(obj, converter.attr, items) + + def dbval2val(converter, dbval, obj=None): + if obj is None or dbval is None: + return dbval + return TrackedArray(obj, converter.attr, dbval) + + def val2dbval(converter, val, obj=None): + return list(val) + + def sql_type(converter): + return '%s[]' % converter.array_types[converter.py_type.item_type][0] diff --git a/pony/orm/dbproviders/cockroach.py b/pony/orm/dbproviders/cockroach.py new file mode 100644 index 000000000..95c598ce4 --- /dev/null +++ b/pony/orm/dbproviders/cockroach.py @@ -0,0 +1,112 @@ +from __future__ import absolute_import +from pony.py23compat import PY2, basestring, unicode, buffer, int_types + +from decimal import Decimal +from datetime import datetime, date, time, timedelta +from uuid import UUID + +try: + import psycopg2 +except ImportError: + try: + from psycopg2cffi import compat + except ImportError: + raise ImportError('In order to use PonyORM with CockroachDB please install psycopg2 or psycopg2cffi') + else: + compat.register() + +from pony.orm.dbproviders.postgres import ( + PGSQLBuilder, PGColumn, PGSchema, PGTranslator, PGProvider, + PGStrConverter, PGIntConverter, PGRealConverter, + PGDatetimeConverter, PGTimedeltaConverter, + PGBlobConverter, PGJsonConverter, PGArrayConverter, +) + +from pony.orm import core, dbapiprovider, ormtypes +from pony.orm.core import log_orm +from pony.orm.dbapiprovider import wrap_dbapi_exceptions + +NoneType = type(None) + +class CRColumn(PGColumn): + auto_template = 'SERIAL PRIMARY KEY' + +class CRSchema(PGSchema): + column_class = CRColumn + +class CRTranslator(PGTranslator): + pass + +class CRSQLBuilder(PGSQLBuilder): + pass + +class CRIntConverter(PGIntConverter): + signed_types = {None: 'INT', 8: 'INT2', 16: 'INT2', 24: 'INT8', 32: 'INT8', 64: 'INT8'} + unsigned_types = {None: 'INT', 8: 'INT2', 16: 'INT4', 24: 'INT8', 32: 'INT8'} + # signed_types = {None: 'INT', 8: 'INT2', 16: 'INT2', 24: 'INT4', 32: 'INT4', 64: 'INT8'} + # unsigned_types = {None: 'INT', 8: 'INT2', 16: 'INT4', 24: 'INT4', 32: 'INT8'} + +class CRBlobConverter(PGBlobConverter): + def sql_type(converter): + return 'BYTES' + +class CRTimedeltaConverter(PGTimedeltaConverter): + sql_type_name = 'INTERVAL' + +class PGUuidConverter(dbapiprovider.UuidConverter): + def py2sql(converter, val): + return val + +class CRArrayConverter(PGArrayConverter): + array_types = { + int: ('INT', PGIntConverter), + unicode: ('STRING', PGStrConverter), + float: ('DOUBLE PRECISION', PGRealConverter) + } + +class CRProvider(PGProvider): + dbapi_module = psycopg2 + dbschema_cls = CRSchema + translator_cls = CRTranslator + sqlbuilder_cls = CRSQLBuilder + array_converter_cls = CRArrayConverter + + default_schema_name = 'public' + + fk_types = { 'SERIAL' : 'INT8' } + + def normalize_name(provider, name): + return name[:provider.max_name_len].lower() + + @wrap_dbapi_exceptions + def set_transaction_mode(provider, connection, cache): + assert not cache.in_transaction + db_session = cache.db_session + if db_session is not None and db_session.ddl: + cache.immediate = False + if cache.immediate and connection.autocommit: + connection.autocommit = False + if core.local.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') + elif not cache.immediate and not connection.autocommit: + connection.autocommit = True + if core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') + if db_session is not None and (db_session.serializable or db_session.ddl): + cache.in_transaction = True + + converter_classes = [ + (NoneType, dbapiprovider.NoneConverter), + (bool, dbapiprovider.BoolConverter), + (basestring, PGStrConverter), + (int_types, CRIntConverter), + (float, PGRealConverter), + (Decimal, dbapiprovider.DecimalConverter), + (datetime, PGDatetimeConverter), + (date, dbapiprovider.DateConverter), + (time, dbapiprovider.TimeConverter), + (timedelta, CRTimedeltaConverter), + (UUID, PGUuidConverter), + (buffer, CRBlobConverter), + (ormtypes.Json, PGJsonConverter), + ] + +provider_cls = CRProvider diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py new file mode 100644 index 000000000..37c923658 --- /dev/null +++ b/pony/orm/dbproviders/mssql.py @@ -0,0 +1,402 @@ +from __future__ import absolute_import +from pony.py23compat import PY2, imap, basestring, buffer, int_types + +import json +from decimal import Decimal +from datetime import datetime, date, time, timedelta +from uuid import UUID +import re +NoneType = type(None) + +import warnings +warnings.filterwarnings('ignore', '^Table.+already exists$', Warning, '^pony\\.orm\\.dbapiprovider$') + +try: + import pyodbc as mssql_module + MSSQL_module_name = 'pyodbc' +except ImportError: + raise ImportError('In order to use PonyORM with MSSQL please install pyodbc') + +from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation +from pony.orm.core import log_orm +from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions +from pony.orm.sqltranslation import SQLTranslator, TranslationError +from pony.orm.sqlbuilding import Value, Param, SQLBuilder, join +from pony.utils import throw +from pony.converting import str2timedelta, timedelta2str + +PYODBC_VAR_REGEX = re.compile(r'(? 2: offset = last_section[2] sections = sections[:-1] - result = builder.subquery(*sections) + result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent if sections[0][0] == 'ROWID': @@ -174,27 +175,27 @@ def SELECT(builder, *sections): else: indent0 = '' x = 't.*' - - if not limit: pass + + if not limit and not offset: + pass elif not offset: result = [ indent0, 'SELECT * FROM (\n' ] builder.indent += 1 - result.extend(builder.subquery(*sections)) + result.extend(builder._subquery(*sections)) builder.indent -= 1 - result.extend((indent, ') WHERE ROWNUM <= ', builder(limit), '\n')) + result.extend((indent, ') WHERE ROWNUM <= %d\n' % limit)) else: indent2 = indent + builder.indent_spaces result = [ indent0, 'SELECT %s FROM (\n' % x, indent2, 'SELECT t.*, ROWNUM "row-num" FROM (\n' ] builder.indent += 2 - result.extend(builder.subquery(*sections)) + result.extend(builder._subquery(*sections)) builder.indent -= 2 - result.extend((indent2, ') t ')) - if limit[0] == 'VALUE' and offset[0] == 'VALUE' \ - and isinstance(limit[1], int) and isinstance(offset[1], int): - total_limit = [ 'VALUE', limit[1] + offset[1] ] - result.extend(('WHERE ROWNUM <= ', builder(total_limit), '\n')) - else: result.extend(('WHERE ROWNUM <= ', builder(limit), ' + ', builder(offset), '\n')) - result.extend((indent, ') t WHERE "row-num" > ', builder(offset), '\n')) + if limit is None: + result.append('%s) t\n' % indent2) + result.append('%s) t WHERE "row-num" > %d\n' % (indent, offset)) + else: + result.append('%s) t WHERE ROWNUM <= %d\n' % (indent2, limit + offset)) + result.append('%s) t WHERE "row-num" > %d\n' % (indent, offset)) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' @@ -205,29 +206,87 @@ def ROWID(builder, *expr_list): return builder.ALL(*expr_list) def LIMIT(builder, limit, offset=None): assert False # pragma: no cover + def TO_REAL(builder, expr): + return 'CAST(', builder(expr), ' AS NUMBER)' + def TO_STR(builder, expr): + return 'TO_CHAR(', builder(expr), ')' def DATE(builder, expr): return 'TRUNC(', builder(expr), ')' def RANDOM(builder): return 'dbms_random.value' + def MOD(builder, a, b): + return 'MOD(', builder(a), ', ', builder(b), ')' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATETIME_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) + def build_json_path(builder, path): + path_sql, has_params, has_wildcards = SQLBuilder.build_json_path(builder, path) + if has_params: throw(TranslationError, "Oracle doesn't allow parameters in JSON paths") + return path_sql, has_params, has_wildcards + def JSON_QUERY(builder, expr, path): + expr_sql = builder(expr) + path_sql, has_params, has_wildcards = builder.build_json_path(path) + if has_wildcards: return 'JSON_QUERY(', expr_sql, ', ', path_sql, ' WITH WRAPPER)' + return 'REGEXP_REPLACE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '(^\\[|\\]$)', '')" + json_value_type_mapping = {bool: 'NUMBER', int: 'NUMBER', float: 'NUMBER'} + def JSON_VALUE(builder, expr, path, type): + if type is Json: return builder.JSON_QUERY(expr, path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) + type_name = builder.json_value_type_mapping.get(type, 'VARCHAR2') + return 'JSON_VALUE(', builder(expr), ', ', path_sql, ' RETURNING ', type_name, ')' + def JSON_NONZERO(builder, expr): + return 'COALESCE(', builder(expr), ''', 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' + def JSON_CONTAINS(builder, expr, path, key): + assert key[0] == 'VALUE' and isinstance(key[1], basestring) + path_sql, has_params, has_wildcards = builder.build_json_path(path) + path_with_key_sql, _, _ = builder.build_json_path(path + [ key ]) + expr_sql = builder(expr) + result = 'JSON_EXISTS(', expr_sql, ', ', path_with_key_sql, ')' + if json_item_re.match(key[1]): + item = r'"([^"]|\\")*"' + list_start = r'\[\s*(%s\s*,\s*)*' % item + list_end = r'\s*(,\s*%s\s*)*\]' % item + pattern = r'%s"%s"%s' % (list_start, key[1], list_end) + if has_wildcards: + sublist = r'\[[^]]*\]' + item_or_sublist = '(%s|%s)' % (item, sublist) + wrapper_list_start = r'^\[\s*(%s\s*,\s*)*' % item_or_sublist + wrapper_list_end = r'\s*(,\s*%s\s*)*\]$' % item_or_sublist + pattern = r'%s%s%s' % (wrapper_list_start, pattern, wrapper_list_end) + result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '%s')" % pattern + else: + pattern = '^%s$' % pattern + result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, "), '%s')" % pattern + return result + def JSON_ARRAY_LENGTH(builder, value): + throw(TranslationError, 'Oracle does not provide `length` function for JSON arrays') + def GROUP_CONCAT(builder, distinct, expr, sep=None): + assert distinct in (None, True, False) + if distinct and builder.provider.server_version >= (19,): + distinct = 'DISTINCT ' + else: + distinct = '' + result = 'LISTAGG(', distinct, builder(expr) + if sep is not None: + result = result, ', ', builder(sep) + else: + result = result, ", ','" + return result, ') WITHIN GROUP(ORDER BY 1)' + +json_item_re = re.compile('[\w\s]*') + class OraBoolConverter(dbapiprovider.BoolConverter): - if not PY2: + if not PY2: def py2sql(converter, val): # Fixes cx_Oracle 5.1.3 Python 3 bug: # "DatabaseError: OCI-22062: invalid input string [True]" @@ -238,7 +297,7 @@ def sql_type(converter): return "NUMBER(1)" class OraStrConverter(dbapiprovider.StrConverter): - def validate(converter, val): + def validate(converter, val, obj=None): if val == '': return None return dbapiprovider.StrConverter.validate(converter, val) def sql2py(converter, val): @@ -287,7 +346,7 @@ def __init__(converter, provider, py_type, attr=None): dbapiprovider.TimeConverter.__init__(converter, provider, py_type, attr) if attr is not None and converter.precision > 0: # cx_Oracle 5.1.3 corrupts microseconds for values of DAY TO SECOND type - converter.precision = 0 + converter.precision = 0 def sql2py(converter, val): if isinstance(val, timedelta): total_seconds = val.days * (24 * 60 * 60) + val.seconds @@ -308,7 +367,7 @@ def __init__(converter, provider, py_type, attr=None): dbapiprovider.TimedeltaConverter.__init__(converter, provider, py_type, attr) if attr is not None and converter.precision > 0: # cx_Oracle 5.1.3 corrupts microseconds for values of DAY TO SECOND type - converter.precision = 0 + converter.precision = 0 class OraDatetimeConverter(dbapiprovider.DatetimeConverter): sql_type_name = 'TIMESTAMP' @@ -317,6 +376,15 @@ class OraUuidConverter(dbapiprovider.UuidConverter): def sql_type(converter): return 'RAW(16)' +class OraJsonConverter(dbapiprovider.JsonConverter): + json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} + optimistic = False # CLOBs cannot be compared with strings, and TO_CHAR(CLOB) returns first 4000 chars only + def sql2py(converter, dbval): + if hasattr(dbval, 'read'): dbval = dbval.read() + return dbapiprovider.JsonConverter.sql2py(converter, dbval) + def sql_type(converter): + return 'CLOB' + class OraProvider(DBAPIProvider): dialect = 'Oracle' paramstyle = 'named' @@ -346,6 +414,7 @@ class OraProvider(DBAPIProvider): (timedelta, OraTimedeltaConverter), (UUID, OraUuidConverter), (buffer, OraBlobConverter), + (Json, OraJsonConverter), ] @wrap_dbapi_exceptions @@ -369,10 +438,11 @@ def normalize_name(provider, name): return name[:provider.max_name_len].upper() def normalize_vars(provider, vars, vartypes): - for name, value in iteritems(vars): + DBAPIProvider.normalize_vars(provider, vars, vartypes) + for key, value in iteritems(vars): if value == '': - vars[name] = None - vartypes[name] = NoneType + vars[key] = None + vartypes[key] = NoneType @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): @@ -381,7 +451,7 @@ def set_transaction_mode(provider, connection, cache): if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.immediate = True if db_session is not None and (db_session.serializable or db_session.ddl): @@ -400,7 +470,11 @@ def execute(provider, cursor, sql, arguments=None, returning_id=False): arguments['new_id'] = var if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) - return var.getvalue() + value = var.getvalue() + if isinstance(value, list): + assert len(value) == 1 + value = value[0] + return value if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) @@ -416,12 +490,16 @@ def get_pool(provider, *args, **kwargs): elif len(args) == 2: user, password = args elif len(args) == 3: user, password, dsn = args elif args: throw(ValueError, 'Invalid number of positional arguments') - if user != kwargs.setdefault('user', user): - throw(ValueError, 'Ambiguous value for user') - if password != kwargs.setdefault('password', password): - throw(ValueError, 'Ambiguous value for password') - if dsn != kwargs.setdefault('dsn', dsn): - throw(ValueError, 'Ambiguous value for dsn') + + def setdefault(kwargs, key, value): + kwargs_value = kwargs.setdefault(key, value) + if value is not None and value != kwargs_value: + throw(ValueError, 'Ambiguous value for ' + key) + + setdefault(kwargs, 'user', user) + setdefault(kwargs, 'password', password) + setdefault(kwargs, 'dsn', dsn) + kwargs.setdefault('threaded', True) kwargs.setdefault('min', 1) kwargs.setdefault('max', 10) @@ -463,15 +541,13 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): return row[0] if row is not None else None def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - cursor.execute('SELECT 1 FROM %s WHERE ROWNUM = 1' % table_name) + cursor.execute('SELECT 1 FROM %s WHERE ROWNUM = 1' % provider.quote_name(table_name)) return cursor.fetchone() is not None def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - sql = 'DROP TABLE %s CASCADE CONSTRAINTS' % table_name + sql = 'DROP TABLE %s CASCADE CONSTRAINTS' % provider.quote_name(table_name) cursor.execute(sql) provider_cls = OraProvider @@ -498,18 +574,19 @@ def output_type_handler(cursor, name, defaultType, size, precision, scale): class OraPool(object): forked_pools = [] def __init__(pool, **kwargs): + pool.kwargs = kwargs pool.cx_pool = cx_Oracle.SessionPool(**kwargs) pool.pid = os.getpid() def connect(pool): pid = os.getpid() if pool.pid != pid: pool.forked_pools.append((pool.cx_pool, pool.pid)) - pool.cx_pool = cx_Oracle.SessionPool(**kwargs) + pool.cx_pool = cx_Oracle.SessionPool(**pool.kwargs) pool.pid = os.getpid() - if core.debug: log_orm('GET CONNECTION') + if core.local.debug: log_orm('GET CONNECTION') con = pool.cx_pool.acquire() con.outputtypehandler = output_type_handler - return con + return con, True def release(pool, con): pool.cx_pool.release(con) def drop(pool, con): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 7a233fecd..81f6d7158 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -5,18 +5,31 @@ from datetime import datetime, date, time, timedelta from uuid import UUID -import psycopg2 +try: + import psycopg2 +except ImportError: + try: + from psycopg2cffi import compat + except ImportError: + raise ImportError('In order to use PonyORM with PostgreSQL please install psycopg2 or psycopg2cffi') + else: + compat.register() + from psycopg2 import extensions import psycopg2.extras psycopg2.extras.register_uuid() -from pony.orm import core, dbschema, dbapiprovider +psycopg2.extras.register_default_json(loads=lambda x: x) +psycopg2.extras.register_default_jsonb(loads=lambda x: x) + +from pony.orm import core, dbschema, dbapiprovider, sqltranslation, ormtypes from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator -from pony.orm.sqlbuilding import Value, SQLBuilder +from pony.orm.sqlbuilding import Value, SQLBuilder, join from pony.converting import timedelta2str +from pony.utils import is_ident NoneType = type(None) @@ -34,13 +47,15 @@ class PGValue(Value): __slots__ = [] def __unicode__(self): value = self.value - if isinstance(value, bool): return value and 'true' or 'false' + if isinstance(value, bool): + return value and 'true' or 'false' return Value.__unicode__(self) - if not PY2: __str__ = __unicode__ + if not PY2: + __str__ = __unicode__ class PGSQLBuilder(SQLBuilder): dialect = 'PostgreSQL' - make_value = PGValue + value_class = PGValue def INSERT(builder, table_name, columns, values, returning=None): if not values: result = [ 'INSERT INTO ', builder.quote_name(table_name) ,' DEFAULT VALUES' ] else: result = SQLBuilder.INSERT(builder, table_name, columns, values) @@ -48,26 +63,80 @@ def INSERT(builder, table_name, columns, values, returning=None): return result def TO_INT(builder, expr): return '(', builder(expr), ')::int' + def TO_STR(builder, expr): + return '(', builder(expr), ')::text' + def TO_REAL(builder, expr): + return '(', builder(expr), ')::double precision' def DATE(builder, expr): return '(', builder(expr), ')::date' def RANDOM(builder): return 'random()' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return '((', builder(expr1), ' - ', builder(expr2), ") * interval '1 day')" def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATETIME_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) + def eval_json_path(builder, values): + result = [] + for value in values: + if isinstance(value, int): + result.append(str(value)) + elif isinstance(value, basestring): + result.append(value if is_ident(value) else '"%s"' % value.replace('"', '\\"')) + else: assert False, value + return '{%s}' % ','.join(result) + def JSON_QUERY(builder, expr, path): + path_sql, has_params, has_wildcards = builder.build_json_path(path) + return '(', builder(expr), " #> ", path_sql, ')' + json_value_type_mapping = {bool: 'boolean', int: 'int', float: 'double precision'} + def JSON_VALUE(builder, expr, path, type): + if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) + sql = '(', builder(expr), " #>> ", path_sql, ')' + type_name = builder.json_value_type_mapping.get(type, 'text') + return sql if type_name == 'text' else (sql, '::', type_name) + def JSON_NONZERO(builder, expr): + return 'coalesce(', builder(expr), ", 'null'::jsonb) NOT IN (" \ + "'null'::jsonb, 'false'::jsonb, '0'::jsonb, '\"\"'::jsonb, '[]'::jsonb, '{}'::jsonb)" + def JSON_CONCAT(builder, left, right): + return '(', builder(left), '||', builder(right), ')' + def JSON_CONTAINS(builder, expr, path, key): + return (builder.JSON_QUERY(expr, path) if path else builder(expr)), ' ? ', builder(key) + def JSON_ARRAY_LENGTH(builder, value): + return 'jsonb_array_length(', builder(value), ')' + def GROUP_CONCAT(builder, distinct, expr, sep=None): + assert distinct in (None, True, False) + result = distinct and 'string_agg(distinct ' or 'string_agg(', builder(expr), '::text' + if sep is not None: + result = result, ', ', builder(sep) + else: + result = result, ", ','" + return result, ')' + def ARRAY_INDEX(builder, col, index): + return builder(col), '[', builder(index), ']' + def ARRAY_CONTAINS(builder, key, not_in, col): + if not_in: + return builder(key), ' <> ALL(', builder(col), ')' + return builder(key), ' = ANY(', builder(col), ')' + def ARRAY_SUBSET(builder, array1, not_in, array2): + result = builder(array1), ' <@ ', builder(array2) + if not_in: + result = 'NOT (', result, ')' + return result + def ARRAY_LENGTH(builder, array): + return 'COALESCE(ARRAY_LENGTH(', builder(array), ', 1), 0)' + def ARRAY_SLICE(builder, array, start, stop): + return builder(array), '[', builder(start) if start else '', ':', builder(stop) if stop else '', ']' + def MAKE_ARRAY(builder, *items): + return 'ARRAY[', join(', ', (builder(item) for item in items)), ']' + class PGStrConverter(dbapiprovider.StrConverter): if PY2: @@ -99,6 +168,17 @@ class PGUuidConverter(dbapiprovider.UuidConverter): def py2sql(converter, val): return val +class PGJsonConverter(dbapiprovider.JsonConverter): + def sql_type(self): + return "JSONB" + +class PGArrayConverter(dbapiprovider.ArrayConverter): + array_types = { + int: ('int', PGIntConverter), + unicode: ('text', PGStrConverter), + float: ('double precision', PGRealConverter) + } + class PGPool(Pool): def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) @@ -120,12 +200,14 @@ class PGProvider(DBAPIProvider): dialect = 'PostgreSQL' paramstyle = 'pyformat' max_name_len = 63 + max_params_count = 10000 index_if_not_exists_syntax = False dbapi_module = psycopg2 dbschema_cls = PGSchema translator_cls = PGTranslator sqlbuilder_cls = PGSQLBuilder + array_converter_cls = PGArrayConverter default_schema_name = 'public' @@ -140,8 +222,7 @@ def inspect_connection(provider, connection): provider.table_if_not_exists_syntax = provider.server_version >= 90100 def should_reconnect(provider, exc): - return isinstance(exc, psycopg2.OperationalError) \ - and exc.pgcode is exc.pgerror is exc.cursor is None + return isinstance(exc, psycopg2.OperationalError) and exc.pgcode is None def get_pool(provider, *args, **kwargs): return PGPool(provider.dbapi_module, *args, **kwargs) @@ -151,16 +232,16 @@ def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction if cache.immediate and connection.autocommit: connection.autocommit = False - if core.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') + if core.local.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') db_session = cache.db_session if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) elif not cache.immediate and not connection.autocommit: connection.autocommit = True - if core.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') + if core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') if db_session is not None and (db_session.serializable or db_session.ddl): cache.in_transaction = True @@ -185,7 +266,7 @@ def table_exists(provider, connection, table_name, case_sensitive=True): cursor.execute(sql, (schema_name, table_name)) row = cursor.fetchone() return row[0] if row is not None else None - + def index_exists(provider, connection, table_name, index_name, case_sensitive=True): schema_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() @@ -214,16 +295,9 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): row = cursor.fetchone() return row[0] if row is not None else None - def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) - cursor = connection.cursor() - cursor.execute('SELECT 1 FROM %s LIMIT 1' % table_name) - return cursor.fetchone() is not None - def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - sql = 'DROP TABLE %s CASCADE' % table_name + sql = 'DROP TABLE %s CASCADE' % provider.quote_name(table_name) cursor.execute(sql) converter_classes = [ @@ -239,6 +313,7 @@ def drop_table(provider, connection, table_name): (timedelta, PGTimedeltaConverter), (UUID, PGUuidConverter), (buffer, PGBlobConverter), + (ormtypes.Json, PGJsonConverter), ] provider_cls = PGProvider diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 7964b151e..ad6ae38f9 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from pony.py23compat import PY2, imap, basestring, buffer, int_types, unicode -import os.path +import os.path, sys, re, json import sqlite3 as sqlite from decimal import Decimal from datetime import datetime, date, time, timedelta @@ -10,12 +10,19 @@ from threading import Lock from uuid import UUID from binascii import hexlify +from functools import wraps -from pony.orm import core, dbschema, sqltranslation, dbapiprovider +from pony.orm import core, dbschema, dbapiprovider from pony.orm.core import log_orm -from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func +from pony.orm.ormtypes import Json, TrackedArray +from pony.orm.sqltranslation import SQLTranslator, StringExprMonad +from pony.orm.sqlbuilding import SQLBuilder, Value, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions -from pony.utils import localbase, datetime2timestamp, timestamp2datetime, decorator, absolutize_path, throw +from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise, \ + cut_traceback_depth + +class SqliteExtensionUnavailable(Exception): + pass NoneType = type(None) @@ -33,12 +40,12 @@ def func(translator, monad): sql = monad.getsql() assert len(sql) == 1 translator = monad.translator - return translator.StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) + return StringExprMonad(monad.type, [ sqlop, sql[0] ]) func.__name__ = sqlop return func -class SQLiteTranslator(sqltranslation.SQLTranslator): +class SQLiteTranslator(SQLTranslator): dialect = 'SQLite' sqlite_version = sqlite.sqlite_version_info row_value_syntax = False @@ -47,14 +54,39 @@ class SQLiteTranslator(sqltranslation.SQLTranslator): StringMixin_UPPER = make_overriden_string_func('PY_UPPER') StringMixin_LOWER = make_overriden_string_func('PY_LOWER') +class SQLiteValue(Value): + __slots__ = [] + def __unicode__(self): + value = self.value + if isinstance(value, datetime): + return self.quote_str(datetime2timestamp(value)) + if isinstance(value, date): + return self.quote_str(str(value)) + if isinstance(value, timedelta): + return repr(value.total_seconds() / (24 * 60 * 60)) + return Value.__unicode__(self) + if not PY2: __str__ = __unicode__ + class SQLiteBuilder(SQLBuilder): dialect = 'SQLite' - def SELECT_FOR_UPDATE(builder, nowait, *sections): - assert not builder.indent and not nowait + least_func_name = 'min' + greatest_func_name = 'max' + value_class = SQLiteValue + def __init__(builder, provider, ast): + builder.json1_available = provider.json1_available + SQLBuilder.__init__(builder, provider, ast) + def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): + assert not builder.indent return builder.SELECT(*sections) def INSERT(builder, table_name, columns, values, returning=None): if not values: return 'INSERT INTO %s DEFAULT VALUES' % builder.quote_name(table_name) return SQLBuilder.INSERT(builder, table_name, columns, values, returning) + def STRING_SLICE(builder, expr, start, stop): + if start is None: + start = [ 'VALUE', None ] + if stop is None: + stop = [ 'VALUE', None ] + return "py_string_slice(", builder(expr), ', ', builder(start), ', ', builder(stop), ")" def TODAY(builder): return "date('now', 'localtime')" def NOW(builder): @@ -94,35 +126,69 @@ def datetime_add(builder, funcname, expr, td): if not modifiers: return builder(expr) return funcname, '(', builder(expr), modifiers, ')' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('date', expr, delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('date', expr, delta[1]) return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('date', expr, -delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('date', expr, -delta[1]) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')' def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('datetime', expr, delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('datetime', expr, delta[1]) return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('datetime', expr, -delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('datetime', expr, -delta[1]) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' - def MIN(builder, *args): - if len(args) == 0: assert False # pragma: no cover - elif len(args) == 1: fname = 'MIN' - else: fname = 'min' - return fname, '(', join(', ', imap(builder, args)), ')' - def MAX(builder, *args): - if len(args) == 0: assert False # pragma: no cover - elif len(args) == 1: fname = 'MAX' - else: fname = 'max' - return fname, '(', join(', ', imap(builder, args)), ')' + def DATETIME_DIFF(builder, expr1, expr2): + return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')' def RANDOM(builder): return 'rand()' # return '(random() / 9223372036854775807.0 + 1.0) / 2.0' PY_UPPER = make_unary_func('py_upper') PY_LOWER = make_unary_func('py_lower') + def FLOAT_EQ(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' + def FLOAT_NE(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14' + def JSON_QUERY(builder, expr, path): + fname = 'json_extract' if builder.json1_available else 'py_json_extract' + path_sql, has_params, has_wildcards = builder.build_json_path(path) + return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', path_sql, '))' + json_value_type_mapping = {unicode: 'text', bool: 'integer', int: 'integer', float: 'real'} + def JSON_VALUE(builder, expr, path, type): + func_name = 'json_extract' if builder.json1_available else 'py_json_extract' + path_sql, has_params, has_wildcards = builder.build_json_path(path) + type_name = builder.json_value_type_mapping.get(type) + result = func_name, '(', builder(expr), ', ', path_sql, ')' + if type_name is not None: result = 'CAST(', result, ' as ', type_name, ')' + return result + def JSON_NONZERO(builder, expr): + return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' + def JSON_ARRAY_LENGTH(builder, value): + func_name = 'json_array_length' if builder.json1_available else 'py_json_array_length' + return func_name, '(', builder(value), ')' + def JSON_CONTAINS(builder, expr, path, key): + path_sql, has_params, has_wildcards = builder.build_json_path(path) + return 'py_json_contains(', builder(expr), ', ', path_sql, ', ', builder(key), ')' + def ARRAY_INDEX(builder, col, index): + return 'py_array_index(', builder(col), ', ', builder(index), ')' + def ARRAY_CONTAINS(builder, key, not_in, col): + return ('NOT ' if not_in else ''), 'py_array_contains(', builder(col), ', ', builder(key), ')' + def ARRAY_SUBSET(builder, array1, not_in, array2): + return ('NOT ' if not_in else ''), 'py_array_subset(', builder(array2), ', ', builder(array1), ')' + def ARRAY_LENGTH(builder, array): + return 'py_array_length(', builder(array), ')' + def ARRAY_SLICE(builder, array, start, stop): + return 'py_array_slice(', builder(array), ', ', \ + builder(start) if start else 'null', ',',\ + builder(stop) if stop else 'null', ')' + def MAKE_ARRAY(builder, *items): + return 'py_make_array(', join(', ', (builder(item) for item in items)), ')' class SQLiteIntConverter(dbapiprovider.IntConverter): def sql_type(converter): @@ -131,6 +197,9 @@ def sql_type(converter): return dbapiprovider.IntConverter.sql_type(converter) class SQLiteDecimalConverter(dbapiprovider.DecimalConverter): + inf = Decimal('infinity') + neg_inf = Decimal('-infinity') + NaN = Decimal('NaN') def sql2py(converter, val): try: val = Decimal(str(val)) except: return val @@ -140,7 +209,10 @@ def sql2py(converter, val): def py2sql(converter, val): if type(val) is not Decimal: val = Decimal(val) exp = converter.exp - if exp is not None: val = val.quantize(exp) + if exp is not None: + if val in (converter.inf, converter.neg_inf, converter.NaN): + throw(ValueError, 'Cannot store %s Decimal value in database' % val) + val = val.quantize(exp) return str(val) class SQLiteDateConverter(dbapiprovider.DateConverter): @@ -175,15 +247,62 @@ def sql2py(converter, val): def py2sql(converter, val): return datetime2timestamp(val) +class SQLiteJsonConverter(dbapiprovider.JsonConverter): + json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} + +def dumps(items): + return json.dumps(items, **SQLiteJsonConverter.json_kwargs) + +class SQLiteArrayConverter(dbapiprovider.ArrayConverter): + array_types = { + int: ('int', SQLiteIntConverter), + unicode: ('text', dbapiprovider.StrConverter), + float: ('real', dbapiprovider.RealConverter) + } + + def dbval2val(converter, dbval, obj=None): + if not dbval: return None + items = json.loads(dbval) + if obj is None: + return items + return TrackedArray(obj, converter.attr, items) + + def val2dbval(converter, val, obj=None): + return dumps(val) + +class LocalExceptions(localbase): + def __init__(self): + self.exc_info = None + self.keep_traceback = False + +local_exceptions = LocalExceptions() + +def keep_exception(func): + @wraps(func) + def new_func(*args): + local_exceptions.exc_info = None + try: + return func(*args) + except Exception: + local_exceptions.exc_info = sys.exc_info() + if not local_exceptions.keep_traceback: + local_exceptions.exc_info = local_exceptions.exc_info[:2] + (None,) + raise + finally: + local_exceptions.keep_traceback = False + return new_func + + class SQLiteProvider(DBAPIProvider): dialect = 'SQLite' + local_exceptions = local_exceptions max_name_len = 1024 - select_for_update_nowait_syntax = False dbapi_module = sqlite dbschema_cls = SQLiteSchema translator_cls = SQLiteTranslator sqlbuilder_cls = SQLiteBuilder + array_converter_cls = SQLiteArrayConverter name_before_table = 'db_name' @@ -201,18 +320,40 @@ class SQLiteProvider(DBAPIProvider): (time, SQLiteTimeConverter), (timedelta, SQLiteTimedeltaConverter), (UUID, dbapiprovider.UuidConverter), - (buffer, dbapiprovider.BlobConverter), - ] + (buffer, dbapiprovider.BlobConverter), + (Json, SQLiteJsonConverter) + ] def __init__(provider, *args, **kwargs): DBAPIProvider.__init__(provider, *args, **kwargs) + provider.pre_transaction_lock = Lock() provider.transaction_lock = Lock() + @wrap_dbapi_exceptions + def inspect_connection(provider, conn): + DBAPIProvider.inspect_connection(provider, conn) + provider.json1_available = provider.check_json1(conn) + + def restore_exception(provider): + if provider.local_exceptions.exc_info is not None: + try: reraise(*provider.local_exceptions.exc_info) + finally: provider.local_exceptions.exc_info = None + + def acquire_lock(provider): + provider.pre_transaction_lock.acquire() + try: + provider.transaction_lock.acquire() + finally: + provider.pre_transaction_lock.release() + + def release_lock(provider): + provider.transaction_lock.release() + @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction if cache.immediate: - provider.transaction_lock.acquire() + provider.acquire_lock() try: cursor = connection.cursor() @@ -223,38 +364,47 @@ def set_transaction_mode(provider, connection, cache): if fk is not None: fk = fk[0] if fk: sql = 'PRAGMA foreign_keys = false' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.saved_fk_state = bool(fk) assert cache.immediate if cache.immediate: sql = 'BEGIN IMMEDIATE TRANSACTION' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.in_transaction = True - elif core.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') + elif core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') finally: if cache.immediate and not cache.in_transaction: - provider.transaction_lock.release() + provider.release_lock() def commit(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction - DBAPIProvider.commit(provider, connection, cache) - if in_transaction: - provider.transaction_lock.release() + try: + DBAPIProvider.commit(provider, connection, cache) + finally: + if in_transaction: + cache.in_transaction = False + provider.release_lock() def rollback(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction - DBAPIProvider.rollback(provider, connection, cache) - if in_transaction: - provider.transaction_lock.release() + try: + DBAPIProvider.rollback(provider, connection, cache) + finally: + if in_transaction: + cache.in_transaction = False + provider.release_lock() def drop(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction - DBAPIProvider.drop(provider, connection, cache) - if in_transaction: - provider.transaction_lock.release() + try: + DBAPIProvider.drop(provider, connection, cache) + finally: + if in_transaction: + cache.in_transaction = False + provider.release_lock() @wrap_dbapi_exceptions def release(provider, connection, cache=None): @@ -264,14 +414,14 @@ def release(provider, connection, cache=None): try: cursor = connection.cursor() sql = 'PRAGMA foreign_keys = true' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) except: provider.pool.drop(connection) raise DBAPIProvider.release(provider, connection, cache) - def get_pool(provider, filename, create_db=False): + def get_pool(provider, filename, create_db=False, **kwargs): if filename != ':memory:': # When relative filename is specified, it is considered # not relative to cwd, but to user module where @@ -286,8 +436,8 @@ def get_pool(provider, filename, create_db=False): # 2 - pony.dbapiprovider.DBAPIProvider.__init__() # 1 - SQLiteProvider.__init__() # 0 - pony.dbproviders.sqlite.get_pool() - filename = absolutize_path(filename, frame_depth=7) - return SQLitePool(filename, create_db) + filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) + return SQLitePool(filename, create_db, **kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): return provider._exists(connection, table_name, None, case_sensitive) @@ -317,6 +467,16 @@ def _exists(provider, connection, table_name, index_name=None, case_sensitive=Tr def fk_exists(provider, connection, table_name, fk_name): assert False # pragma: no cover + def check_json1(provider, connection): + cursor = connection.cursor() + sql = ''' + select json('{"this": "is", "a": ["test"]}')''' + try: + cursor.execute(sql) + return True + except sqlite.OperationalError: + return False + provider_cls = SQLiteProvider def _text_factory(s): @@ -340,23 +500,177 @@ def func(value): py_upper = make_string_function('py_upper', unicode.upper) py_lower = make_string_function('py_lower', unicode.lower) +def py_json_unwrap(value): + # [null,some-value] -> some-value + if value is None: + return None + assert value.startswith('[null,'), value + return value[6:-1] + +path_cache = {} + +json_path_re = re.compile(r'\[(-?\d+)\]|\.(?:(\w+)|"([^"]*)")', re.UNICODE) + +def _parse_path(path): + if path in path_cache: + return path_cache[path] + keys = None + if isinstance(path, basestring) and path.startswith('$'): + keys = [] + pos = 1 + path_len = len(path) + while pos < path_len: + match = json_path_re.match(path, pos) + if match is not None: + g1, g2, g3 = match.groups() + keys.append(int(g1) if g1 else g2 or g3) + pos = match.end() + else: + keys = None + break + else: keys = tuple(keys) + path_cache[path] = keys + return keys + +def _traverse(obj, keys): + if keys is None: return None + list_or_dict = (list, dict) + for key in keys: + if type(obj) not in list_or_dict: return None + try: obj = obj[key] + except (KeyError, IndexError): return None + return obj + +def _extract(expr, *paths): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + result = [] + for path in paths: + keys = _parse_path(path) + result.append(_traverse(expr, keys)) + return result[0] if len(paths) == 1 else result + +def py_json_extract(expr, *paths): + result = _extract(expr, *paths) + if type(result) in (list, dict): + result = json.dumps(result, **SQLiteJsonConverter.json_kwargs) + return result + +def py_json_query(expr, path, with_wrapper): + result = _extract(expr, path) + if type(result) not in (list, dict): + if not with_wrapper: return None + result = [result] + return json.dumps(result, **SQLiteJsonConverter.json_kwargs) + +def py_json_value(expr, path): + result = _extract(expr, path) + return result if type(result) not in (list, dict) else None + +def py_json_contains(expr, path, key): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + keys = _parse_path(path) + expr = _traverse(expr, keys) + return type(expr) in (list, dict) and key in expr + +def py_json_nonzero(expr, path): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + keys = _parse_path(path) + expr = _traverse(expr, keys) + return bool(expr) + +def py_json_array_length(expr, path=None): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + if path: + keys = _parse_path(path) + expr = _traverse(expr, keys) + return len(expr) if type(expr) is list else 0 + +def wrap_array_func(func): + @wraps(func) + def new_func(array, *args): + if array is None: + return None + array = json.loads(array) + return func(array, *args) + return new_func + +@wrap_array_func +def py_array_index(array, index): + try: + return array[index] + except IndexError: + return None + +@wrap_array_func +def py_array_contains(array, item): + return item in array + +@wrap_array_func +def py_array_subset(array, items): + if items is None: return None + items = json.loads(items) + return set(items).issubset(set(array)) + +@wrap_array_func +def py_array_length(array): + return len(array) + +@wrap_array_func +def py_array_slice(array, start, stop): + return dumps(array[start:stop]) + +def py_make_array(*items): + return dumps(items) + +def py_string_slice(s, start, end): + if s is None: + return None + if isinstance(start, basestring): + start = int(start) + if isinstance(end, basestring): + end = int(end) + return s[start:end] + class SQLitePool(Pool): - def __init__(pool, filename, create_db): # called separately in each thread + def __init__(pool, filename, create_db, **kwargs): # called separately in each thread pool.filename = filename pool.create_db = create_db + pool.kwargs = kwargs pool.con = None def _connect(pool): filename = pool.filename if filename != ':memory:' and not pool.create_db and not os.path.exists(filename): throw(IOError, "Database file is not found: %r" % filename) - pool.con = con = sqlite.connect(filename, isolation_level=None) + pool.con = con = sqlite.connect(filename, isolation_level=None, **pool.kwargs) con.text_factory = _text_factory - con.create_function('power', 2, pow) - con.create_function('rand', 0, random) - con.create_function('py_upper', 1, py_upper) - con.create_function('py_lower', 1, py_lower) + + def create_function(name, num_params, func): + func = keep_exception(func) + con.create_function(name, num_params, func) + + create_function('power', 2, pow) + create_function('rand', 0, random) + create_function('py_upper', 1, py_upper) + create_function('py_lower', 1, py_lower) + create_function('py_json_unwrap', 1, py_json_unwrap) + create_function('py_json_extract', -1, py_json_extract) + create_function('py_json_contains', 3, py_json_contains) + create_function('py_json_nonzero', 2, py_json_nonzero) + create_function('py_json_array_length', -1, py_json_array_length) + + create_function('py_array_index', 2, py_array_index) + create_function('py_array_contains', 2, py_array_contains) + create_function('py_array_subset', 2, py_array_subset) + create_function('py_array_length', 1, py_array_length) + create_function('py_array_slice', 3, py_array_slice) + create_function('py_make_array', -1, py_make_array) + + create_function('py_string_slice', 3, py_string_slice) + if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') + + con.execute('PRAGMA case_sensitive_like = true') def disconnect(pool): if pool.filename != ':memory:': Pool.disconnect(pool) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index a03a8bef7..9124a52b3 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import itervalues, basestring +from pony.py23compat import itervalues, basestring, int_types from operator import attrgetter from pony.orm import core -from pony.orm.core import log_sql, DBSchemaError +from pony.orm.core import log_sql, DBSchemaError, MappingError from pony.utils import throw class DBSchema(object): @@ -26,12 +26,13 @@ def case(schema, s): if schema.uppercase: return s.upper().replace('%S', '%s') \ .replace(')S', ')s').replace('%R', '%r').replace(')R', ')r') else: return s.lower() - def add_table(schema, table_name): - return schema.table_class(table_name, schema) + def add_table(schema, table_name, entity=None): + return schema.table_class(table_name, schema, entity) def order_tables_to_create(schema): tables = [] created_tables = set() - tables_to_create = sorted(itervalues(schema.tables), key=lambda table: table.name) + split = schema.provider.split_table_name + tables_to_create = sorted(itervalues(schema.tables), key=lambda table: split(table.name)) while tables_to_create: for table in tables_to_create: if table.parent_tables.issubset(created_tables): @@ -52,9 +53,10 @@ def create_tables(schema, provider, connection): created_tables = set() for table in schema.order_tables_to_create(): for db_object in table.get_objects_to_create(created_tables): + base_name = provider.base_name(db_object.name) name = db_object.exists(provider, connection, case_sensitive=False) if name is None: db_object.create(provider, connection) - elif name != db_object.name: + elif name != base_name: quote_name = schema.provider.quote_name n1, n2 = quote_name(db_object.name), quote_name(name) tn1, tn2 = db_object.typename, db_object.typename.lower() @@ -63,29 +65,28 @@ def create_tables(schema, provider, connection): 'Try to delete %s %s first.' % (tn1, n1, tn2, n2, n2, tn2)) def check_tables(schema, provider, connection): cursor = connection.cursor() - for table in sorted(itervalues(schema.tables), key=lambda table: table.name): - if isinstance(table.name, tuple): alias = table.name[-1] - elif isinstance(table.name, basestring): alias = table.name - else: assert False # pragma: no cover + split = provider.split_table_name + for table in sorted(itervalues(schema.tables), key=lambda table: split(table.name)): + alias = provider.base_name(table.name) sql_ast = [ 'SELECT', [ 'ALL', ] + [ [ 'COLUMN', alias, column.name ] for column in table.column_list ], [ 'FROM', [ alias, 'TABLE', table.name ] ], [ 'WHERE', [ 'EQ', [ 'VALUE', 0 ], [ 'VALUE', 1 ] ] ] ] sql, adapter = provider.ast2sql(sql_ast) - if core.debug: log_sql(sql) + if core.local.debug: log_sql(sql) provider.execute(cursor, sql) class DBObject(object): def create(table, provider, connection): sql = table.get_create_command() - if core.debug: log_sql(sql) + if core.local.debug: log_sql(sql) cursor = connection.cursor() provider.execute(cursor, sql) class Table(DBObject): typename = 'Table' - def __init__(table, name, schema): + def __init__(table, name, schema, entity=None): if name in schema.tables: throw(DBSchemaError, "Table %r already exists in database schema" % name) if name in schema.names: @@ -102,12 +103,21 @@ def __init__(table, name, schema): table.parent_tables = set() table.child_tables = set() table.entities = set() + table.options = {} + if entity is not None: + table.entities.add(entity) + table.options = entity._table_options_ table.m2m = set() def __repr__(table): - table_name = table.name - if isinstance(table_name, tuple): - table_name = '.'.join(table_name) - return '' % table_name + return '' % table.schema.provider.format_table_name(table.name) + def add_entity(table, entity): + for e in table.entities: + if e._root_ is not entity._root_: + throw(MappingError, "Entities %s and %s cannot be mapped to table %s " + "because they don't belong to the same hierarchy" + % (e, entity, table.name)) + assert '_table_options_' not in entity.__dict__ + table.entities.add(entity) def exists(table, provider, connection, case_sensitive=True): return provider.table_exists(connection, table.name, case_sensitive) def get_create_command(table): @@ -132,11 +142,31 @@ def get_create_command(table): for foreign_key in sorted(itervalues(table.foreign_keys), key=lambda fk: fk.name): if schema.inline_fk_syntax and len(foreign_key.child_columns) == 1: continue cmd.append(schema.indent+foreign_key.get_sql() + ',') - cmd[-1] = cmd[-1][:-1] - cmd.append(')') + interleave_fks = [ fk for fk in table.foreign_keys.values() if fk.interleave ] + if interleave_fks: + assert len(interleave_fks) == 1 + fk = interleave_fks[0] + cmd.append(schema.indent+fk.get_sql()) + cmd.append(case(') INTERLEAVE IN PARENT %s (%s)') % ( + quote_name(fk.parent_table.name), + ', '.join(quote_name(col.name) for col in fk.child_columns) + )) + else: + cmd[-1] = cmd[-1][:-1] + cmd.append(')') + for name, value in sorted(table.options.items()): + option = table.format_option(name, value) + if option: cmd.append(option) return '\n'.join(cmd) + def format_option(table, name, value): + if value is True: + return name + if value is False: + return None + return '%s %s' % (name, value) def get_objects_to_create(table, created_tables=None): if created_tables is None: created_tables = set() + created_tables.add(table) result = [ table ] indexes = [ index for index in itervalues(table.indexes) if not index.is_pk and not index.is_unique ] for index in indexes: assert index.name is not None @@ -152,7 +182,6 @@ def get_objects_to_create(table, created_tables=None): for foreign_key in sorted(itervalues(child_table.foreign_keys), key=lambda fk: fk.name): if foreign_key.parent_table is not table: continue result.append(foreign_key) - created_tables.add(table) return result def add_column(table, column_name, sql_type, converter, is_not_null=None, sql_default=None): return table.schema.column_class(column_name, table, sql_type, converter, is_not_null, sql_default) @@ -167,12 +196,14 @@ def add_index(table, index_name, columns, is_pk=False, is_unique=None, m2m=False if index and index.name == index_name and index.is_pk == is_pk and index.is_unique == is_unique: return index return table.schema.index_class(index_name, table, columns, is_pk, is_unique) - def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None): + def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None, on_delete=False, + interleave=False): if fk_name is None: provider = table.schema.provider child_column_names = tuple(column.name for column in child_columns) fk_name = provider.get_default_fk_name(table.name, parent_table.name, child_column_names) - return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name) + return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name, on_delete, + interleave=interleave) class Column(object): auto_template = '%(type)s PRIMARY KEY AUTOINCREMENT' @@ -200,19 +231,24 @@ def get_sql(column): result = [] append = result.append append(quote_name(column.name)) - if column.is_pk == 'auto' and column.auto_template: + + def add_default(): + if column.sql_default not in (None, True, False): + append(case('DEFAULT')) + append(column.sql_default) + + if column.is_pk == 'auto' and column.auto_template and column.converter.py_type in int_types: append(case(column.auto_template % dict(type=column.sql_type))) + add_default() else: append(case(column.sql_type)) + add_default() if column.is_pk: if schema.dialect == 'SQLite': append(case('NOT NULL')) append(case('PRIMARY KEY')) else: if column.is_unique: append(case('UNIQUE')) if column.is_not_null: append(case('NOT NULL')) - if column.sql_default not in (None, True, False): - append(case('DEFAULT')) - append(column.sql_default) if schema.inline_fk_syntax and not schema.named_foreign_keys: foreign_key = table.foreign_keys.get((column,)) if foreign_key is not None: @@ -220,6 +256,8 @@ def get_sql(column): append(case('REFERENCES')) append(quote_name(parent_table.name)) append(schema.column_list(foreign_key.parent_columns)) + if foreign_key.on_delete: + append('ON DELETE %s' % foreign_key.on_delete) return ' '.join(result) class Constraint(DBObject): @@ -257,9 +295,9 @@ def __init__(index, name, table, columns, is_pk=False, is_unique=None): throw(DBSchemaError, 'Index %s cannot be created, name is already in use' % name) Constraint.__init__(index, name, schema) for column in columns: - column.is_pk = len(columns) == 1 and is_pk - column.is_pk_part = bool(is_pk) - column.is_unique = is_unique and len(columns) == 1 + column.is_pk = column.is_pk or (len(columns) == 1 and is_pk) + column.is_pk_part = column.is_pk_part or bool(is_pk) + column.is_unique = column.is_unique or (is_unique and len(columns) == 1) table.indexes[columns] = index index.table = table index.columns = columns @@ -288,6 +326,9 @@ def _get_create_sql(index, inside_table): append(quote_name(index.name)) append(case('ON')) append(quote_name(index.table.name)) + converter = index.columns[0].converter + if isinstance(converter.py_type, core.Array) and converter.provider.dialect == 'PostgreSQL': + append(case('USING GIN')) else: if index.name: append(case('CONSTRAINT')) @@ -300,7 +341,8 @@ def _get_create_sql(index, inside_table): class ForeignKey(Constraint): typename = 'Foreign key' - def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name): + def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name, on_delete, + interleave=False): schema = parent_table.schema if schema is not child_table.schema: throw(DBSchemaError, 'Parent and child tables of foreign_key cannot belong to different schemata') @@ -326,13 +368,15 @@ def __init__(foreign_key, name, child_table, child_columns, parent_table, parent foreign_key.parent_columns = parent_columns foreign_key.child_table = child_table foreign_key.child_columns = child_columns + foreign_key.on_delete = on_delete + foreign_key.interleave = interleave if index_name is not False: child_columns_len = len(child_columns) - for columns in child_table.indexes: - if columns[:child_columns_len] == child_columns: break - else: child_table.add_index(index_name, child_columns, is_pk=False, - is_unique=False, m2m=bool(child_table.m2m)) + if all(columns[:child_columns_len] != child_columns for columns in child_table.indexes): + child_table.add_index(index_name, child_columns, is_pk=False, + is_unique=False, m2m=bool(child_table.m2m)) + def exists(foreign_key, provider, connection, case_sensitive=True): return provider.fk_exists(connection, foreign_key.child_table.name, foreign_key.name, case_sensitive) def get_sql(foreign_key): @@ -357,6 +401,8 @@ def _get_create_sql(foreign_key, inside_table): append(case('REFERENCES')) append(quote_name(foreign_key.parent_table.name)) append(schema.column_list(foreign_key.parent_columns)) + if foreign_key.on_delete: + append(case('ON DELETE %s' % foreign_key.on_delete)) return ' '.join(cmd) DBSchema.table_class = Table diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index d3927e4bc..068714754 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -1,20 +1,22 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, izip, xrange +from pony.py23compat import PY2, izip, xrange, PY37, PYPY -import types +import sys, types, inspect from opcode import opname as opnames, HAVE_ARGUMENT, EXTENDED_ARG, cmp_op from opcode import hasconst, hasname, hasjrel, haslocal, hascompare, hasfree +from collections import defaultdict from pony.thirdparty.compiler import ast, parse -from pony.utils import throw +from pony.utils import throw, get_codeobject_id ##ast.And.__repr__ = lambda self: "And(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) ##ast.Or.__repr__ = lambda self: "Or(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) -ast_cache = {} +class DecompileError(NotImplementedError): + pass -codeobjects = {} +ast_cache = {} def decompile(x): cells = {} @@ -28,10 +30,9 @@ def decompile(x): else: if x.__closure__: cells = dict(izip(codeobject.co_freevars, x.__closure__)) else: throw(TypeError) - key = id(codeobject) + key = get_codeobject_id(codeobject) result = ast_cache.get(key) if result is None: - codeobjects[key] = codeobject decompiler = Decompiler(codeobject) result = decompiler.ast, decompiler.external_names ast_cache[key] = result @@ -50,8 +51,6 @@ def simplify(clause): class InvalidQuery(Exception): pass -class AstGenerated(Exception): pass - def binop(node_type, args_holder=tuple): def method(decompiler): oper2 = decompiler.stack.pop() @@ -68,48 +67,114 @@ def __init__(decompiler, code, start=0, end=None): if end is None: end = len(code.co_code) decompiler.end = end decompiler.stack = [] + decompiler.jump_map = defaultdict(list) decompiler.targets = {} decompiler.ast = None decompiler.names = set() decompiler.assnames = set() + decompiler.conditions_end = 0 + decompiler.instructions = [] + decompiler.instructions_map = {} + decompiler.or_jumps = set() + decompiler.get_instructions() + decompiler.analyze_jumps() decompiler.decompile() decompiler.ast = decompiler.stack.pop() - decompiler.external_names = set(decompiler.names - decompiler.assnames) + decompiler.external_names = decompiler.names - decompiler.assnames assert not decompiler.stack, decompiler.stack - def decompile(decompiler): + def get_instructions(decompiler): + PY36 = sys.version_info >= (3, 6) + before_yield = True code = decompiler.code co_code = code.co_code free = code.co_cellvars + code.co_freevars - try: - while decompiler.pos < decompiler.end: - i = decompiler.pos - if i in decompiler.targets: decompiler.process_target(i) - op = ord(code.co_code[i]) + decompiler.abs_jump_to_top = decompiler.for_iter_pos = -1 + while decompiler.pos < decompiler.end: + i = decompiler.pos + op = ord(code.co_code[i]) + if PY36: + extended_arg = 0 + oparg = ord(code.co_code[i+1]) + while op == EXTENDED_ARG: + extended_arg = (extended_arg | oparg) << 8 + i += 2 + op = ord(code.co_code[i]) + oparg = ord(code.co_code[i+1]) + oparg = None if op < HAVE_ARGUMENT else oparg | extended_arg + i += 2 + else: i += 1 if op >= HAVE_ARGUMENT: - oparg = ord(co_code[i]) + ord(co_code[i+1])*256 + oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 i += 2 if op == EXTENDED_ARG: op = ord(code.co_code[i]) i += 1 - oparg = ord(co_code[i]) + ord(co_code[i+1])*256 + oparg*65536 + oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 + oparg * 65536 i += 2 - if op in hasconst: arg = [code.co_consts[oparg]] - elif op in hasname: arg = [code.co_names[oparg]] - elif op in hasjrel: arg = [i + oparg] - elif op in haslocal: arg = [code.co_varnames[oparg]] - elif op in hascompare: arg = [cmp_op[oparg]] - elif op in hasfree: arg = [free[oparg]] - else: arg = [oparg] - else: arg = [] - opname = opnames[op].replace('+', '_') - # print(opname, arg, decompiler.stack) - method = getattr(decompiler, opname, None) - if method is None: throw(NotImplementedError('Unsupported operation: %s' % opname)) - decompiler.pos = i - x = method(*arg) - if x is not None: decompiler.stack.append(x) - except AstGenerated: pass + if op >= HAVE_ARGUMENT: + if op in hasconst: arg = [code.co_consts[oparg]] + elif op in hasname: arg = [code.co_names[oparg]] + elif op in hasjrel: arg = [i + oparg] + elif op in haslocal: arg = [code.co_varnames[oparg]] + elif op in hascompare: arg = [cmp_op[oparg]] + elif op in hasfree: arg = [free[oparg]] + else: arg = [oparg] + else: arg = [] + opname = opnames[op].replace('+', '_') + if opname == 'FOR_ITER': + decompiler.for_iter_pos = decompiler.pos + if opname == 'JUMP_ABSOLUTE' and arg[0] == decompiler.for_iter_pos: + decompiler.abs_jump_to_top = decompiler.pos + + if before_yield: + if 'JUMP' in opname: + endpos = arg[0] + if endpos < decompiler.pos: + decompiler.conditions_end = i + decompiler.jump_map[endpos].append(decompiler.pos) + decompiler.instructions_map[decompiler.pos] = len(decompiler.instructions) + decompiler.instructions.append((decompiler.pos, i, opname, arg)) + if opname == 'YIELD_VALUE': + before_yield = False + decompiler.pos = i + def analyze_jumps(decompiler): + if PYPY: + targets = decompiler.jump_map.pop(decompiler.abs_jump_to_top, []) + decompiler.jump_map[decompiler.for_iter_pos] = targets + for i, (x, y, opname, arg) in enumerate(decompiler.instructions): + if 'JUMP' in opname: + target = arg[0] + if target == decompiler.abs_jump_to_top: + decompiler.instructions[i] = (x, y, opname, [decompiler.for_iter_pos]) + decompiler.conditions_end = y + + i = decompiler.instructions_map[decompiler.conditions_end] + while i > 0: + pos, next_pos, opname, arg = decompiler.instructions[i] + if pos in decompiler.jump_map: + for jump_start_pos in decompiler.jump_map[pos]: + if jump_start_pos > pos: + continue + for or_jump_start_pos in decompiler.or_jumps: + if pos > or_jump_start_pos > jump_start_pos: + break # And jump + else: + decompiler.or_jumps.add(jump_start_pos) + i -= 1 + def decompile(decompiler): + for pos, next_pos, opname, arg in decompiler.instructions: + if pos in decompiler.targets: + decompiler.process_target(pos) + method = getattr(decompiler, opname, None) + if method is None: + throw(DecompileError('Unsupported operation: %s' % opname)) + decompiler.pos = pos + decompiler.next_pos = next_pos + x = method(*arg) + if x is not None: + decompiler.stack.append(x) + def pop_items(decompiler, size): if not size: return () result = decompiler.stack[-size:] @@ -144,15 +209,34 @@ def store(decompiler, node): def BINARY_SUBSCR(decompiler): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() - if isinstance(oper2, ast.Tuple): return ast.Subscript(oper1, 'OP_APPLY', list(oper2.nodes)) - else: return ast.Subscript(oper1, 'OP_APPLY', [ oper2 ]) + if isinstance(oper2, ast.Sliceobj) and len(oper2.nodes) == 2: + a, b = oper2.nodes + a = None if isinstance(a, ast.Const) and a.value == None else a + b = None if isinstance(b, ast.Const) and b.value == None else b + return ast.Slice(oper1, 'OP_APPLY', a, b) + elif isinstance(oper2, ast.Tuple): + return ast.Subscript(oper1, 'OP_APPLY', list(oper2.nodes)) + else: + return ast.Subscript(oper1, 'OP_APPLY', [ oper2 ]) + + def BUILD_CONST_KEY_MAP(decompiler, length): + keys = decompiler.stack.pop() + assert isinstance(keys, ast.Const) + keys = [ ast.Const(key) for key in keys.value ] + values = decompiler.pop_items(length) + pairs = list(izip(keys, values)) + return ast.Dict(pairs) def BUILD_LIST(decompiler, size): return ast.List(decompiler.pop_items(size)) - def BUILD_MAP(decompiler, not_used): - # Pushes a new empty dictionary object onto the stack. The argument is ignored and set to zero by the compiler - return ast.Dict(()) + def BUILD_MAP(decompiler, length): + if sys.version_info < (3, 5): + return ast.Dict(()) + data = decompiler.pop_items(2 * length) # [key1, value1, key2, value2, ...] + it = iter(data) + pairs = list(izip(it, it)) # [(key1, value1), (key2, value2), ...] + return ast.Dict(tuple(pairs)) def BUILD_SET(decompiler, size): return ast.Set(decompiler.pop_items(size)) @@ -163,6 +247,10 @@ def BUILD_SLICE(decompiler, size): def BUILD_TUPLE(decompiler, size): return ast.Tuple(decompiler.pop_items(size)) + def BUILD_STRING(decompiler, count): + values = list(reversed([decompiler.stack.pop() for _ in range(count)])) + return ast.JoinedStr(values) + def CALL_FUNCTION(decompiler, argc, star=None, star2=None): pop = decompiler.stack.pop kwarg, posarg = divmod(argc, 256) @@ -173,7 +261,10 @@ def CALL_FUNCTION(decompiler, argc, star=None, star2=None): args.append(ast.Keyword(key, arg)) for i in xrange(posarg): args.append(pop()) args.reverse() - tos = pop() + return decompiler._call_function(args, star, star2) + + def _call_function(decompiler, args, star=None, star2=None): + tos = decompiler.stack.pop() if isinstance(tos, ast.GenExpr): assert len(args) == 1 and star is None and star2 is None genexpr = tos @@ -188,13 +279,49 @@ def CALL_FUNCTION_VAR(decompiler, argc): return decompiler.CALL_FUNCTION(argc, decompiler.stack.pop()) def CALL_FUNCTION_KW(decompiler, argc): - return decompiler.CALL_FUNCTION(argc, None, decompiler.stack.pop()) + if sys.version_info < (3, 6): + return decompiler.CALL_FUNCTION(argc, star2=decompiler.stack.pop()) + keys = decompiler.stack.pop() + assert isinstance(keys, ast.Const) + keys = keys.value + values = decompiler.pop_items(argc) + assert len(keys) <= len(values) + args = values[:-len(keys)] + for key, value in izip(keys, values[-len(keys):]): + args.append(ast.Keyword(key, value)) + return decompiler._call_function(args) def CALL_FUNCTION_VAR_KW(decompiler, argc): star2 = decompiler.stack.pop() star = decompiler.stack.pop() return decompiler.CALL_FUNCTION(argc, star, star2) + def CALL_FUNCTION_EX(decompiler, argc): + star2 = None + if argc: + if argc != 1: throw(DecompileError) + star2 = decompiler.stack.pop() + star = decompiler.stack.pop() + return decompiler._call_function([], star, star2) + + def CALL_METHOD(decompiler, argc): + pop = decompiler.stack.pop + args = [] + if argc >= 256: + kwargc = argc // 256 + argc = argc % 256 + for i in range(kwargc): + v = pop() + k = pop() + assert isinstance(k, ast.Const) + k = k.value # ast.Name(k.value) + args.append(ast.Keyword(k, v)) + for i in range(argc): + args.append(pop()) + args.reverse() + method = pop() + return ast.CallFunc(method, args) + def COMPARE_OP(decompiler, op): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() @@ -209,22 +336,61 @@ def FOR_ITER(decompiler, endpos): ifs = [] return ast.GenExprFor(assign, iter, ifs) + def FORMAT_VALUE(decompiler, flags): + if flags in (0, 1, 2, 3): + value = decompiler.stack.pop() + return ast.Str(value, flags) + elif flags == 4: + fmt_spec = decompiler.stack.pop() + value = decompiler.stack.pop() + return ast.FormattedValue(value, fmt_spec) + def GET_ITER(decompiler): pass def JUMP_IF_FALSE(decompiler, endpos): - return decompiler.conditional_jump(endpos, ast.And) + return decompiler.conditional_jump(endpos, False) JUMP_IF_FALSE_OR_POP = JUMP_IF_FALSE def JUMP_IF_TRUE(decompiler, endpos): - return decompiler.conditional_jump(endpos, ast.Or) + return decompiler.conditional_jump(endpos, True) JUMP_IF_TRUE_OR_POP = JUMP_IF_TRUE - def conditional_jump(decompiler, endpos, clausetype): - i = decompiler.pos # next instruction - if i in decompiler.targets: decompiler.process_target(i) + def conditional_jump(decompiler, endpos, if_true): + if PY37 or PYPY: + return decompiler.conditional_jump_new(endpos, if_true) + return decompiler.conditional_jump_old(endpos, if_true) + + def conditional_jump_old(decompiler, endpos, if_true): + i = decompiler.next_pos + if i in decompiler.targets: + decompiler.process_target(i) + expr = decompiler.stack.pop() + clausetype = ast.Or if if_true else ast.And + clause = clausetype([expr]) + clause.endpos = endpos + decompiler.targets.setdefault(endpos, clause) + return clause + + def conditional_jump_new(decompiler, endpos, if_true): + expr = decompiler.stack.pop() + if decompiler.pos >= decompiler.conditions_end: + clausetype = ast.Or if if_true else ast.And + elif decompiler.pos in decompiler.or_jumps: + clausetype = ast.Or + if not if_true: + expr = ast.Not(expr) + else: + clausetype = ast.And + if if_true: + expr = ast.Not(expr) + decompiler.stack.append(expr) + + if decompiler.next_pos in decompiler.targets: + decompiler.process_target(decompiler.next_pos) + expr = decompiler.stack.pop() clause = clausetype([ expr ]) clause.endpos = endpos @@ -240,7 +406,7 @@ def process_target(decompiler, pos, partial=False): top = simplify(top) if top is limit: break if isinstance(top, ast.GenExprFor): break - + if not decompiler.stack: break top2 = decompiler.stack[-1] if isinstance(top2, ast.GenExprFor): break if partial and hasattr(top2, 'endpos') and top2.endpos == pos: break @@ -253,13 +419,13 @@ def process_target(decompiler, pos, partial=False): if hasattr(top, 'endpos'): top2.endpos = top.endpos if decompiler.targets.get(top.endpos) is top: decompiler.targets[top.endpos] = top2 - else: throw(NotImplementedError('Expression is too complex to decompile, try to pass query as string, e.g. select("x for x in Something")')) + else: throw(DecompileError('Expression is too complex to decompile, try to pass query as string, e.g. select("x for x in Something")')) top2.endpos = max(top2.endpos, getattr(top, 'endpos', 0)) top = decompiler.stack.pop() decompiler.stack.append(top) def JUMP_FORWARD(decompiler, endpos): - i = decompiler.pos # next instruction + i = decompiler.next_pos # next instruction decompiler.process_target(i, True) then = decompiler.stack.pop() decompiler.process_target(i, False) @@ -270,8 +436,9 @@ def JUMP_FORWARD(decompiler, endpos): if decompiler.targets.get(endpos) is then: decompiler.targets[endpos] = if_exp return if_exp - def LIST_APPEND(decompiler): - throw(NotImplementedError) + def LIST_APPEND(decompiler, offset=None): + throw(InvalidQuery('Use generator expression (... for ... in ...) ' + 'instead of list comprehension [... for ... in ...] inside query')) def LOAD_ATTR(decompiler, attr_name): return ast.Getattr(decompiler.stack.pop(), attr_name) @@ -295,6 +462,11 @@ def LOAD_GLOBAL(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) + def LOAD_METHOD(decompiler, methname): + return decompiler.LOAD_ATTR(methname) + + LOOKUP_METHOD = LOAD_METHOD # For PyPy + def LOAD_NAME(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) @@ -305,17 +477,39 @@ def MAKE_CLOSURE(decompiler, argc): return decompiler.MAKE_FUNCTION(argc) def MAKE_FUNCTION(decompiler, argc): - if argc: throw(NotImplementedError) - tos = decompiler.stack.pop() - if not PY2: tos = decompiler.stack.pop() + defaults = [] + flags = 0 + if sys.version_info >= (3, 6): + qualname = decompiler.stack.pop() + tos = decompiler.stack.pop() + if argc & 0x08: + func_closure = decompiler.stack.pop() + if argc & 0x04: + annotations = decompiler.stack.pop() + if argc & 0x02: + kwonly_defaults = decompiler.stack.pop() + if argc & 0x01: + defaults = decompiler.stack.pop() + throw(DecompileError) + else: + if not PY2: + qualname = decompiler.stack.pop() + tos = decompiler.stack.pop() + if argc: + defaults = [ decompiler.stack.pop() for i in range(argc) ] + defaults.reverse() codeobject = tos.value func_decompiler = Decompiler(codeobject) # decompiler.names.update(decompiler.names) ??? if codeobject.co_varnames[:1] == ('.0',): return func_decompiler.ast # generator - argnames = codeobject.co_varnames[:codeobject.co_argcount] - defaults = [] # todo - flags = 0 # todo + argnames, varargs, keywords = inspect.getargs(codeobject) + if varargs: + argnames.append(varargs) + flags |= inspect.CO_VARARGS + if keywords: + argnames.append(keywords) + flags |= inspect.CO_VARKEYWORDS return ast.Lambda(argnames, defaults, flags, func_decompiler.ast) POP_JUMP_IF_FALSE = JUMP_IF_FALSE @@ -325,10 +519,9 @@ def POP_TOP(decompiler): pass def RETURN_VALUE(decompiler): - if decompiler.pos != decompiler.end: throw(NotImplementedError) + if decompiler.next_pos != decompiler.end: throw(DecompileError) expr = decompiler.stack.pop() - decompiler.stack.append(simplify(expr)) - raise AstGenerated() + return simplify(expr) def ROT_TWO(decompiler): tos = decompiler.stack.pop() @@ -375,7 +568,8 @@ def STORE_DEREF(decompiler, freevar): def STORE_FAST(decompiler, varname): if varname.startswith('_['): - throw(InvalidQuery('Use generator expression (... for ... in ...) instead of list comprehension [... for ... in ...] inside query')) + throw(InvalidQuery('Use generator expression (... for ... in ...) ' + 'instead of list comprehension [... for ... in ...] inside query')) decompiler.assnames.add(varname) decompiler.store(ast.AssName(varname, 'OP_ASSIGN')) @@ -429,8 +623,7 @@ def YIELD_VALUE(decompiler): fors.append(top) else: fors.append(top) fors.reverse() - decompiler.stack.append(ast.GenExpr(ast.GenExprInner(simplify(expr), fors))) - raise AstGenerated() + return ast.GenExpr(ast.GenExprInner(simplify(expr), fors)) test_lines = """ (a and b if c and d else e and f for i in T if (A and B if C and D else E and F)) @@ -445,8 +638,13 @@ def YIELD_VALUE(decompiler): (a for b in T if f == 5 and r or t) (a for b in T if f and r and t) - (a for b in T if f == 5 and +r or not t) - (a for b in T if -t and ~r or `f`) + # (a for b in T if f == 5 and +r or not t) + # (a for b in T if -t and ~r or `f`) + + (a for b in T if x and not y and z) + (a for b in T if not x and y) + (a for b in T if not x and y and z) + (a for b in T if not x and y or z) #FIXME! (a**2 for b in T if t * r > y / 3) (a + 2 for b in T if t + r > y // 3) @@ -480,10 +678,12 @@ def YIELD_VALUE(decompiler): (s for s in T if s.a > 20 and (s.x.y == 123 or 'ABC' in s.p.q.r)) (a for b in T1 if c > d for e in T2 if f < g) - (func1(a, a.attr, keyarg=123) for s in T) - (func1(a, a.attr, keyarg=123, *e) for s in T) - (func1(a, b, a.attr1, a.b.c, keyarg1=123, keyarg2='mx', *e, **f) for s in T) - (func(a, a.attr, keyarg=123) for a in T if a.method(x, *y, **z) == 4) + (func1(a, a.attr, x=123) for s in T) + # (func1(a, a.attr, *args) for s in T) + # (func1(a, a.attr, x=123, **kwargs) for s in T) + (func1(a, b, a.attr1, a.b.c, x=123, y='foo') for s in T) + # (func1(a, b, a.attr1, a.b.c, x=123, y='foo', **kwargs) for s in T) + # (func(a, a.attr, keyarg=123) for a in T if a.method(x, *args, **kwargs) == 4) ((x or y) and (p or q) for a in T if (a or b) and (c or d)) (x.y for x in T if (a and (b or (c and d))) or X) diff --git a/pony/orm/examples/estore.py b/pony/orm/examples/estore.py index e62d3374d..a7de96a77 100644 --- a/pony/orm/examples/estore.py +++ b/pony/orm/examples/estore.py @@ -229,7 +229,7 @@ def test_queries(): print('Three most valuable customers') print() result = select(c for c in Customer).order_by(lambda c: desc(sum(c.orders.total_price)))[:3] - + print(result) print() @@ -276,7 +276,7 @@ def test_queries(): for customer in Customer for product in customer.orders.items.product for category in product.categories - if count(product) > 1)[:] + if count(product) > 1)[:] print(result) print() diff --git a/pony/orm/examples/university1.py b/pony/orm/examples/university1.py index e5ec79335..b575743d2 100644 --- a/pony/orm/examples/university1.py +++ b/pony/orm/examples/university1.py @@ -42,11 +42,14 @@ class Student(db.Entity): sql_debug(True) # Output all SQL queries to stdout -db.bind('sqlite', 'university1.sqlite', create_db=True) -#db.bind('mysql', host="localhost", user="pony", passwd="pony", db="university1") -#db.bind('postgres', user='pony', password='pony', host='localhost', database='university1') -#db.bind('oracle', 'university1/pony@localhost') -#db.bind('mssqlserver', 'DRIVER={SQL Server};SERVER=mycomputer\SQLEXPRESS;DATABASE=university1') +params = dict( + sqlite=dict(provider='sqlite', filename='university1.sqlite', create_db=True), + mysql=dict(provider='mysql', host="localhost", user="pony", passwd="pony", db="pony"), + postgres=dict(provider='postgres', user='pony', password='pony', host='localhost', database='pony'), + cockroach=dict(provider='cockroach', user='root', host='localhost', port=26257, database='pony', sslmode='disable'), + oracle=dict(provider='oracle', user='c##pony', password='pony', dsn='localhost/orcl') +) +db.bind(**params['sqlite']) db.generate_mapping(create_tables=True) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 050a93a79..080b7f5cf 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -1,12 +1,13 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types +from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types, iteritems -import types +import sys, types, weakref from decimal import Decimal from datetime import date, time, datetime, timedelta +from functools import wraps, WRAPPER_ASSIGNMENTS from uuid import UUID -from pony.utils import throw, parse_expr +from pony.utils import throw, parse_expr, deref_proxy NoneType = type(None) @@ -44,6 +45,8 @@ def __ne__(self, other): return type(other) is not FuncType or self.func != other.func def __hash__(self): return hash(self.func) + 1 + def __repr__(self): + return 'FuncType(%s at %d)' % (self.func.__name__, id(self.func)) class MethodType(object): __slots__ = 'obj', 'func' @@ -59,7 +62,7 @@ def __init__(self, method): def __eq__(self, other): return type(other) is MethodType and self.obj == other.obj and self.func == other.func def __ne__(self, other): - return type(other) is not SetType or self.obj != other.obj or self.func != other.func + return type(other) is not MethodType or self.obj != other.obj or self.func != other.func def __hash__(self): return hash(self.obj) ^ hash(self.func) @@ -94,14 +97,18 @@ def parse_raw_sql(sql): raw_sql_cache[sql] = result return result +def raw_sql(sql, result_type=None): + globals = sys._getframe(1).f_globals + locals = sys._getframe(1).f_locals + return RawSQL(sql, globals, locals, result_type) + class RawSQL(object): def __deepcopy__(self, memo): assert False # should not attempt to deepcopy RawSQL instances, because of locals/globals def __init__(self, sql, globals=None, locals=None, result_type=None): self.sql = sql self.items, self.codes = parse_raw_sql(sql) - self.values = tuple(eval(code, globals, locals) for code in self.codes) - self.types = tuple(get_normalized_type_of(value) for value in self.values) + self.types, self.values = normalize(tuple(eval(code, globals, locals) for code in self.codes)) self.result_type = result_type def _get_type_(self): return RawSQLType(self.sql, self.items, self.types, self.result_type) @@ -121,49 +128,88 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) -numeric_types = set([ bool, int, float, Decimal ]) -comparable_types = set([ int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID ]) -primitive_types = comparable_types | set([ buffer ]) -function_types = set([type, types.FunctionType, types.BuiltinFunctionType]) -type_normalization_dict = { long : int } if PY2 else {} +class QueryType(object): + def __init__(self, query, limit=None, offset=None): + self.query_key = query._key + self.translator = query._translator + self.limit = limit + self.offset = offset + def __hash__(self): + result = hash(self.query_key) + if self.limit is not None: + result ^= hash(self.limit + 3) + if self.offset is not None: + result ^= hash(self.offset) + return result + def __eq__(self, other): + return type(other) is QueryType and self.query_key == other.query_key \ + and self.limit == other.limit and self.offset == other.offset + def __ne__(self, other): + return not self.__eq__(other) -def get_normalized_type_of(value): + +def normalize(value): + value = deref_proxy(value) t = type(value) - if t is tuple: return tuple(get_normalized_type_of(item) for item in value) - try: hash(value) # without this, cannot do tests like 'if value in special_fucntions...' - except TypeError: throw(TypeError, 'Unsupported type %r' % t.__name__) - if t.__name__ == 'EntityMeta': return SetType(value) - if t.__name__ == 'EntityIter': return SetType(value.entity) + if t is tuple: + item_types, item_values = [], [] + for item in value: + item_type, item_value = normalize(item) + item_values.append(item_value) + item_types.append(item_type) + return tuple(item_types), tuple(item_values) + + if t.__name__ == 'EntityMeta': + return SetType(value), value + + if t.__name__ == 'EntityIter': + entity = value.entity + return SetType(entity), entity + if PY2 and isinstance(value, str): - try: value.decode('ascii') - except UnicodeDecodeError: throw(TypeError, - 'The bytestring %r contains non-ascii symbols. Try to pass unicode string instead' % value) - else: return unicode - elif isinstance(value, unicode): return unicode - if t in function_types: return FuncType(value) - if t is types.MethodType: return MethodType(value) + try: + value.decode('ascii') + except UnicodeDecodeError: + throw(TypeError, 'The bytestring %r contains non-ascii symbols. Try to pass unicode string instead' % value) + else: + return unicode, value + elif isinstance(value, unicode): + return unicode, value + + if t in function_types: + return FuncType(value), value + + if t is types.MethodType: + return MethodType(value), value + if hasattr(value, '_get_type_'): - return value._get_type_() - return normalize_type(t) + return value._get_type_(), value + + return normalize_type(t), value def normalize_type(t): tt = type(t) if tt is tuple: return tuple(normalize_type(item) for item in t) + if not isinstance(t, type): + return t assert t.__name__ != 'EntityMeta' if tt.__name__ == 'EntityMeta': return t if t is NoneType: return t t = type_normalization_dict.get(t, t) if t in primitive_types: return t + if t in (slice, type(Ellipsis)): return t if issubclass(t, basestring): return unicode + if issubclass(t, (dict, Json)): return Json + if issubclass(t, Array): return t throw(TypeError, 'Unsupported type %r' % t.__name__) coercions = { - (int, float) : float, - (int, Decimal) : Decimal, - (date, datetime) : datetime, - (bool, int) : int, - (bool, float) : float, - (bool, Decimal) : Decimal + (int, float): float, + (int, Decimal): Decimal, + (date, datetime): datetime, + (bool, int): int, + (bool, float): float, + (bool, Decimal): Decimal } coercions.update(((t2, t1), t3) for ((t1, t2), t3) in items_list(coercions)) @@ -184,6 +230,10 @@ def are_comparable_types(t1, t2, op='=='): # types must be normalized already! tt1 = type(t1) tt2 = type(t2) + + t12 = {t1, t2} + if Json in t12 and t12 < {Json, str, unicode, int, bool, float}: + return True if op in ('in', 'not in'): if tt2 is RawSQLType: return True if tt2 is not SetType: return False @@ -214,3 +264,154 @@ def are_comparable_types(t1, t2, op='=='): return False if t1 is t2 and t1 in comparable_types: return True return (t1, t2) in coercions + +class TrackedValue(object): + def __init__(self, obj, attr): + self.obj_ref = weakref.ref(obj) + self.attr = attr + @classmethod + def make(cls, obj, attr, value): + if isinstance(value, dict): + return TrackedDict(obj, attr, value) + if isinstance(value, list): + return TrackedList(obj, attr, value) + return value + def _changed_(self): + obj = self.obj_ref() + if obj is not None: + obj._attr_changed_(self.attr) + def get_untracked(self): + assert False, 'Abstract method' # pragma: no cover + +def tracked_method(func): + @wraps(func, assigned=('__name__', '__doc__') if PY2 else WRAPPER_ASSIGNMENTS) + def new_func(self, *args, **kwargs): + obj = self.obj_ref() + attr = self.attr + if obj is not None: + args = tuple(TrackedValue.make(obj, attr, arg) for arg in args) + if kwargs: kwargs = {key: TrackedValue.make(obj, attr, value) for key, value in iteritems(kwargs)} + result = func(self, *args, **kwargs) + self._changed_() + return result + return new_func + +class TrackedDict(TrackedValue, dict): + def __init__(self, obj, attr, value): + TrackedValue.__init__(self, obj, attr) + dict.__init__(self, {key: self.make(obj, attr, val) for key, val in iteritems(value)}) + def __reduce__(self): + return dict, (dict(self),) + __setitem__ = tracked_method(dict.__setitem__) + __delitem__ = tracked_method(dict.__delitem__) + _update = tracked_method(dict.update) + def update(self, *args, **kwargs): + args = [ arg if isinstance(arg, dict) else dict(arg) for arg in args ] + return self._update(*args, **kwargs) + setdefault = tracked_method(dict.setdefault) + pop = tracked_method(dict.pop) + popitem = tracked_method(dict.popitem) + clear = tracked_method(dict.clear) + def get_untracked(self): + return {key: val.get_untracked() if isinstance(val, TrackedValue) else val + for key, val in self.items()} + +class TrackedList(TrackedValue, list): + def __init__(self, obj, attr, value): + TrackedValue.__init__(self, obj, attr) + list.__init__(self, (self.make(obj, attr, val) for val in value)) + def __reduce__(self): + return list, (list(self),) + __setitem__ = tracked_method(list.__setitem__) + __delitem__ = tracked_method(list.__delitem__) + extend = tracked_method(list.extend) + append = tracked_method(list.append) + pop = tracked_method(list.pop) + remove = tracked_method(list.remove) + insert = tracked_method(list.insert) + reverse = tracked_method(list.reverse) + sort = tracked_method(list.sort) + if PY2: + __setslice__ = tracked_method(list.__setslice__) + else: + clear = tracked_method(list.clear) + def get_untracked(self): + return [val.get_untracked() if isinstance(val, TrackedValue) else val for val in self] + +def validate_item(item_type, item): + if PY2 and isinstance(item, str): + item = item.decode('ascii') + if not isinstance(item, item_type): + if item_type is not unicode and hasattr(item, '__index__'): + return item.__index__() + throw(TypeError, 'Cannot store %r item in array of %r' % (type(item).__name__, item_type.__name__)) + return item + +class TrackedArray(TrackedList): + def __init__(self, obj, attr, value): + TrackedList.__init__(self, obj, attr, value) + self.item_type = attr.py_type.item_type + def extend(self, items): + items = [validate_item(self.item_type, item) for item in items] + TrackedList.extend(self, items) + def append(self, item): + item = validate_item(self.item_type, item) + TrackedList.append(self, item) + def insert(self, index, item): + item = validate_item(self.item_type, item) + TrackedList.insert(self, index, item) + def __setitem__(self, index, item): + item = validate_item(self.item_type, item) + TrackedList.__setitem__(self, index, item) + + def __contains__(self, item): + if not isinstance(item, basestring) and hasattr(item, '__iter__'): + return all(it in set(self) for it in item) + return list.__contains__(self, item) + + +class Json(object): + """A wrapper over a dict or list + """ + @classmethod + def default_empty_value(cls): + return {} + + def __init__(self, wrapped): + self.wrapped = wrapped + + def __repr__(self): + return '' % self.wrapped + +class Array(object): + item_type = None # Should be overridden in subclass + + @classmethod + def default_empty_value(cls): + return [] + + +class IntArray(Array): + item_type = int + + +class StrArray(Array): + item_type = unicode + + +class FloatArray(Array): + item_type = float + + +numeric_types = {bool, int, float, Decimal} +comparable_types = {int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID, IntArray, StrArray, FloatArray} +primitive_types = comparable_types | {buffer} +function_types = {type, types.FunctionType, types.BuiltinFunctionType} +type_normalization_dict = { long : int } if PY2 else {} + +array_types = { + int: IntArray, + float: FloatArray, + unicode: StrArray +} + diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 0b61a8641..671ba5e37 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -1,24 +1,44 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, izip, imap, itervalues, basestring, unicode, buffer +from pony.py23compat import PY2, izip, imap, itervalues, basestring, unicode, buffer, int_types from operator import attrgetter from decimal import Decimal -from datetime import date, datetime +from datetime import date, datetime, timedelta from binascii import hexlify from pony import options -from pony.utils import datetime2timestamp, throw -from pony.orm.ormtypes import RawSQL +from pony.utils import datetime2timestamp, throw, is_ident +from pony.converting import timedelta2str +from pony.orm.ormtypes import RawSQL, Json class AstError(Exception): pass class Param(object): - __slots__ = 'style', 'id', 'paramkey', 'py2sql' - def __init__(param, paramstyle, id, paramkey, converter=None): + __slots__ = 'style', 'id', 'paramkey', 'converter', 'optimistic' + def __init__(param, paramstyle, paramkey, converter=None, optimistic=False): param.style = paramstyle - param.id = id + param.id = None param.paramkey = paramkey - param.py2sql = converter.py2sql if converter else (lambda val: val) + param.converter = converter + param.optimistic = optimistic + def eval(param, values): + varkey, i, j = param.paramkey + value = values[varkey] + if i is not None: + t = type(value) + if t is tuple: value = value[i] + elif t is RawSQL: value = value.values[i] + elif hasattr(value, '_get_items'): value = value._get_items()[i] + else: assert False, t + if j is not None: + assert type(type(value)).__name__ == 'EntityMeta' + value = value._get_raw_pkval_()[j] + converter = param.converter + if value is not None and converter is not None: + if converter.attr is None: + value = converter.val2dbval(value) + value = converter.py2sql(value) + return value def __unicode__(param): paramstyle = param.style if paramstyle == 'qmark': return u'?' @@ -31,6 +51,17 @@ def __unicode__(param): def __repr__(param): return '%s(%r)' % (param.__class__.__name__, param.paramkey) +class CompositeParam(Param): + __slots__ = 'items', 'func' + def __init__(param, paramstyle, paramkey, items, func): + for item in items: assert isinstance(item, (Param, Value)), item + Param.__init__(param, paramstyle, paramkey) + param.items = items + param.func = func + def eval(param, values): + args = [ item.eval(values) if isinstance(item, Param) else item.value for item in param.items ] + return param.func(args) + class Value(object): __slots__ = 'paramstyle', 'value' def __init__(self, paramstyle, value): @@ -38,19 +69,31 @@ def __init__(self, paramstyle, value): self.value = value def __unicode__(self): value = self.value - if value is None: return 'null' - if isinstance(value, bool): return value and '1' or '0' - if isinstance(value, basestring): return self.quote_str(value) - if isinstance(value, datetime): return self.quote_str(datetime2timestamp(value)) - if isinstance(value, date): return self.quote_str(str(value)) + if value is None: + return 'null' + if isinstance(value, bool): + return value and '1' or '0' + if isinstance(value, basestring): + return self.quote_str(value) + if isinstance(value, datetime): + return 'TIMESTAMP ' + self.quote_str(datetime2timestamp(value)) + if isinstance(value, date): + return 'DATE ' + self.quote_str(str(value)) + if isinstance(value, timedelta): + return "INTERVAL '%s' HOUR TO SECOND" % timedelta2str(value) if PY2: - if isinstance(value, (int, long, float, Decimal)): return str(value) - if isinstance(value, buffer): return "X'%s'" % hexlify(value) + if isinstance(value, (int, long, float, Decimal)): + return str(value) + if isinstance(value, buffer): + return "X'%s'" % hexlify(value) else: - if isinstance(value, (int, float, Decimal)): return str(value) - if isinstance(value, bytes): return "X'%s'" % hexlify(value).decode('ascii') - assert False, value # pragma: no cover - if not PY2: __str__ = __unicode__ + if isinstance(value, (int, float, Decimal)): + return str(value) + if isinstance(value, bytes): + return "X'%s'" % hexlify(value).decode('ascii') + assert False, repr(value) # pragma: no cover + if not PY2: + __str__ = __unicode__ def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.value) def quote_str(self, s): @@ -67,9 +110,9 @@ def flat(tree): x = stack_pop() if isinstance(x, basestring): result_append(x) else: - try: stack_extend(reversed(x)) + try: stack_extend(x) except TypeError: result_append(x) - return result + return result[::-1] def flat_conditions(conditions): result = [] @@ -127,27 +170,14 @@ def new_method(builder, *args, **kwargs): new_method.__name__ = method.__name__ return new_method -def convert(values, params): - for param in params: - varkey, i, j = param.paramkey - value = values[varkey] - t = type(value) - if i is not None: - if t is tuple: value = value[i] - elif t is RawSQL: value = value.values[i] - else: assert False - if j is not None: - assert type(type(value)).__name__ == 'EntityMeta' - value = value._get_raw_pkval_()[j] - if value is not None: # can value be None at all? - value = param.py2sql(value) - yield value - class SQLBuilder(object): dialect = None - make_param = Param - make_value = Value + param_class = Param + composite_param_class = CompositeParam + value_class = Value indent_spaces = " " * 4 + least_func_name = 'least' + greatest_func_name = 'greatest' def __init__(builder, provider, ast): builder.provider = provider builder.quote_name = provider.quote_name @@ -158,22 +188,24 @@ def __init__(builder, provider, ast): builder.inner_join_syntax = options.INNER_JOIN_SYNTAX builder.suppress_aliases = False builder.result = flat(builder(ast)) + params = tuple(x for x in builder.result if isinstance(x, Param)) + layout = [] + for i, param in enumerate(params): + if param.id is None: param.id = i + 1 + layout.append(param.paramkey) + builder.layout = layout builder.sql = u''.join(imap(unicode, builder.result)).rstrip('\n') if paramstyle in ('qmark', 'format'): - params = tuple(x for x in builder.result if isinstance(x, Param)) def adapter(values): - return tuple(convert(values, params)) + return tuple(param.eval(values) for param in params) elif paramstyle == 'numeric': - params = tuple(param for param in sorted(itervalues(builder.keys), key=attrgetter('id'))) def adapter(values): - return tuple(convert(values, params)) + return tuple(param.eval(values) for param in params) elif paramstyle in ('named', 'pyformat'): - params = tuple(param for param in sorted(itervalues(builder.keys), key=attrgetter('id'))) def adapter(values): - return dict(('p%d' % param.id, value) for param, value in izip(params, convert(values, params))) + return {'p%d' % param.id: param.eval(values) for param in params} else: throw(NotImplementedError, paramstyle) builder.params = params - builder.layout = tuple(param.paramkey for param in params) builder.adapter = adapter def __call__(builder, ast): if isinstance(ast, basestring): @@ -217,7 +249,7 @@ def DELETE(builder, alias, from_ast, where=None): if alias is not None: builder.suppress_aliases = True if not where: return 'DELETE ', builder(from_ast) return 'DELETE ', builder(from_ast), builder(where) - def subquery(builder, *sections): + def _subquery(builder, *sections): builder.indent += 1 if not builder.inner_join_syntax: sections = move_conditions_from_inner_join_to_where(sections) @@ -228,19 +260,21 @@ def SELECT(builder, *sections): prev_suppress_aliases = builder.suppress_aliases builder.suppress_aliases = False try: - result = builder.subquery(*sections) + result = builder._subquery(*sections) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' return result finally: builder.suppress_aliases = prev_suppress_aliases - def SELECT_FOR_UPDATE(builder, nowait, *sections): + def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): assert not builder.indent result = builder.SELECT(*sections) - return result, 'FOR UPDATE NOWAIT\n' if nowait else 'FOR UPDATE\n' + nowait = ' NOWAIT' if nowait else '' + skip_locked = ' SKIP LOCKED' if skip_locked else '' + return result, 'FOR UPDATE', nowait, skip_locked, '\n' def EXISTS(builder, *sections): - result = builder.subquery(*sections) + result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent return 'EXISTS (\n', indent, 'SELECT 1\n', result, indent, ')' def NOT_EXISTS(builder, *sections): @@ -340,23 +374,36 @@ def DESC(builder, expr): return builder(expr), ' DESC' @indentable def LIMIT(builder, limit, offset=None): - if not offset: return 'LIMIT ', builder(limit), '\n' - else: return 'LIMIT ', builder(limit), ' OFFSET ', builder(offset), '\n' + if limit is None: + limit = 'null' + else: + assert isinstance(limit, int_types) + assert offset is None or isinstance(offset, int) + if offset: + return 'LIMIT %s OFFSET %d\n' % (limit, offset) + else: + return 'LIMIT %s\n' % limit def COLUMN(builder, table_alias, col_name): if builder.suppress_aliases or not table_alias: return [ '%s' % builder.quote_name(col_name) ] return [ '%s.%s' % (builder.quote_name(table_alias), builder.quote_name(col_name)) ] - def PARAM(builder, paramkey, converter=None): + def PARAM(builder, paramkey, converter=None, optimistic=False): + return builder.make_param(builder.param_class, paramkey, converter, optimistic) + def make_param(builder, param_class, paramkey, *args): keys = builder.keys param = keys.get(paramkey) if param is None: - param = Param(builder.paramstyle, len(keys) + 1, paramkey, converter) + param = param_class(builder.paramstyle, paramkey, *args) keys[paramkey] = param - return [ param ] + return param + def make_composite_param(builder, paramkey, items, func): + return builder.make_param(builder.composite_param_class, paramkey, items, func) + def STAR(builder, table_alias): + return builder.quote_name(table_alias), '.*' def ROW(builder, *items): return '(', join(', ', imap(builder, items)), ')' def VALUE(builder, value): - return [ builder.make_value(builder.paramstyle, value) ] + return builder.value_class(builder.paramstyle, value) def AND(builder, *cond_list): cond_list = [ builder(condition) for condition in cond_list ] return join(' AND ', cond_list) @@ -380,6 +427,15 @@ def POW(builder, expr1, expr2): DIV = make_binary_op(' / ', True) FLOORDIV = make_binary_op(' / ', True) + def MOD(builder, a, b): + symbol = ' %% ' if builder.paramstyle in ('format', 'pyformat') else ' % ' + return '(', builder(a), symbol, builder(b), ')' + def FLOAT_EQ(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' + def FLOAT_NE(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14' def CONCAT(builder, *args): return '(', join(' || ', imap(builder, args)), ')' def NEG(builder, expr): @@ -412,24 +468,39 @@ def NOT_IN(builder, expr1, x): return builder(expr1), ' NOT IN ', builder(x) expr_list = [ builder(expr) for expr in x ] return builder(expr1), ' NOT IN (', join(', ', expr_list), ')' - def COUNT(builder, kind, *expr_list): - if kind == 'ALL': + def COUNT(builder, distinct, *expr_list): + assert distinct in (None, True, False) + if not distinct: if not expr_list: return ['COUNT(*)'] - return 'COUNT(', join(', ', imap(builder, expr_list)), ')' - elif kind == 'DISTINCT': - if not expr_list: throw(AstError, 'COUNT(DISTINCT) without argument') - if len(expr_list) == 1: return 'COUNT(DISTINCT ', builder(expr_list[0]), ')' if builder.dialect == 'PostgreSQL': - return 'COUNT(DISTINCT ', builder.ROW(*expr_list), ')' - elif builder.dialect == 'MySQL': - return 'COUNT(DISTINCT ', join(', ', imap(builder, expr_list)), ')' - # Oracle and SQLite queries translated to completely different subquery syntax - else: throw(NotImplementedError) # This line must not be executed - throw(AstError, 'Invalid COUNT kind (must be ALL or DISTINCT)') - def SUM(builder, expr, distinct=False): + return 'COUNT(', builder.ROW(*expr_list), ')' + else: + return 'COUNT(', join(', ', imap(builder, expr_list)), ')' + if not expr_list: throw(AstError, 'COUNT(DISTINCT) without argument') + if len(expr_list) == 1: + return 'COUNT(DISTINCT ', builder(expr_list[0]), ')' + + if builder.dialect == 'PostgreSQL': + return 'COUNT(DISTINCT ', builder.ROW(*expr_list), ')' + elif builder.dialect == 'MySQL': + return 'COUNT(DISTINCT ', join(', ', imap(builder, expr_list)), ')' + # Oracle and SQLite queries translated to completely different subquery syntax + else: throw(NotImplementedError) # This line must not be executed + def SUM(builder, distinct, expr): + assert distinct in (None, True, False) return distinct and 'coalesce(SUM(DISTINCT ' or 'coalesce(SUM(', builder(expr), '), 0)' - def AVG(builder, expr, distinct=False): + def AVG(builder, distinct, expr): + assert distinct in (None, True, False) return distinct and 'AVG(DISTINCT ' or 'AVG(', builder(expr), ')' + def GROUP_CONCAT(builder, distinct, expr, sep=None): + assert distinct in (None, True, False) + result = distinct and 'GROUP_CONCAT(DISTINCT ' or 'GROUP_CONCAT(', builder(expr) + if sep is not None: + if builder.provider.dialect == 'MySQL': + result = result, ' SEPARATOR ', builder(sep) + else: + result = result, ', ', builder(sep) + return result, ')' UPPER = make_unary_func('upper') LOWER = make_unary_func('lower') LENGTH = make_unary_func('length') @@ -437,20 +508,112 @@ def AVG(builder, expr, distinct=False): def COALESCE(builder, *args): if len(args) < 2: assert False # pragma: no cover return 'coalesce(', join(', ', imap(builder, args)), ')' - def MIN(builder, *args): + def MIN(builder, distinct, *args): + assert not distinct, distinct if len(args) == 0: assert False # pragma: no cover elif len(args) == 1: fname = 'MIN' - else: fname = 'least' + else: fname = builder.least_func_name return fname, '(', join(', ', imap(builder, args)), ')' - def MAX(builder, *args): + def MAX(builder, distinct, *args): + assert not distinct, distinct if len(args) == 0: assert False # pragma: no cover elif len(args) == 1: fname = 'MAX' - else: fname = 'greatest' + else: fname = builder.greatest_func_name return fname, '(', join(', ', imap(builder, args)), ')' def SUBSTR(builder, expr, start, len=None): if len is None: return 'substr(', builder(expr), ', ', builder(start), ')' return 'substr(', builder(expr), ', ', builder(start), ', ', builder(len), ')' + def STRING_SLICE(builder, expr, start, stop): + if start is None: + start = [ 'VALUE', 0 ] + + if start[0] == 'VALUE': + start_value = start[1] + if builder.dialect == 'PostgreSQL' and start_value < 0: + index_sql = [ 'LENGTH', expr ] + if start_value < -1: + index_sql = [ 'SUB', index_sql, [ 'VALUE', -(start_value + 1) ] ] + else: + if start_value >= 0: start_value += 1 + index_sql = [ 'VALUE', start_value ] + else: + inner_sql = start + then = [ 'ADD', inner_sql, [ 'VALUE', 1 ] ] + else_ = [ 'ADD', [ 'LENGTH', expr ], then ] if builder.dialect == 'PostgreSQL' else inner_sql + index_sql = [ 'IF', [ 'GE', inner_sql, [ 'VALUE', 0 ] ], then, else_ ] + + if stop is None: + len_sql = None + elif stop[0] == 'VALUE': + stop_value = stop[1] + if start[0] == 'VALUE': + start_value = start[1] + if start_value >= 0 and stop_value >= 0: + len_sql = [ 'VALUE', stop_value - start_value ] + elif start_value < 0 and stop_value < 0: + len_sql = [ 'VALUE', stop_value - start_value ] + elif start_value >= 0 and stop_value < 0: + len_sql = [ 'SUB', [ 'LENGTH', expr ], [ 'VALUE', start_value - stop_value ]] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + elif start_value < 0 and stop_value >= 0: + len_sql = [ 'SUB', [ 'VALUE', stop_value + 1 ], index_sql ] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + else: + assert False # pragma: nocover1 + else: + start_sql = [ 'COALESCE', start, [ 'VALUE', 0 ] ] + if stop_value >= 0: + start_positive = [ 'SUB', stop, start_sql ] + start_negative = [ 'SUB', [ 'VALUE', stop_value + 1 ], index_sql ] + else: + start_positive = [ 'SUB', [ 'LENGTH', expr ], [ 'ADD', start_sql, [ 'VALUE', -stop_value ] ] ] + start_negative = [ 'SUB', stop, start_sql] + len_sql = [ 'IF', [ 'GE', start_sql, [ 'VALUE', 0 ] ], start_positive, start_negative ] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + else: + stop_sql = [ 'COALESCE', stop, [ 'VALUE', -1 ] ] + if start[0] == 'VALUE': + start_value = start[1] + start_sql = [ 'VALUE', start_value ] + if start_value >= 0: + stop_positive = [ 'SUB', stop_sql, start_sql ] + stop_negative = [ 'SUB', [ 'LENGTH', expr ], [ 'SUB', start_sql, stop_sql ] ] + else: + stop_positive = [ 'SUB', [ 'ADD', stop_sql, [ 'VALUE', 1 ] ], index_sql ] + stop_negative = [ 'SUB', stop_sql, start_sql] + len_sql = [ 'IF', [ 'GE', stop_sql, [ 'VALUE', 0 ] ], stop_positive, stop_negative ] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + else: + start_sql = [ 'COALESCE', start, [ 'VALUE', 0 ] ] + both_positive = [ 'SUB', stop_sql, start_sql ] + both_negative = both_positive + start_positive = [ 'SUB', [ 'LENGTH', expr ], [ 'SUB', start_sql, stop_sql ] ] + stop_positive = [ 'SUB', [ 'ADD', stop_sql, [ 'VALUE', 1 ] ], index_sql ] + len_sql = [ 'CASE', None, [ + ( + [ 'AND', [ 'GE', start_sql, [ 'VALUE', 0 ] ], [ 'GE', stop_sql, [ 'VALUE', 0 ] ] ], + both_positive + ), + ( + [ 'AND', [ 'LT', start_sql, [ 'VALUE', 0 ] ], [ 'LT', stop_sql, [ 'VALUE', 0 ] ] ], + both_negative + ), + ( + [ 'AND', [ 'GE', start_sql, [ 'VALUE', 0 ] ], [ 'LT', stop_sql, [ 'VALUE', 0 ] ] ], + start_positive + ), + ( + [ 'AND', [ 'LT', start_sql, [ 'VALUE', 0 ] ], [ 'GE', stop_sql, [ 'VALUE', 0 ] ] ], + stop_positive + ), + ]] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + sql = [ 'SUBSTR', expr, index_sql, len_sql ] + return builder(sql) def CASE(builder, expr, cases, default=None): + if expr is None and default is not None and default[0] == 'CASE' and default[1] is None: + cases2, default2 = default[2:] + return builder.CASE(None, tuple(cases) + tuple(cases2), default2) result = [ 'case' ] if expr is not None: result.append(' ') @@ -461,6 +624,8 @@ def CASE(builder, expr, cases, default=None): result.extend((' else ', builder(default))) result.append(' end') return result + def IF(builder, cond, then, else_): + return builder.CASE(None, [(cond, then)], else_) def TRIM(builder, expr, chars=None): if chars is None: return 'trim(', builder(expr), ')' return 'trim(', builder(expr), ', ', builder(chars), ')' @@ -474,6 +639,10 @@ def REPLACE(builder, str, from_, to): return 'replace(', builder(str), ', ', builder(from_), ', ', builder(to), ')' def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS integer)' + def TO_STR(builder, expr): + return 'CAST(', builder(expr), ' AS text)' + def TO_REAL(builder, expr): + return 'CAST(', builder(expr), ' AS real)' def TODAY(builder): return 'CURRENT_DATE' def NOW(builder): @@ -497,3 +666,64 @@ def RANDOM(builder): def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] + def build_json_path(builder, path): + empty_slice = slice(None, None, None) + has_params = False + has_wildcards = False + items = [ builder(element) for element in path ] + for item in items: + if isinstance(item, Param): + has_params = True + elif isinstance(item, Value): + value = item.value + if value is Ellipsis or value == empty_slice: has_wildcards = True + else: assert isinstance(value, (int, basestring)), value + else: assert False, item + if has_params: + paramkey = tuple(item.paramkey if isinstance(item, Param) else + None if type(item.value) is slice else item.value + for item in items) + path_sql = builder.make_composite_param(paramkey, items, builder.eval_json_path) + else: + result_value = builder.eval_json_path(item.value for item in items) + path_sql = builder.value_class(builder.paramstyle, result_value) + return path_sql, has_params, has_wildcards + @classmethod + def eval_json_path(cls, values): + result = ['$'] + append = result.append + empty_slice = slice(None, None, None) + for value in values: + if isinstance(value, int): append('[%d]' % value) + elif isinstance(value, basestring): + append('.' + value if is_ident(value) else '."%s"' % value.replace('"', '\\"')) + elif value is Ellipsis: append('.*') + elif value == empty_slice: append('[*]') + else: assert False, value + return ''.join(result) + def JSON_QUERY(builder, expr, path): + throw(NotImplementedError) + def JSON_VALUE(builder, expr, path, type): + throw(NotImplementedError) + def JSON_NONZERO(builder, expr): + throw(NotImplementedError) + def JSON_CONCAT(builder, left, right): + throw(NotImplementedError) + def JSON_CONTAINS(builder, expr, path, key): + throw(NotImplementedError) + def JSON_ARRAY_LENGTH(builder, value): + throw(NotImplementedError) + def JSON_PARAM(builder, expr): + return builder(expr) + def ARRAY_INDEX(builder, col, index): + throw(NotImplementedError) + def ARRAY_CONTAINS(builder, key, not_in, col): + throw(NotImplementedError) + def ARRAY_SUBSET(builder, array1, not_in, array2): + throw(NotImplementedError) + def ARRAY_LENGTH(builder, array): + throw(NotImplementedError) + def ARRAY_SLICE(builder, array, start, stop): + throw(NotImplementedError) + def MAKE_ARRAY(builder, *items): + throw(NotImplementedError) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 72742f1ba..be78709eb 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, pickle, with_metaclass +from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, with_metaclass, int_types -import types, sys, re, itertools +import types, sys, re, itertools, inspect from decimal import Decimal from datetime import date, time, datetime, timedelta from random import random @@ -12,13 +12,16 @@ from pony.thirdparty.compiler import ast from pony import options, utils -from pony.utils import is_ident, throw, reraise, concat -from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError +from pony.utils import localbase, is_ident, throw, reraise, copy_ast, between, concat, coalesce +from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError, create_extractors +from pony.orm.decompiling import decompile, DecompileError from pony.orm.ormtypes import \ - numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ - get_normalized_type_of, normalize_type, coerce_types, are_comparable_types + numeric_types, comparable_types, SetType, FuncType, MethodType, raw_sql, RawSQLType, \ + normalize, normalize_type, coerce_types, are_comparable_types, \ + Json, QueryType, Array, array_types from pony.orm import core -from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper +from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper, \ + special_functions, const_functions, extract_vars, Query, UseAnotherTranslator NoneType = type(None) @@ -65,11 +68,33 @@ def type2str(t): try: return t.__name__ except: return str(t) +class Local(localbase): + def __init__(local): + local.translators = [] + + @property + def translator(self): + return local.translators[-1] + +local = Local() + class SQLTranslator(ASTTranslator): dialect = None row_value_syntax = True + json_path_wildcard_syntax = False + json_values_are_comparable = True rowid_support = False + def __enter__(translator): + local.translators.append(translator) + + def __exit__(translator, exc_type, exc_val, exc_tb): + t = local.translators.pop() + if isinstance(exc_val, UseAnotherTranslator): + assert t is exc_val.translator + else: + assert t is translator + def default_post(translator, node): throw(NotImplementedError) # pragma: no cover @@ -77,46 +102,83 @@ def dispatch(translator, node): if hasattr(node, 'monad'): return # monad already assigned somehow if not getattr(node, 'external', False) or getattr(node, 'constant', False): return ASTTranslator.dispatch(translator, node) # default route - varkey = translator.filter_num, node.src - t = translator.vartypes[varkey] + translator.call(translator.__class__.dispatch_external, node) + + def dispatch_external(translator, node): + varkey = translator.filter_num, node.src, translator.code_key + t = translator.root_translator.vartypes[varkey] tt = type(t) if t is NoneType: - monad = translator.ConstMonad.new(translator, None) + monad = ConstMonad.new(None) elif tt is SetType: if isinstance(t.item_type, EntityMeta): - monad = translator.EntityMonad(translator, t.item_type) + monad = EntityMonad(t.item_type) else: throw(NotImplementedError) # pragma: no cover + elif tt is QueryType: + prev_translator = deepcopy(t.translator) + prev_translator.parent = translator + prev_translator.injected = True + if translator.database is not prev_translator.database: + throw(TranslationError, 'Mixing queries from different databases') + monad = QuerySetMonad(prev_translator) + if t.limit is not None or t.offset is not None: + monad = monad.call_limit(t.limit, t.offset) elif tt is FuncType: func = t.func - func_monad_class = translator.registered_functions.get(func, translator.ErrorSpecialFuncMonad) - monad = func_monad_class(translator, func) + func_monad_class = translator.registered_functions.get(func) + if func_monad_class is not None: + monad = func_monad_class(func) + else: + monad = HybridFuncMonad(t, func.__name__) elif tt is MethodType: obj, func = t.obj, t.func - if not isinstance(obj, EntityMeta): throw(NotImplementedError) - entity_monad = translator.EntityMonad(translator, obj) - if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) - monad = translator.MethodMonad(translator, entity_monad, func.__name__) + if isinstance(obj, EntityMeta): + entity_monad = EntityMonad(obj) + if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) + monad = MethodMonad(entity_monad, func.__name__) + elif node.src == 'random': # For PyPy + monad = FuncRandomMonad(t) + else: throw(NotImplementedError) elif isinstance(node, ast.Name) and node.name in ('True', 'False'): value = True if node.name == 'True' else False - monad = translator.ConstMonad.new(translator, value) + monad = ConstMonad.new(value) elif tt is tuple: params = [] + is_array = False + if t and translator.database.provider.array_converter_cls is not None: + types = set(t) + if len(types) == 1 and unicode in types: + item_type = unicode + is_array = True + else: + item_type = int + for type_ in types: + if type_ is float: + item_type = float + if type_ not in (float, int) or not hasattr(type_, '__index__'): + break + else: + is_array = True + for i, item_type in enumerate(t): if item_type is NoneType: throw(TypeError, 'Expression `%s` should not contain None values' % node.src) - param = translator.ParamMonad.new(translator, item_type, (varkey, i, None)) + param = ParamMonad.new(item_type, (varkey, i, None)) params.append(param) - monad = translator.ListMonad(translator, params) + monad = ListMonad(params) + if is_array: + array_type = array_types.get(item_type, None) + monad = ArrayParamMonad(array_type, (varkey, None, None), list_monad=monad) elif isinstance(t, RawSQLType): - monad = translator.RawSQLMonad(translator, t, varkey) + monad = RawSQLMonad(t, varkey) else: - monad = translator.ParamMonad.new(translator, t, (varkey, None, None)) + monad = ParamMonad.new(t, (varkey, None, None)) node.monad = monad monad.node = node monad.aggregated = monad.nogroup = False def call(translator, method, node): - try: monad = method(node) + try: monad = method(translator, node) except Exception: exc_class, exc, tb = sys.exc_info() try: @@ -154,105 +216,195 @@ def call(translator, method, node): else: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: %s' % ast2src(node)) return monad - def __init__(translator, tree, extractors, vartypes, parent_translator=None, left_join=False, optimize=None): + def __init__(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): + local.translators.append(translator) + try: + translator.init(tree, parent_translator, code_key, filter_num, extractors, vars, vartypes, left_join, optimize) + except UseAnotherTranslator as e: + assert local.translators + t = local.translators.pop() + assert t is e.translator + raise + else: + assert local.translators + t = local.translators.pop() + assert t is translator + + def init(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): + this = translator assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) - translator.database = None - translator.argnames = None - translator.filter_num = parent_translator.filter_num if parent_translator is not None else 0 + translator.can_be_cached = True + translator.parent = parent_translator + translator.injected = False + if parent_translator is None: + translator.root_translator = translator + translator.database = None + translator.sqlquery = SqlQuery(translator, left_join=left_join) + assert code_key is not None and filter_num is not None + translator.code_key = translator.original_code_key = code_key + translator.filter_num = translator.original_filter_num = filter_num + else: + translator.root_translator = parent_translator.root_translator + translator.database = parent_translator.database + translator.sqlquery = SqlQuery(translator, parent_translator.sqlquery, left_join=left_join) + assert code_key is None and filter_num is None + translator.code_key = parent_translator.code_key + translator.filter_num = parent_translator.filter_num + translator.original_code_key = translator.original_filter_num = None translator.extractors = extractors + translator.vars = vars translator.vartypes = vartypes - translator.parent = parent_translator + translator.namespace_stack = [{}] if not parent_translator else [ parent_translator.namespace.copy() ] + translator.func_extractors_map = {} + translator.fixed_param_values = {} + translator.func_vartypes = {} translator.left_join = left_join translator.optimize = optimize translator.from_optimized = False translator.optimization_failed = False - if not parent_translator: subquery = Subquery(left_join=left_join) - else: subquery = Subquery(parent_translator.subquery, left_join=left_join) - translator.subquery = subquery - tablerefs = subquery.tablerefs translator.distinct = False - translator.conditions = subquery.conditions + translator.conditions = translator.sqlquery.conditions translator.having_conditions = [] translator.order = [] + translator.limit = translator.offset = None + translator.inside_order_by = False translator.aggregated = False if not optimize else True - translator.inside_expr = False - translator.inside_not = False translator.hint_join = False translator.query_result_is_cacheable = True translator.aggregated_subquery_paths = set() for i, qual in enumerate(tree.quals): assign = qual.assign - if not isinstance(assign, ast.AssName): throw(NotImplementedError, ast2src(assign)) - if assign.flags != 'OP_ASSIGN': throw(TypeError, ast2src(assign)) + if isinstance(assign, ast.AssTuple): + ass_names = tuple(assign.nodes) + elif isinstance(assign, ast.AssName): + ass_names = (assign,) + else: + throw(NotImplementedError, ast2src(assign)) - name = assign.name - if name in tablerefs: throw(TranslationError, 'Duplicate name: %r' % name) - if name.startswith('__'): throw(TranslationError, 'Illegal name: %r' % name) + for ass_name in ass_names: + if not isinstance(ass_name, ast.AssName): + throw(NotImplementedError, ast2src(ass_name)) + if ass_name.flags != 'OP_ASSIGN': + throw(TypeError, ast2src(ass_name)) + + names = tuple(ass_name.name for ass_name in ass_names) + for name in names: + if name in translator.namespace and name in translator.sqlquery.tablerefs: + throw(TranslationError, 'Duplicate name: %r' % name) + if name.startswith('__'): throw(TranslationError, 'Illegal name: %r' % name) + + name = names[0] if len(names) == 1 else None + + def check_name_is_single(): + if len(names) > 1: throw(TypeError, 'Single variable name expected. Got: %s' % ast2src(assign)) + + database = entity = None node = qual.iter monad = getattr(node, 'monad', None) - src = getattr(node, 'src', None) + if monad: # Lambda was encountered inside generator - assert isinstance(monad, EntityMonad) + check_name_is_single() + assert parent_translator and i == 0 entity = monad.type.item_type - tablerefs[name] = TableRef(subquery, name, entity) - elif src: - iterable = translator.vartypes[translator.filter_num, src] - if not isinstance(iterable, SetType): throw(TranslationError, - 'Inside declarative query, iterator must be entity. ' - 'Got: for %s in %s' % (name, ast2src(qual.iter))) - entity = iterable.item_type - if not isinstance(entity, EntityMeta): - throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) - if i > 0: - if translator.left_join: throw(TranslationError, - 'Collection expected inside left join query. ' - 'Got: for %s in %s' % (name, ast2src(qual.iter))) - translator.distinct = True - tableref = TableRef(subquery, name, entity) - tablerefs[name] = tableref - tableref.make_join() - else: - attr_names = [] - while isinstance(node, ast.Getattr): - attr_names.append(node.attrname) - node = node.expr - if not isinstance(node, ast.Name) or not attr_names: - throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) - node_name = node.name - attr_names.reverse() - name_path = node_name - parent_tableref = subquery.get_tableref(node_name) - if parent_tableref is None: throw(TranslationError, "Name %r must be defined in query" % node_name) - parent_entity = parent_tableref.entity - last_index = len(attr_names) - 1 - for j, attrname in enumerate(attr_names): - attr = parent_entity._adict_.get(attrname) - if attr is None: throw(AttributeError, attrname) - entity = attr.py_type + if isinstance(monad, EntityMonad): + tableref = TableRef(translator.sqlquery, name, entity) + translator.sqlquery.tablerefs[name] = tableref + elif isinstance(monad, AttrSetMonad): + translator.sqlquery = monad._subselect(translator.sqlquery, extract_outer_conditions=False) + tableref = monad.tableref + else: assert False # pragma: no cover + translator.namespace[name] = ObjectIterMonad(tableref, entity) + elif node.external: + varkey = translator.filter_num, node.src, translator.code_key + iterable = translator.root_translator.vartypes[varkey] + if isinstance(iterable, SetType): + check_name_is_single() + entity = iterable.item_type if not isinstance(entity, EntityMeta): + throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) + if i > 0: + if translator.left_join: throw(TranslationError, + 'Collection expected inside left join query. ' + 'Got: for %s in %s' % (name, ast2src(qual.iter))) + translator.distinct = True + tableref = TableRef(translator.sqlquery, name, entity) + translator.sqlquery.tablerefs[name] = tableref + tableref.make_join() + translator.namespace[name] = node.monad = ObjectIterMonad(tableref, entity) + elif isinstance(iterable, QueryType): + prev_translator = deepcopy(iterable.translator) + prev_limit = iterable.limit + prev_offset = iterable.offset + database = prev_translator.database + try: + translator.process_query_qual(prev_translator, prev_limit, prev_offset, + names, try_extend_prev_query=not i) + except UseAnotherTranslator as e: + assert local.translators and local.translators[-1] is translator + translator = e.translator + local.translators[-1] = translator + else: throw(TranslationError, 'Inside declarative query, iterator must be entity or query. ' + 'Got: for %s in %s' % (name, ast2src(qual.iter))) + + else: + translator.dispatch(node) + monad = node.monad + + if isinstance(monad, QuerySetMonad): + subtranslator = monad.subtranslator + database = subtranslator.database + try: + translator.process_query_qual(subtranslator, monad.limit, monad.offset, names) + except UseAnotherTranslator: + assert False + else: + check_name_is_single() + attr_names = [] + while isinstance(monad, (AttrMonad, AttrSetMonad)) and monad.parent is not None: + attr_names.append(monad.attr.name) + monad = monad.parent + attr_names.reverse() + + if not isinstance(monad, ObjectIterMonad): throw(NotImplementedError, 'for %s in %s' % (name, ast2src(qual.iter))) - can_affect_distinct = None - if attr.is_collection: - if not isinstance(attr, Set): throw(NotImplementedError, ast2src(qual.iter)) - reverse = attr.reverse - if reverse.is_collection: - if not isinstance(reverse, Set): throw(NotImplementedError, ast2src(qual.iter)) - translator.distinct = True - elif parent_tableref.alias != tree.quals[i-1].assign.name: - translator.distinct = True - else: can_affect_distinct = True - if j == last_index: name_path = name - else: name_path += '-' + attr.name - tableref = JoinedTableRef(subquery, name_path, parent_tableref, attr) - if can_affect_distinct is not None: - tableref.can_affect_distinct = can_affect_distinct - tablerefs[name_path] = tableref - parent_tableref = tableref - parent_entity = entity - - database = entity._database_ + name_path = monad.tableref.alias # or name_path, it is the same + + parent_tableref = monad.tableref + parent_entity = parent_tableref.entity + + last_index = len(attr_names) - 1 + for j, attrname in enumerate(attr_names): + attr = parent_entity._adict_.get(attrname) + if attr is None: throw(AttributeError, attrname) + entity = attr.py_type + if not isinstance(entity, EntityMeta): + throw(NotImplementedError, 'for %s in %s' % (name, ast2src(qual.iter))) + can_affect_distinct = None + if attr.is_collection: + if not isinstance(attr, Set): throw(NotImplementedError, ast2src(qual.iter)) + reverse = attr.reverse + if reverse.is_collection: + if not isinstance(reverse, Set): throw(NotImplementedError, ast2src(qual.iter)) + translator.distinct = True + elif parent_tableref.alias != tree.quals[i-1].assign.name: + translator.distinct = True + else: can_affect_distinct = True + if j == last_index: name_path = name + else: name_path += '-' + attr.name + tableref = translator.sqlquery.add_tableref(name_path, parent_tableref, attr) + tableref.make_join(pk_only=True) + if j == last_index: + translator.namespace[name] = ObjectIterMonad(tableref, tableref.entity) + if can_affect_distinct is not None: + tableref.can_affect_distinct = can_affect_distinct + parent_tableref = tableref + parent_entity = entity + + if database is None: + assert entity is not None + database = entity._database_ assert database.schema is not None if translator.database is None: translator.database = database elif translator.database is not database: throw(TranslationError, @@ -261,40 +413,41 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef for if_ in qual.ifs: assert isinstance(if_, ast.GenExprIf) translator.dispatch(if_) - if isinstance(if_.monad, translator.AndMonad): cond_monads = if_.monad.operands + if isinstance(if_.monad, AndMonad): cond_monads = if_.monad.operands else: cond_monads = [ if_.monad ] for m in cond_monads: - if not m.aggregated: translator.conditions.extend(m.getsql()) + if not getattr(m, 'aggregated', False): translator.conditions.extend(m.getsql()) else: translator.having_conditions.extend(m.getsql()) - translator.inside_expr = True translator.dispatch(tree.expr) assert not translator.hint_join - assert not translator.inside_not monad = tree.expr.monad - if isinstance(monad, translator.ParamMonad): throw(TranslationError, + if isinstance(monad, ParamMonad): throw(TranslationError, "External parameter '%s' cannot be used as query result" % ast2src(tree.expr)) - translator.expr_monads = monad.items if isinstance(monad, translator.ListMonad) else [ monad ] + translator.expr_monads = monad.items if isinstance(monad, ListMonad) else [ monad ] translator.groupby_monads = None expr_type = monad.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): - monad.orderby_columns = list(xrange(1, len(expr_type._pk_columns_)+1)) + entity = expr_type + translator.expr_type = entity + monad.orderby_columns = list(xrange(1, len(entity._pk_columns_)+1)) if monad.aggregated: throw(TranslationError) - if translator.aggregated: translator.groupby_monads = [ monad ] - else: translator.distinct |= monad.requires_distinct() - if isinstance(monad, translator.ObjectMixin): - entity = monad.type + if isinstance(monad, QuerySetMonad): + throw(NotImplementedError) + elif isinstance(monad, ObjectMixin): tableref = monad.tableref - elif isinstance(monad, translator.AttrSetMonad): - entity = monad.type.item_type - tableref = monad.make_tableref(translator.subquery) + elif isinstance(monad, AttrSetMonad): + tableref = monad.make_tableref(translator.sqlquery) else: assert False # pragma: no cover + if translator.aggregated: + translator.groupby_monads = [ monad ] + else: + translator.distinct |= monad.requires_distinct() translator.tableref = tableref pk_only = parent_translator is not None or translator.aggregated alias, pk_columns = tableref.make_join(pk_only=pk_only) translator.alias = alias - translator.expr_type = entity translator.expr_columns = [ [ 'COLUMN', alias, column ] for column in pk_columns ] translator.row_layout = None translator.col_names = [ attr.name for attr in entity._attrs_ @@ -319,18 +472,19 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef expr_set.add(m.tableref.name_path) elif isinstance(m, AttrMonad) and isinstance(m.parent, ObjectIterMonad): expr_set.add((m.parent.tableref.name_path, m.attr)) - for tr in tablerefs.values(): + for tr in translator.sqlquery.tablerefs.values(): + if tr.entity is None: continue if not tr.can_affect_distinct: continue if tr.name_path in expr_set: continue - for attr in tr.entity._pk_attrs_: - if (tr.name_path, attr) not in expr_set: break - else: continue - translator.distinct = True - break + if any((tr.name_path, attr) not in expr_set for attr in tr.entity._pk_attrs_): + translator.distinct = True + break row_layout = [] offset = 0 provider = translator.database.provider for m in expr_monads: + if m.disable_distinct: + translator.distinct = False expr_type = m.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): @@ -343,24 +497,144 @@ def func(values, constructor=expr_type._get_by_raw_pkval_): offset = next_offset else: converter = provider.get_converter_by_py_type(expr_type) - def func(value, sql2py=converter.sql2py): + def func(value, converter=converter): if value is None: return None - return sql2py(value) + value = converter.sql2py(value) + value = converter.dbval2val(value) + return value row_layout.append((func, offset, ast2src(m.node))) - m.orderby_columns = (offset+1,) + m.orderby_columns = (offset+1,) if not m.disable_ordering else () offset += 1 translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] - def shallow_copy_of_subquery_ast(translator, move_outer_conditions=True, is_not_null_checks=False): - subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=False, is_not_null_checks=is_not_null_checks) - assert attr_offsets is None + if translator.aggregated: + translator.distinct = False + translator.vars = None + if translator is not this: + raise UseAnotherTranslator(translator) + @property + def namespace(translator): + return translator.namespace_stack[-1] + def can_be_optimized(translator): + if translator.groupby_monads: return False + if len(translator.aggregated_subquery_paths) != 1: return False + aggr_path = next(iter(translator.aggregated_subquery_paths)) + for tableref in translator.sqlquery.tablerefs.values(): + if tableref.joined and not aggr_path.startswith(tableref.name_path): + return False + return aggr_path + def process_query_qual(translator, prev_translator, prev_limit, prev_offset, names, try_extend_prev_query=False): + sqlquery = translator.sqlquery + tablerefs = sqlquery.tablerefs + expr_types = prev_translator.expr_type + if not isinstance(expr_types, tuple): expr_types = (expr_types,) + expr_count = len(expr_types) + + if expr_count > 1 and len(names) == 1: + throw(NotImplementedError, + 'Please unpack a tuple of (%s) in for-loop to individual variables (like: "for x, y in ...")' + % (', '.join(ast2src(m.node) for m in prev_translator.expr_monads))) + elif expr_count > len(names): + throw(TranslationError, + 'Not enough values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' + % (', '.join(names), + ', '.join(ast2src(m.node) for m in prev_translator.expr_monads), + len(names), expr_count)) + elif expr_count < len(names): + throw(TranslationError, + 'Too many values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' + % (', '.join(names), + ', '.join(ast2src(m.node) for m in prev_translator.expr_monads), + len(names), expr_count)) + + if try_extend_prev_query: + if prev_translator.aggregated: pass + elif prev_translator.left_join: pass + else: + assert translator.parent is None + assert prev_translator.vars is None + prev_translator.code_key = translator.code_key + prev_translator.filter_num = translator.filter_num + prev_translator.extractors.update(translator.extractors) + prev_translator.vars = translator.vars + prev_translator.vartypes.update(translator.vartypes) + prev_translator.left_join = translator.left_join + prev_translator.optimize = translator.optimize + prev_translator.namespace_stack = [ + {name: expr for name, expr in izip(names, prev_translator.expr_monads)} + ] + prev_translator.limit, prev_translator.offset = combine_limit_and_offset( + prev_translator.limit, prev_translator.offset, prev_limit, prev_offset) + raise UseAnotherTranslator(prev_translator) + + + if len(names) == 1 and isinstance(prev_translator.expr_type, EntityMeta) \ + and not prev_translator.aggregated and not prev_translator.distinct: + name = names[0] + entity = prev_translator.expr_type + [expr_monad] = prev_translator.expr_monads + entity_alias = expr_monad.tableref.alias + subquery_ast = prev_translator.construct_subquery_ast(prev_limit, prev_offset, star=entity_alias) + tableref = StarTableRef(sqlquery, name, entity, subquery_ast) + tablerefs[name] = tableref + tableref.make_join() + translator.namespace[name] = ObjectIterMonad(tableref, entity) + else: + aliases = [] + aliases_dict = {} + for name, base_expr_monad in izip(names, prev_translator.expr_monads): + t = base_expr_monad.type + if isinstance(t, EntityMeta): + t_aliases = [] + for suffix in t._pk_paths_: + alias = '%s-%s' % (name, suffix) + t_aliases.append(alias) + aliases.extend(t_aliases) + aliases_dict[base_expr_monad] = t_aliases + else: + aliases.append(name) + aliases_dict[base_expr_monad] = name + + subquery_ast = prev_translator.construct_subquery_ast(prev_limit, prev_offset, aliases=aliases) + tableref = ExprTableRef(sqlquery, 't', subquery_ast, names, aliases) + for name in names: + tablerefs[name] = tableref + tableref.make_join() + + for name, base_expr_monad in izip(names, prev_translator.expr_monads): + t = base_expr_monad.type + if isinstance(t, EntityMeta): + columns = aliases_dict[base_expr_monad] + expr_tableref = ExprJoinedTableRef(sqlquery, tableref, columns, name, t) + expr_monad = ObjectIterMonad(expr_tableref, t) + else: + column = aliases_dict[base_expr_monad] + expr_ast = ['COLUMN', tableref.alias, column] + expr_monad = ExprMonad.new(t, expr_ast, base_expr_monad.nullable) + assert name not in translator.namespace + translator.namespace[name] = expr_monad + def construct_subquery_ast(translator, limit=None, offset=None, aliases=None, star=None, + distinct=None, is_not_null_checks=False): + subquery_ast, attr_offsets = translator.construct_sql_ast( + limit, offset, distinct, is_not_null_checks=is_not_null_checks) assert len(subquery_ast) >= 3 and subquery_ast[0] == 'SELECT' select_ast = subquery_ast[1][:] - assert select_ast[0] == 'ALL' + assert select_ast[0] in ('ALL', 'DISTINCT', 'AGGREGATES'), select_ast + if aliases: + assert not star and len(aliases) == len(select_ast) - 1 + for i, alias in enumerate(aliases): + expr = select_ast[i+1] + if expr[0] == 'AS': expr = expr[1] + select_ast[i+1] = [ 'AS', expr, alias ] + elif star is not None: + assert isinstance(star, basestring) + for section in subquery_ast: + assert section[0] not in ('GROUP_BY', 'HAVING'), subquery_ast + select_ast[1:] = [ [ 'STAR', star ] ] from_ast = subquery_ast[2][:] - assert from_ast[0] == 'FROM' + assert from_ast[0] in ('FROM', 'LEFT_JOIN') if len(subquery_ast) == 3: where_ast = [ 'WHERE' ] @@ -372,83 +646,92 @@ def shallow_copy_of_subquery_ast(translator, move_outer_conditions=True, is_not_ where_ast = subquery_ast[3][:] other_ast = subquery_ast[4:] - if move_outer_conditions and len(from_ast[1]) == 4: + if len(from_ast[1]) == 4: outer_conditions = from_ast[1][-1] from_ast[1] = from_ast[1][:-1] if outer_conditions[0] == 'AND': where_ast[1:1] = outer_conditions[1:] else: where_ast.insert(1, outer_conditions) return [ 'SELECT', select_ast, from_ast, where_ast ] + other_ast - def can_be_optimized(translator): - if translator.groupby_monads: return False - if len(translator.aggregated_subquery_paths) != 1: return False - return next(iter(translator.aggregated_subquery_paths)) - def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None, for_update=False, nowait=False, - attrs_to_prefetch=(), is_not_null_checks=False): + def construct_sql_ast(translator, limit=None, offset=None, distinct=None, + aggr_func_name=None, aggr_func_distinct=None, sep=None, + for_update=False, nowait=False, skip_locked=False, is_not_null_checks=False): attr_offsets = None - if distinct is None: distinct = translator.distinct + if distinct is None: + if not translator.order: + distinct = translator.distinct ast_transformer = lambda ast: ast if for_update: - sql_ast = [ 'SELECT_FOR_UPDATE', nowait ] + sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked ] translator.query_result_is_cacheable = False else: sql_ast = [ 'SELECT' ] - groupby_monads = translator.groupby_monads - if distinct and translator.aggregated and not groupby_monads: - distinct = False - groupby_monads = translator.expr_monads - select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + translator.expr_columns if aggr_func_name: expr_type = translator.expr_type if isinstance(expr_type, EntityMeta): - if aggr_func_name is not 'COUNT': throw(TypeError, + if aggr_func_name == 'GROUP_CONCAT': + if expr_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") + elif aggr_func_name != 'COUNT': throw(TypeError, 'Attribute should be specified for %r aggregate function' % aggr_func_name.lower()) elif isinstance(expr_type, tuple): - if aggr_func_name is not 'COUNT': throw(TypeError, + if aggr_func_name != 'COUNT': throw(TypeError, 'Single attribute should be specified for %r aggregate function' % aggr_func_name.lower()) else: if aggr_func_name in ('SUM', 'AVG') and expr_type not in numeric_types: throw(TypeError, '%r is valid for numeric attributes only' % aggr_func_name.lower()) assert len(translator.expr_columns) == 1 aggr_ast = None - if groupby_monads or (aggr_func_name == 'COUNT' and distinct - and isinstance(translator.expr_type, EntityMeta) - and len(translator.expr_columns) > 1): + if translator.groupby_monads or ( + aggr_func_name == 'COUNT' and distinct + and isinstance(translator.expr_type, EntityMeta) + and len(translator.expr_columns) > 1): outer_alias = 't' - if aggr_func_name == 'COUNT': - outer_aggr_ast = [ 'COUNT', 'ALL' ] + if aggr_func_name == 'COUNT' and not aggr_func_distinct: + outer_aggr_ast = [ 'COUNT', None ] else: assert len(translator.expr_columns) == 1 expr_ast = translator.expr_columns[0] if expr_ast[0] == 'COLUMN': outer_alias, column_name = expr_ast[1:] - outer_aggr_ast = [ aggr_func_name, [ 'COLUMN', outer_alias, column_name ] ] + outer_aggr_ast = [aggr_func_name, aggr_func_distinct, ['COLUMN', outer_alias, column_name]] + if aggr_func_name == 'GROUP_CONCAT' and sep is not None: + outer_aggr_ast.append(['VALUE', sep]) else: select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + [ [ 'AS', expr_ast, 'expr' ] ] - outer_aggr_ast = [ aggr_func_name, [ 'COLUMN', 't', 'expr' ] ] + outer_aggr_ast = [ aggr_func_name, aggr_func_distinct, [ 'COLUMN', 't', 'expr' ] ] + if aggr_func_name == 'GROUP_CONCAT' and sep is not None: + outer_aggr_ast.append(['VALUE', sep]) def ast_transformer(ast): return [ 'SELECT', [ 'AGGREGATES', outer_aggr_ast ], [ 'FROM', [ outer_alias, 'SELECT', ast[1:] ] ] ] else: if aggr_func_name == 'COUNT': - if isinstance(expr_type, (tuple, EntityMeta)) and not distinct: aggr_ast = [ 'COUNT', 'ALL' ] - else: aggr_ast = [ 'COUNT', 'DISTINCT', translator.expr_columns[0] ] - else: aggr_ast = [ aggr_func_name, translator.expr_columns[0] ] + if isinstance(expr_type, (tuple, EntityMeta)) and not distinct and not aggr_func_distinct: + aggr_ast = [ 'COUNT', aggr_func_distinct ] + else: + aggr_ast = [ 'COUNT', True if aggr_func_distinct is None else aggr_func_distinct, + translator.expr_columns[0] ] + else: + aggr_ast = [ aggr_func_name, aggr_func_distinct, translator.expr_columns[0] ] + if aggr_func_name == 'GROUP_CONCAT' and sep is not None: + aggr_ast.append(['VALUE', sep]) + if aggr_ast: select_ast = [ 'AGGREGATES', aggr_ast ] elif isinstance(translator.expr_type, EntityMeta) and not translator.parent \ and not translator.aggregated and not translator.optimize: select_ast, attr_offsets = translator.expr_type._construct_select_clause_( - translator.alias, distinct, translator.tableref.used_attrs, attrs_to_prefetch) + translator.alias, distinct, translator.tableref.used_attrs) sql_ast.append(select_ast) - sql_ast.append(translator.subquery.from_ast) + sql_ast.append(translator.sqlquery.from_ast) conditions = translator.conditions[:] having_conditions = translator.having_conditions[:] if is_not_null_checks: for monad in translator.expr_monads: - if isinstance(monad, translator.ObjectIterMonad): pass - elif isinstance(monad, translator.AttrMonad) and not monad.attr.nullable: pass + if isinstance(monad, ObjectIterMonad): pass + elif not monad.nullable: pass else: notnull_conditions = [ [ 'IS_NOT_NULL', column_ast ] for column_ast in monad.getsql() ] if monad.aggregated: having_conditions.extend(notnull_conditions) @@ -456,9 +739,9 @@ def ast_transformer(ast): if conditions: sql_ast.append([ 'WHERE' ] + conditions) - if groupby_monads: + if translator.groupby_monads: group_by = [ 'GROUP_BY' ] - for m in groupby_monads: group_by.extend(m.getsql()) + for m in translator.groupby_monads: group_by.extend(m.getsql()) sql_ast.append(group_by) else: group_by = None @@ -473,15 +756,18 @@ def ast_transformer(ast): if translator.order and not aggr_func_name: sql_ast.append([ 'ORDER_BY' ] + translator.order) - if range: + limit, offset = combine_limit_and_offset(translator.limit, translator.offset, limit, offset) + if limit is not None or offset is not None: assert not aggr_func_name - start, stop = range - limit = stop - start - offset = start - assert limit is not None - limit_section = [ 'LIMIT', [ 'VALUE', limit ]] - if offset: limit_section.append([ 'VALUE', offset ]) - sql_ast = sql_ast + [ limit_section ] + provider = translator.database.provider + if limit is None: + if provider.dialect == 'SQLite': + limit = -1 + elif provider.dialect == 'MySQL': + limit = 18446744073709551615 + limit_section = [ 'LIMIT', limit ] + if offset: limit_section.append(offset) + sql_ast.append(limit_section) sql_ast = ast_transformer(sql_ast) return sql_ast, attr_offsets @@ -490,17 +776,21 @@ def construct_delete_sql_ast(translator): expr_monad = translator.tree.expr.monad if not isinstance(entity, EntityMeta): throw(TranslationError, 'Delete query should be applied to a single entity. Got: %s' % ast2src(translator.tree.expr)) - if translator.groupby_monads: throw(TranslationError, - 'Delete query cannot contains GROUP BY section or aggregate functions') - assert not translator.having_conditions + force_in = False + if translator.groupby_monads: + force_in = True + else: + assert not translator.having_conditions tableref = expr_monad.tableref - from_ast = translator.subquery.from_ast - assert from_ast[0] == 'FROM' - if len(from_ast) == 2 and not translator.subquery.used_from_subquery: + from_ast = translator.sqlquery.from_ast + if from_ast[0] != 'FROM': + force_in = True + + if not force_in and len(from_ast) == 2 and not translator.sqlquery.used_from_subquery: sql_ast = [ 'DELETE', None, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) - elif translator.dialect == 'MySQL': + elif not force_in and translator.dialect == 'MySQL': sql_ast = [ 'DELETE', tableref.alias, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) @@ -574,56 +864,86 @@ def order_by_attributes(translator, attrs): new_order.append(desc_wrapper([ 'COLUMN', alias, column])) order[:0] = new_order return translator - def apply_kwfilters(translator, filterattrs): - entity = translator.expr_type - if not isinstance(entity, EntityMeta): - throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') + def apply_kwfilters(translator, filterattrs, original_names=False): translator = deepcopy(translator) - expr_monad = translator.tree.expr.monad - monads = [] - none_monad = translator.NoneMonad(translator) - for attr, id, is_none in filterattrs: - attr_monad = expr_monad.getattr(attr.name) - if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) + with translator: + if original_names: + object_monad = translator.tree.quals[0].iter.monad + assert isinstance(object_monad.type, EntityMeta) else: - param_monad = translator.ParamMonad.new(translator, attr.py_type, (id, None, None)) - monads.append(CmpMonad('==', attr_monad, param_monad)) - for m in monads: translator.conditions.extend(m.getsql()) - return translator - def apply_lambda(translator, filter_num, order_by, func_ast, argnames, extractors, vartypes): + object_monad = translator.tree.expr.monad + if not isinstance(object_monad.type, EntityMeta): + throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') + + monads = [] + none_monad = NoneMonad() + for attr, id, is_none in filterattrs: + attr_monad = object_monad.getattr(attr.name) + if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) + else: + param_monad = ParamMonad.new(attr.py_type, (id, None, None)) + monads.append(CmpMonad('==', attr_monad, param_monad)) + for m in monads: translator.conditions.extend(m.getsql()) + return translator + def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, original_names, extractors, vars, vartypes): translator = deepcopy(translator) - pickled_func_ast = pickle.dumps(func_ast, 2) - func_ast = pickle.loads(pickled_func_ast) # func_ast = deepcopy(func_ast) + func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) + translator.code_key = func_id translator.filter_num = filter_num translator.extractors.update(extractors) + translator.vars = vars + translator.vartypes = translator.vartypes.copy() # make HashableDict mutable again translator.vartypes.update(vartypes) - translator.argnames = list(argnames) - translator.dispatch(func_ast) - if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes - else: nodes = (func_ast,) - if order_by: - new_order = [] - for node in nodes: - if isinstance(node.monad, translator.SetMixin): - t = node.monad.type.item_type - if isinstance(type(t), type): t = t.__name__ - throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' - % (t, ast2src(node))) - new_order.extend(node.monad.getsql()) - translator.order[:0] = new_order + + if not original_names: + assert argnames + namespace = {name: monad for name, monad in izip(argnames, translator.expr_monads)} + elif argnames: + namespace = {name: translator.namespace[name] for name in argnames} else: - for node in nodes: - monad = node.monad - if isinstance(monad, translator.AndMonad): cond_monads = monad.operands - else: cond_monads = [ monad ] - for m in cond_monads: - if not m.aggregated: translator.conditions.extend(m.getsql()) - else: translator.having_conditions.extend(m.getsql()) - return translator + namespace = None + if namespace is not None: + translator.namespace_stack.append(namespace) + + with translator: + try: + translator.dispatch(func_ast) + if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes + else: nodes = (func_ast,) + if order_by: + translator.inside_order_by = True + new_order = [] + for node in nodes: + if isinstance(node.monad, SetMixin): + t = node.monad.type.item_type + if isinstance(type(t), type): t = t.__name__ + throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' + % (t, ast2src(node))) + new_order.extend(node.monad.getsql()) + translator.order[:0] = new_order + translator.inside_order_by = False + else: + for node in nodes: + monad = node.monad + if isinstance(monad, AndMonad): cond_monads = monad.operands + else: cond_monads = [ monad ] + for m in cond_monads: + if not m.aggregated: translator.conditions.extend(m.getsql()) + else: translator.having_conditions.extend(m.getsql()) + translator.vars = None + return translator + finally: + if namespace is not None: + ns = translator.namespace_stack.pop() + assert ns is namespace def preGenExpr(translator, node): inner_tree = node.code - subtranslator = translator.__class__(inner_tree, translator.extractors, translator.vartypes, translator) - return translator.QuerySetMonad(translator, subtranslator) + translator_cls = translator.__class__ + try: + subtranslator = translator_cls(inner_tree, translator) + except UseAnotherTranslator: + assert False + return QuerySetMonad(subtranslator) def postGenExprIf(translator, node): monad = node.test.monad if monad.type is not bool: monad = monad.nonzero() @@ -633,12 +953,9 @@ def preCompare(translator, node): ops = node.ops left = node.expr translator.dispatch(left) - inside_not = translator.inside_not # op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '==' # | 'in' | 'not in' | 'is' | 'is not' for op, right in node.ops: - translator.inside_not = inside_not - if op == 'not in': translator.inside_not = not inside_not translator.dispatch(right) if op.endswith('in'): monad = right.monad.contains(left.monad, op == 'not in') else: monad = left.monad.cmp(op, right.monad) @@ -650,29 +967,34 @@ def preCompare(translator, node): 'Too complex aggregation, expressions cannot be combined: {EXPR}') monads.append(monad) left = right - translator.inside_not = inside_not if len(monads) == 1: return monads[0] - return translator.AndMonad(monads) + return AndMonad(monads) def postConst(translator, node): value = node.value + if type(value) is frozenset: + value = tuple(sorted(value)) if type(value) is not tuple: - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(value) else: - return translator.ListMonad(translator, [ translator.ConstMonad.new(translator, item) for item in value ]) + return ListMonad([ ConstMonad.new(item) for item in value ]) + def postEllipsis(translator, node): + return ConstMonad.new(Ellipsis) def postList(translator, node): - return translator.ListMonad(translator, [ item.monad for item in node.nodes ]) + return ListMonad([ item.monad for item in node.nodes ]) def postTuple(translator, node): - return translator.ListMonad(translator, [ item.monad for item in node.nodes ]) + return ListMonad([ item.monad for item in node.nodes ]) def postName(translator, node): - name = node.name - argnames = translator.argnames - if translator.argnames and name in translator.argnames: - i = translator.argnames.index(name) - return translator.expr_monads[i] - tableref = translator.subquery.get_tableref(name) - if tableref is not None: - return translator.ObjectIterMonad(translator, tableref, tableref.entity) - else: assert False, name # pragma: no cover + monad = translator.resolve_name(node.name) + assert monad is not None + return monad + def resolve_name(translator, name): + if name not in translator.namespace: + throw(TranslationError, 'Name %s is not found in %s' % (name, translator.namespace)) + monad = translator.namespace[name] + assert isinstance(monad, Monad) + if monad.translator is not translator: + monad.translator.sqlquery.used_from_subquery = True + return monad def postAdd(translator, node): return node.left.monad + node.right.monad def postSub(translator, node): @@ -683,6 +1005,8 @@ def postDiv(translator, node): return node.left.monad / node.right.monad def postFloorDiv(translator, node): return node.left.monad // node.right.monad + def postMod(translator, node): + return node.left.monad % node.right.monad def postPower(translator, node): return node.left.monad ** node.right.monad def postUnarySub(translator, node): @@ -690,36 +1014,46 @@ def postUnarySub(translator, node): def postGetattr(translator, node): return node.expr.monad.getattr(node.attrname) def postAnd(translator, node): - return translator.AndMonad([ subnode.monad for subnode in node.nodes ]) + return AndMonad([ subnode.monad for subnode in node.nodes ]) def postOr(translator, node): - return translator.OrMonad([ subnode.monad for subnode in node.nodes ]) - def preNot(translator, node): - translator.inside_not = not translator.inside_not + return OrMonad([ subnode.monad for subnode in node.nodes ]) + def postBitor(translator, node): + left, right = (subnode.monad for subnode in node.nodes) + return left | right + def postBitand(translator, node): + left, right = (subnode.monad for subnode in node.nodes) + return left & right + def postBitxor(translator, node): + left, right = (subnode.monad for subnode in node.nodes) + return left ^ right def postNot(translator, node): - translator.inside_not = not translator.inside_not return node.expr.monad.negate() def preCallFunc(translator, node): if node.star_args is not None: throw(NotImplementedError, '*%s is not supported' % ast2src(node.star_args)) if node.dstar_args is not None: throw(NotImplementedError, '**%s is not supported' % ast2src(node.dstar_args)) - if not isinstance(node.node, (ast.Name, ast.Getattr)): throw(NotImplementedError) + func_node = node.node + if isinstance(func_node, ast.CallFunc): + if isinstance(func_node.node, ast.Name) and func_node.node.name == 'getattr': return + if not isinstance(func_node, (ast.Name, ast.Getattr)): throw(NotImplementedError) if len(node.args) > 1: return if not node.args: return arg = node.args[0] if isinstance(arg, ast.GenExpr): - translator.dispatch(node.node) - func_monad = node.node.monad + translator.dispatch(func_node) + func_monad = func_node.monad translator.dispatch(arg) query_set_monad = arg.monad return func_monad(query_set_monad) if not isinstance(arg, ast.Lambda): return lambda_expr = arg - translator.dispatch(node.node) - method_monad = node.node.monad + translator.dispatch(func_node) + method_monad = func_node.monad if not isinstance(method_monad, MethodMonad): throw(NotImplementedError) entity_monad = method_monad.parent - if not isinstance(entity_monad, EntityMonad): throw(NotImplementedError) + if not isinstance(entity_monad, (EntityMonad, AttrSetMonad)): throw(NotImplementedError) entity = entity_monad.type.item_type - if method_monad.attrname != 'select': throw(TypeError) + method_name = method_monad.attrname + if method_name not in ('select', 'filter', 'exists'): throw(TypeError) if len(lambda_expr.argnames) != 1: throw(TypeError) if lambda_expr.varargs: throw(TypeError) if lambda_expr.kwargs: throw(TypeError) @@ -731,8 +1065,15 @@ def preCallFunc(translator, node): name_ast.monad = entity_monad for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) - subtranslator = translator.__class__(inner_expr, translator.extractors, translator.vartypes, translator) - return translator.QuerySetMonad(translator, subtranslator) + translator_cls = translator.__class__ + try: + subtranslator = translator_cls(inner_expr, translator) + except UseAnotherTranslator: + assert False + monad = QuerySetMonad(subtranslator) + if method_name == 'exists': + monad = monad.nonzero() + return monad def postCallFunc(translator, node): args = [] kwargs = {} @@ -741,8 +1082,6 @@ def postCallFunc(translator, node): kwargs[arg.name] = arg.expr.monad else: args.append(arg.monad) func_monad = node.node.monad - if isinstance(func_monad, ErrorSpecialFuncMonad): - 'Function %r cannot be used in this way: %s' % (func_monad.func.__name__, ast2src(node)) return func_monad(*args, **kwargs) def postKeyword(translator, node): pass # this node will be processed by postCallFunc @@ -752,14 +1091,17 @@ def postSubscript(translator, node): if len(node.subs) > 1: for x in node.subs: if isinstance(x, ast.Sliceobj): throw(TypeError) - key = translator.ListMonad(translator, [ item.monad for item in node.subs ]) + key = ListMonad([ item.monad for item in node.subs ]) return node.expr.monad[key] sub = node.subs[0] if isinstance(sub, ast.Sliceobj): start, stop, step = (sub.nodes+[None])[:3] if start is not None: start = start.monad + if isinstance(start, NoneMonad): start = None if stop is not None: stop = stop.monad + if isinstance(stop, NoneMonad): stop = None if step is not None: step = step.monad + if isinstance(step, NoneMonad): step = None return node.expr.monad[start:stop:step] else: return node.expr.monad[sub.monad] def postSlice(translator, node): @@ -767,8 +1109,10 @@ def postSlice(translator, node): expr_monad = node.expr.monad upper = node.upper if upper is not None: upper = upper.monad + if isinstance(upper, NoneMonad): upper = None lower = node.lower if lower is not None: lower = lower.monad + if isinstance(lower, NoneMonad): lower = None return expr_monad[lower:upper] def postSliceobj(translator, node): pass @@ -781,72 +1125,120 @@ def postIfExp(translator, node): elif not translator.row_value_syntax: throw(NotImplementedError) else: then_sql, else_sql = [ 'ROW' ] + then_sql, [ 'ROW' ] + else_sql expr = [ 'CASE', None, [ [ test_sql, then_sql ] ], else_sql ] - result = translator.ExprMonad.new(translator, result_type, expr) + result = ExprMonad.new(result_type, expr, + nullable=test_monad.nullable or then_monad.nullable or else_monad.nullable) result.aggregated = test_monad.aggregated or then_monad.aggregated or else_monad.aggregated return result + def postStr(translator, node): + val_monad = node.value.monad + if isinstance(val_monad, StringMixin): + return val_monad + sql = ['TO_STR', val_monad.getsql()[0] ] + return StringExprMonad(unicode, sql, nullable=val_monad.nullable) + def postJoinedStr(translator, node): + nullable = False + for subnode in node.values: + assert isinstance(subnode.monad, StringMixin), (subnode.monad, subnode) + if subnode.monad.nullable: + nullable = True + sql = [ 'CONCAT' ] + [ value.monad.getsql()[0] for value in node.values ] + return StringExprMonad(unicode, sql, nullable=nullable) + def postFormattedValue(translator, node): + throw(NotImplementedError, 'You cannot set width and precision markers in query') + +def combine_limit_and_offset(limit, offset, limit2, offset2): + assert limit is None or limit >= 0 + assert limit2 is None or limit2 >= 0 + + if offset2 is not None: + if limit is not None: + limit = max(0, limit - offset2) + offset = (offset or 0) + offset2 + + if limit2 is not None: + if limit is not None: + limit = min(limit, limit2) + else: + limit = limit2 -def coerce_monads(m1, m2): + if limit == 0: + offset = None + + return limit, offset + +def coerce_monads(m1, m2, for_comparison=False): result_type = coerce_types(m1.type, m2.type) - if result_type in numeric_types and bool in (m1.type, m2.type) and result_type is not bool: + if result_type in numeric_types and bool in (m1.type, m2.type) and ( + result_type is not bool or not for_comparison): translator = m1.translator if translator.dialect == 'PostgreSQL': + if result_type is bool: + result_type = int if m1.type is bool: - new_m1 = NumericExprMonad(translator, int, [ 'TO_INT', m1.getsql()[0] ]) + new_m1 = NumericExprMonad(int, [ 'TO_INT', m1.getsql()[0] ], nullable=m1.nullable) new_m1.aggregated = m1.aggregated m1 = new_m1 if m2.type is bool: - new_m2 = NumericExprMonad(translator, int, [ 'TO_INT', m2.getsql()[0] ]) + new_m2 = NumericExprMonad(int, [ 'TO_INT', m2.getsql()[0] ], nullable=m2.nullable) new_m2.aggregated = m2.aggregated m2 = new_m2 return result_type, m1, m2 max_alias_length = 30 -class Subquery(object): - def __init__(subquery, parent_subquery=None, left_join=False): - subquery.parent_subquery = parent_subquery - subquery.left_join = left_join - subquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] - subquery.conditions = [] - subquery.tablerefs = {} - if parent_subquery is None: - subquery.alias_counters = {} - subquery.expr_counter = itertools.count(1) +class SqlQuery(object): + def __init__(sqlquery, translator, parent_sqlquery=None, left_join=False): + sqlquery.translator = translator + sqlquery.parent_sqlquery = parent_sqlquery + sqlquery.left_join = left_join + sqlquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] + sqlquery.conditions = [] + sqlquery.outer_conditions = [] + sqlquery.tablerefs = {} + if parent_sqlquery is None: + sqlquery.alias_counters = {} + sqlquery.expr_counter = itertools.count(1) else: - subquery.alias_counters = parent_subquery.alias_counters.copy() - subquery.expr_counter = parent_subquery.expr_counter - subquery.used_from_subquery = False - def get_tableref(subquery, name_path, from_subquery=False): - tableref = subquery.tablerefs.get(name_path) - if tableref is not None: - if from_subquery and subquery.parent_subquery is None: - subquery.used_from_subquery = True - return tableref - if subquery.parent_subquery: - return subquery.parent_subquery.get_tableref(name_path, from_subquery=True) - return None - __contains__ = get_tableref - def add_tableref(subquery, name_path, parent_tableref, attr): - tablerefs = subquery.tablerefs - assert name_path not in tablerefs - tableref = JoinedTableRef(subquery, name_path, parent_tableref, attr) - tablerefs[name_path] = tableref + sqlquery.alias_counters = parent_sqlquery.alias_counters.copy() + sqlquery.expr_counter = parent_sqlquery.expr_counter + sqlquery.used_from_subquery = False + def get_tableref(sqlquery, name_path): + tableref = sqlquery.tablerefs.get(name_path) + parent_sqlquery = sqlquery.parent_sqlquery + if tableref is None and parent_sqlquery: + tableref = parent_sqlquery.get_tableref(name_path) + if tableref is not None: + parent_sqlquery.used_from_subquery = True return tableref - def get_short_alias(subquery, name_path, entity_name): - if name_path: - if is_ident(name_path): return name_path - if not options.SIMPLE_ALIASES and len(name_path) <= max_alias_length: - return name_path - name = entity_name[:max_alias_length-3].lower() - i = subquery.alias_counters.setdefault(name, 0) + 1 - alias = '%s-%d' % (name, i) - subquery.alias_counters[name] = i + def add_tableref(sqlquery, name_path, parent_tableref, attr): + assert name_path not in sqlquery.tablerefs + if parent_tableref.sqlquery is not sqlquery: + parent_tableref.sqlquery.used_from_subquery = True + tableref = JoinedTableRef(sqlquery, name_path, parent_tableref, attr) + sqlquery.tablerefs[name_path] = tableref + return tableref + def make_alias(sqlquery, name): + name = name[:max_alias_length-3].lower() + i = sqlquery.alias_counters.setdefault(name, 0) + 1 + alias = name if i == 1 and name != 't' else '%s-%d' % (name, i) + sqlquery.alias_counters[name] = i return alias + def join_table(sqlquery, parent_alias, alias, table_name, join_cond): + new_item = [alias, 'TABLE', table_name, join_cond] + from_ast = sqlquery.from_ast + for i in xrange(1, len(from_ast)): + if from_ast[i][0] == parent_alias: + for j in xrange(i+1, len(from_ast)): + if len(from_ast[j]) < 4: # item without join condition + from_ast.insert(j, new_item) + return + from_ast.append(new_item) class TableRef(object): - def __init__(tableref, subquery, name, entity): - tableref.subquery = subquery - tableref.alias = tableref.name_path = name + def __init__(tableref, sqlquery, name, entity): + tableref.sqlquery = sqlquery + tableref.alias = sqlquery.make_alias(name) + tableref.name_path = tableref.alias tableref.entity = entity tableref.joined = False tableref.can_affect_distinct = True @@ -854,19 +1246,79 @@ def __init__(tableref, subquery, name, entity): def make_join(tableref, pk_only=False): entity = tableref.entity if not tableref.joined: - subquery = tableref.subquery - subquery.from_ast.append([ tableref.alias, 'TABLE', entity._table_ ]) + sqlquery = tableref.sqlquery + sqlquery.from_ast.append([ tableref.alias, 'TABLE', entity._table_ ]) if entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(tableref.alias) assert discr_criteria is not None - subquery.conditions.append(discr_criteria) + sqlquery.conditions.append(discr_criteria) tableref.joined = True return tableref.alias, entity._pk_columns_ +class ExprTableRef(TableRef): + def __init__(tableref, sqlquery, name, subquery_ast, expr_names, expr_aliases): + TableRef.__init__(tableref, sqlquery, name, None) + tableref.subquery_ast = subquery_ast + tableref.expr_names = expr_names + tableref.expr_aliases = expr_aliases + def make_join(tableref, pk_only=False): + assert tableref.subquery_ast[0] == 'SELECT' + if not tableref.joined: + sqlquery = tableref.sqlquery + sqlquery.from_ast.append([tableref.alias, 'SELECT', tableref.subquery_ast[1:]]) + tableref.joined = True + return tableref.alias, None + +class StarTableRef(TableRef): + def __init__(tableref, sqlquery, name, entity, subquery_ast): + TableRef.__init__(tableref, sqlquery, name, entity) + tableref.subquery_ast = subquery_ast + def make_join(tableref, pk_only=False): + entity = tableref.entity + assert tableref.subquery_ast[0] == 'SELECT' + if not tableref.joined: + sqlquery = tableref.sqlquery + sqlquery.from_ast.append([ tableref.alias, 'SELECT', tableref.subquery_ast[1:] ]) + if entity._discriminator_attr_: # ??? + discr_criteria = entity._construct_discriminator_criteria_(tableref.alias) + assert discr_criteria is not None + sqlquery.conditions.append(discr_criteria) + tableref.joined = True + return tableref.alias, entity._pk_columns_ + +class ExprJoinedTableRef(object): + def __init__(tableref, sqlquery, parent_tableref, parent_columns, name, entity): + tableref.sqlquery = sqlquery + tableref.parent_tableref = parent_tableref + tableref.parent_columns = parent_columns + tableref.name = tableref.name_path = name + tableref.entity = entity + tableref.alias = None + tableref.joined = False + tableref.can_affect_distinct = False + tableref.used_attrs = set() + def make_join(tableref, pk_only=False): + entity = tableref.entity + if tableref.joined: + return tableref.alias, tableref.pk_columns + sqlquery = tableref.sqlquery + parent_alias, left_pk_columns = tableref.parent_tableref.make_join() + if pk_only: + tableref.alias = parent_alias + tableref.pk_columns = tableref.parent_columns + return tableref.alias, tableref.pk_columns + tableref.alias = sqlquery.make_alias(tableref.name) + tableref.pk_columns = entity._pk_columns_ + join_cond = join_tables(parent_alias, tableref.alias, tableref.parent_columns, tableref.pk_columns) + sqlquery.join_table(parent_alias, tableref.alias, entity._table_, join_cond) + tableref.joined = True + return tableref.alias, tableref.pk_columns + class JoinedTableRef(object): - def __init__(tableref, subquery, name_path, parent_tableref, attr): - tableref.subquery = subquery + def __init__(tableref, sqlquery, name_path, parent_tableref, attr): + tableref.sqlquery = sqlquery tableref.name_path = name_path + tableref.var_name = name_path if is_ident(name_path) else None tableref.alias = None tableref.optimized = None tableref.parent_tableref = parent_tableref @@ -881,7 +1333,7 @@ def make_join(tableref, pk_only=False): if tableref.joined: if pk_only or not tableref.optimized: return tableref.alias, tableref.pk_columns - subquery = tableref.subquery + sqlquery = tableref.sqlquery attr = tableref.attr parent_pk_only = attr.pk_offset is not None or attr.is_collection parent_alias, left_pk_columns = tableref.parent_tableref.make_join(parent_pk_only) @@ -889,13 +1341,15 @@ def make_join(tableref, pk_only=False): pk_columns = entity._pk_columns_ if not attr.is_collection: if not attr.columns: + # one-to-one relationship with foreign key column on the right side reverse = attr.reverse assert reverse.columns and not reverse.is_collection rentity = reverse.entity pk_columns = rentity._pk_columns_ - alias = subquery.get_short_alias(tableref.name_path, rentity.__name__) + alias = sqlquery.make_alias(tableref.var_name or rentity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, reverse.columns) else: + # one-to-one or many-to-one relationship with foreign key column on the left side if attr.pk_offset is not None: offset = attr.pk_columns_offset left_columns = left_pk_columns[offset:offset+len(attr.columns)] @@ -904,21 +1358,23 @@ def make_join(tableref, pk_only=False): tableref.alias = parent_alias tableref.pk_columns = left_columns tableref.optimized = True - tableref.joined = True + # tableref.joined = True return parent_alias, left_columns - alias = subquery.get_short_alias(tableref.name_path, entity.__name__) + alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_columns, pk_columns) elif not attr.reverse.is_collection: - alias = subquery.get_short_alias(tableref.name_path, entity.__name__) + # many-to-one relationship + alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, attr.reverse.columns) else: + # many-to-many relationship right_m2m_columns = attr.reverse_columns if attr.symmetric else attr.columns if not tableref.joined: m2m_table = attr.table - m2m_alias = subquery.get_short_alias(None, 't') + m2m_alias = sqlquery.make_alias('t') reverse_columns = attr.columns if attr.symmetric else attr.reverse.columns m2m_join_cond = join_tables(parent_alias, m2m_alias, left_pk_columns, reverse_columns) - subquery.from_ast.append([ m2m_alias, 'TABLE', m2m_table, m2m_join_cond ]) + sqlquery.join_table(parent_alias, m2m_alias, m2m_table, m2m_join_cond) if pk_only: tableref.alias = m2m_alias tableref.pk_columns = right_m2m_columns @@ -928,13 +1384,18 @@ def make_join(tableref, pk_only=False): elif tableref.optimized: assert not pk_only m2m_alias = tableref.alias - alias = subquery.get_short_alias(tableref.name_path, entity.__name__) + alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(m2m_alias, alias, right_m2m_columns, pk_columns) if not pk_only and entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(alias) assert discr_criteria is not None join_cond.append(discr_criteria) - subquery.from_ast.append([ alias, 'TABLE', entity._table_, join_cond ]) + + translator = tableref.sqlquery.translator.root_translator + if translator.optimize == tableref.name_path and translator.from_optimized and tableref.sqlquery is translator.sqlquery: + pass + else: + sqlquery.join_table(parent_alias, alias, entity._table_, join_cond) tableref.alias = alias tableref.pk_columns = pk_columns tableref.optimized = False @@ -960,35 +1421,40 @@ class MonadMixin(with_metaclass(MonadMeta)): pass class Monad(with_metaclass(MonadMeta)): - def __init__(monad, translator, type): - monad.translator = translator + disable_distinct = False + disable_ordering = False + def __init__(monad, type, nullable=True): + monad.node = None + monad.translator = local.translator monad.type = type + monad.nullable = nullable monad.mixin_init() def mixin_init(monad): pass def cmp(monad, op, monad2): - return monad.translator.CmpMonad(op, monad, monad2) + return CmpMonad(op, monad, monad2) def contains(monad, item, not_in=False): throw(TypeError) - def nonzero(monad): throw(TypeError) + def nonzero(monad): + return CmpMonad('is not', monad, NoneMonad()) def negate(monad): - return monad.translator.NotMonad(monad) + return NotMonad(monad) def getattr(monad, attrname): try: property_method = getattr(monad, 'attr_' + attrname) except AttributeError: if not hasattr(monad, 'call_' + attrname): - throw(AttributeError, '%r object has no attribute %r' % (type2str(monad.type), attrname)) - translator = monad.translator - return translator.MethodMonad(translator, monad, attrname) + throw(AttributeError, '%r object has no attribute %r: {EXPR}' % (type2str(monad.type), attrname)) + return MethodMonad(monad, attrname) return property_method() def len(monad): throw(TypeError) - def count(monad): + def count(monad, distinct=None): + distinct = distinct_from_monad(distinct, default=True) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr = monad.getsql() - count_kind = 'DISTINCT' + if monad.type is bool: expr = [ 'CASE', None, [ [ expr[0], [ 'VALUE', 1 ] ] ], [ 'VALUE', None ] ] - count_kind = 'ALL' + distinct = None elif len(expr) == 1: expr = expr[0] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr @@ -1005,37 +1471,49 @@ def count(monad): '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR}' % translator.dialect) - result = translator.ExprMonad.new(translator, int, [ 'COUNT', count_kind, expr ]) + result = ExprMonad.new(int, [ 'COUNT', distinct, expr ], nullable=False) result.aggregated = True return result - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None, sep=None): + distinct = distinct_from_monad(distinct) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr_type = monad.type # if isinstance(expr_type, SetType): expr_type = expr_type.item_type if func_name in ('SUM', 'AVG'): if expr_type not in numeric_types: - throw(TypeError, "Function '%s' expects argument of numeric type, got %r in {EXPR}" - % (func_name, type2str(expr_type))) + if expr_type is Json: monad = monad.to_real() + else: throw(TypeError, "Function '%s' expects argument of numeric type, got %r in {EXPR}" + % (func_name, type2str(expr_type))) elif func_name in ('MIN', 'MAX'): if expr_type not in comparable_types: throw(TypeError, "Function '%s' cannot be applied to type %r in {EXPR}" % (func_name, type2str(expr_type))) + elif func_name == 'GROUP_CONCAT': + if isinstance(expr_type, EntityMeta) and expr_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover expr = monad.getsql() if len(expr) == 1: expr = expr[0] - elif translator.row_value_syntax == True: expr = ['ROW'] + expr + elif translator.row_value_syntax: expr = ['ROW'] + expr else: throw(NotImplementedError, '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR} ' '(you can suggest us how to write SQL for this query)' % translator.dialect) - if func_name == 'AVG': result_type = float - else: result_type = expr_type - aggr_ast = [ func_name, expr ] - if getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG'): - aggr_ast.append(True) - result = translator.ExprMonad.new(translator, result_type, aggr_ast) + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = expr_type + if distinct is None: + distinct = getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG') + aggr_ast = [ func_name, distinct, expr ] + if func_name == 'GROUP_CONCAT': + if sep is not None: + aggr_ast.append(['VALUE', sep]) + result = ExprMonad.new(result_type, aggr_ast, nullable=True) result.aggregated = True return result def __call__(monad, *args, **kwargs): throw(TypeError) @@ -1047,13 +1525,30 @@ def __truediv__(monad, monad2): throw(TypeError) def __floordiv__(monad, monad2): throw(TypeError) def __pow__(monad, monad2): throw(TypeError) def __neg__(monad): throw(TypeError) + def __or__(monad): throw(TypeError) + def __and__(monad): throw(TypeError) + def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) + def cast_from_json(monad, type): assert False, monad + def to_int(monad): + return NumericExprMonad(int, [ 'TO_INT', monad.getsql()[0] ], nullable=monad.nullable) + def to_str(monad): + return StringExprMonad(unicode, [ 'TO_STR', monad.getsql()[0] ], nullable=monad.nullable) + def to_real(monad): + return NumericExprMonad(float, [ 'TO_REAL', monad.getsql()[0] ], nullable=monad.nullable) + +def distinct_from_monad(distinct, default=None): + if distinct is None: + return default + if isinstance(distinct, NumericConstMonad) and isinstance(distinct.value, bool): + return distinct.value + throw(TypeError, '`distinct` value should be True or False. Got: %s' % ast2src(distinct.node)) class RawSQLMonad(Monad): - def __init__(monad, translator, rawtype, varkey): + def __init__(monad, rawtype, varkey, nullable=True): if rawtype.result_type is None: type = rawtype else: type = normalize_type(rawtype.result_type) - Monad.__init__(monad, translator, type) + Monad.__init__(monad, type, nullable=nullable) monad.rawtype = rawtype monad.varkey = varkey def contains(monad, item, not_in=False): @@ -1065,9 +1560,9 @@ def contains(monad, item, not_in=False): '%s database provider does not support tuples. Got: {EXPR} ' % translator.dialect) op = 'NOT_IN' if not_in else 'IN' sql = [ op, expr, monad.getsql() ] - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(sql, nullable=item.nullable) def nonzero(monad): return monad - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): provider = monad.translator.database.provider rawtype = monad.rawtype result = [] @@ -1121,8 +1616,8 @@ def raise_forgot_parentheses(monad): throw(TranslationError, 'You seems to forgot parentheses after %s' % ast2src(monad.node)) class MethodMonad(Monad): - def __init__(monad, translator, parent, attrname): - Monad.__init__(monad, translator, 'METHOD') + def __init__(monad, parent, attrname): + Monad.__init__(monad, 'METHOD', nullable=False) monad.parent = parent monad.attrname = attrname def getattr(monad, attrname): @@ -1135,7 +1630,7 @@ def __call__(monad, *args, **kwargs): def contains(monad, item, not_in=False): raise_forgot_parentheses(monad) def nonzero(monad): raise_forgot_parentheses(monad) def negate(monad): raise_forgot_parentheses(monad) - def aggregate(monad, func_name): raise_forgot_parentheses(monad) + def aggregate(monad, func_name, distinct=None, sep=None): raise_forgot_parentheses(monad) def __getitem__(monad, key): raise_forgot_parentheses(monad) def __add__(monad, monad2): raise_forgot_parentheses(monad) @@ -1149,8 +1644,9 @@ def __neg__(monad): raise_forgot_parentheses(monad) def abs(monad): raise_forgot_parentheses(monad) class EntityMonad(Monad): - def __init__(monad, translator, entity): - Monad.__init__(monad, translator, SetType(entity)) + def __init__(monad, entity): + Monad.__init__(monad, SetType(entity)) + translator = monad.translator if translator.database is None: translator.database = entity._database_ elif translator.database is not entity._database_: @@ -1159,12 +1655,13 @@ def __getitem__(monad, *args): throw(NotImplementedError) class ListMonad(Monad): - def __init__(monad, translator, items): - Monad.__init__(monad, translator, tuple(item.type for item in items)) + def __init__(monad, items): + Monad.__init__(monad, tuple(item.type for item in items)) monad.items = items def contains(monad, x, not_in=False): - translator = monad.translator - for item in monad.items: check_comparable(item, x) + if isinstance(x.type, SetType): throw(TypeError, + "Type of `%s` is '%s'. Expression `{EXPR}` is not supported" % (ast2src(x.node), type2str(x.type))) + for item in monad.items: check_comparable(x, item) left_sql = x.getsql() if len(left_sql) == 1: if not_in: sql = [ 'NOT_IN', left_sql[0], [ item.getsql()[0] for item in monad.items ] ] @@ -1173,8 +1670,8 @@ def contains(monad, x, not_in=False): sql = sqland([ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) else: sql = sqlor([ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) - return translator.BoolExprMonad(translator, sql) - def getsql(monad, subquery=None): + return BoolExprMonad(sql, nullable=x.nullable or any(item.nullable for item in monad.items)) + def getsql(monad, sqlquery=None): return [ [ 'ROW' ] + [ item.getsql()[0] for item in monad.items ] ] class BufferMixin(MonadMixin): @@ -1187,16 +1684,15 @@ class UuidMixin(MonadMixin): def make_numeric_binop(op, sqlop): def numeric_binop(monad, monad2): - translator = monad.translator - if isinstance(monad2, (translator.AttrSetMonad, translator.NumericSetExprMonad)): - return translator.NumericSetExprMonad(op, sqlop, monad, monad2) + if isinstance(monad2, (AttrSetMonad, NumericSetExprMonad)): + return NumericSetExprMonad(op, sqlop, monad, monad2) if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) result_type, monad, monad2 = coerce_monads(monad, monad2) if result_type is None: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql()[0] right_sql = monad2.getsql()[0] - return translator.NumericExprMonad(translator, result_type, [ sqlop, left_sql, right_sql ]) + return NumericExprMonad(result_type, [ sqlop, left_sql, right_sql ]) numeric_binop.__name__ = sqlop return numeric_binop @@ -1208,56 +1704,81 @@ def mixin_init(monad): __mul__ = make_numeric_binop('*', 'MUL') __truediv__ = make_numeric_binop('/', 'DIV') __floordiv__ = make_numeric_binop('//', 'FLOORDIV') + __mod__ = make_numeric_binop('%', 'MOD') def __pow__(monad, monad2): - translator = monad.translator - if not isinstance(monad2, translator.NumericMixin): + if not isinstance(monad2, NumericMixin): throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), '**')) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return translator.NumericExprMonad(translator, float, [ 'POW', left_sql[0], right_sql[0] ]) + return NumericExprMonad(float, [ 'POW', left_sql[0], right_sql[0] ], + nullable=monad.nullable or monad2.nullable) def __neg__(monad): sql = monad.getsql()[0] - translator = monad.translator - return translator.NumericExprMonad(translator, monad.type, [ 'NEG', sql ]) + return NumericExprMonad(monad.type, [ 'NEG', sql ], nullable=monad.nullable) def abs(monad): sql = monad.getsql()[0] - translator = monad.translator - return translator.NumericExprMonad(translator, monad.type, [ 'ABS', sql ]) + return NumericExprMonad(monad.type, [ 'ABS', sql ], nullable=monad.nullable) def nonzero(monad): translator = monad.translator - return translator.CmpMonad('!=', monad, translator.ConstMonad.new(translator, 0)) + sql = monad.getsql()[0] + if not (translator.dialect == 'PostgreSQL' and monad.type is bool): + sql = [ 'NE', sql, [ 'VALUE', 0 ] ] + return BoolExprMonad(sql, nullable=False) def negate(monad): + sql = monad.getsql()[0] translator = monad.translator - return translator.CmpMonad('==', monad, translator.ConstMonad.new(translator, 0)) + pg_bool = translator.dialect == 'PostgreSQL' and monad.type is bool + result_sql = [ 'NOT', sql ] if pg_bool else [ 'EQ', sql, [ 'VALUE', 0 ] ] + if monad.nullable: + if isinstance(monad, AttrMonad): + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + elif pg_bool: + result_sql = [ 'NOT', [ 'COALESCE', sql, [ 'VALUE', True ] ] ] + else: + result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', 0 ] ], [ 'VALUE', 0 ] ] + return BoolExprMonad(result_sql, nullable=False) def numeric_attr_factory(name): def attr_func(monad): sql = [ name, monad.getsql()[0] ] - translator = monad.translator - return translator.NumericExprMonad(translator, int, sql) + return NumericExprMonad(int, sql, nullable=monad.nullable) attr_func.__name__ = name.lower() return attr_func def make_datetime_binop(op, sqlop): def datetime_binop(monad, monad2): - translator = monad.translator if monad2.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) - expr_monad_cls = translator.DateExprMonad if monad.type is date else translator.DatetimeExprMonad - delta = monad2.value if isinstance(monad2, TimedeltaConstMonad) else monad2.getsql()[0] - return expr_monad_cls(translator, monad.type, [ sqlop, monad.getsql()[0], delta ]) + expr_monad_cls = DateExprMonad if monad.type is date else DatetimeExprMonad + return expr_monad_cls(monad.type, [ sqlop, monad.getsql()[0], monad2.getsql()[0] ], + nullable=monad.nullable or monad2.nullable) datetime_binop.__name__ = sqlop return datetime_binop class DateMixin(MonadMixin): def mixin_init(monad): assert monad.type is date + attr_year = numeric_attr_factory('YEAR') attr_month = numeric_attr_factory('MONTH') attr_day = numeric_attr_factory('DAY') - __add__ = make_datetime_binop('+', 'DATE_ADD') - __sub__ = make_datetime_binop('-', 'DATE_SUB') + + def __add__(monad, other): + if other.type != timedelta: + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '+')) + return DateExprMonad(monad.type, [ 'DATE_ADD', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + + def __sub__(monad, other): + if other.type == timedelta: + return DateExprMonad(monad.type, [ 'DATE_SUB', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + elif other.type == date: + return TimedeltaExprMonad(timedelta, [ 'DATE_DIFF', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '-')) + class TimeMixin(MonadMixin): def mixin_init(monad): @@ -1273,26 +1794,40 @@ def mixin_init(monad): class DatetimeMixin(DateMixin): def mixin_init(monad): assert monad.type is datetime + def call_date(monad): - translator = monad.translator sql = [ 'DATE', monad.getsql()[0] ] - return translator.ExprMonad.new(translator, date, sql) + return ExprMonad.new(date, sql, nullable=monad.nullable) + attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') - __add__ = make_datetime_binop('+', 'DATETIME_ADD') - __sub__ = make_datetime_binop('-', 'DATETIME_SUB') + + def __add__(monad, other): + if other.type != timedelta: + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '+')) + return DatetimeExprMonad(monad.type, [ 'DATETIME_ADD', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + + def __sub__(monad, other): + if other.type == timedelta: + return DatetimeExprMonad(monad.type, [ 'DATETIME_SUB', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + elif other.type == datetime: + return TimedeltaExprMonad(timedelta, [ 'DATETIME_DIFF', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '-')) def make_string_binop(op, sqlop): def string_binop(monad, monad2): - translator = monad.translator if not are_comparable_types(monad.type, monad2.type, sqlop): if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return translator.StringExprMonad(translator, monad.type, [ sqlop, left_sql[0], right_sql[0] ]) + return StringExprMonad(monad.type, [ sqlop, left_sql[0], right_sql[0] ], + nullable=monad.nullable or monad2.nullable) string_binop.__name__ = sqlop return string_binop @@ -1300,8 +1835,7 @@ def make_string_func(sqlop): def func(monad): sql = monad.getsql() assert len(sql) == 1 - translator = monad.translator - return translator.StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) + return StringExprMonad(monad.type, [ sqlop, sql[0] ], nullable=monad.nullable) func.__name__ = sqlop return func @@ -1310,20 +1844,37 @@ def mixin_init(monad): assert issubclass(monad.type, basestring), monad.type __add__ = make_string_binop('+', 'CONCAT') def __getitem__(monad, index): - translator = monad.translator - if isinstance(index, translator.ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") + root_translator = monad.translator.root_translator + dialect = root_translator.database.provider.dialect + + def param_to_const(monad, is_start=True): + if isinstance(monad, ParamMonad): + key = monad.paramkey[0] + if key in root_translator.fixed_param_values: + index_value = root_translator.fixed_param_values[key] + else: + index_value = root_translator.vars[key] + if index_value is None: + index_value = 0 if is_start else -1 + root_translator.fixed_param_values[key] = index_value + return ConstMonad.new(index_value) + return monad + + if isinstance(index, ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') start, stop = index.start, index.stop - if isinstance(start, NoneMonad): start = None - if isinstance(stop, NoneMonad): stop = None - if start is None and stop is None: return monad - if isinstance(monad, translator.StringConstMonad) \ - and (start is None or isinstance(start, translator.NumericConstMonad)) \ - and (stop is None or isinstance(stop, translator.NumericConstMonad)): - if start is not None: start = start.value - if stop is not None: stop = stop.value - return translator.ConstMonad.new(translator, monad.value[start:stop]) + start = param_to_const(start, is_start=True) + stop = param_to_const(stop, is_start=False) + start_value = stop_value = None + if start is None: start_value = 0 + if stop_value is None: stop_value = -1 + if isinstance(start, ConstMonad): start_value = start.value + if isinstance(stop, ConstMonad): stop_value = stop.value + if start_value == 0 and stop_value == -1: + return monad + if isinstance(monad, StringConstMonad) and start_value is not None and stop_value is not None: + return ConstMonad.new(monad.value[start_value:stop_value]) if start is not None and start.type is not int: throw(TypeError, "Invalid type of start index (expected 'int', got %r) in string slice {EXPR}" % type2str(start.type)) @@ -1331,57 +1882,64 @@ def __getitem__(monad, index): throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in string slice {EXPR}" % type2str(stop.type)) expr_sql = monad.getsql()[0] - if start is None: start = translator.ConstMonad.new(translator, 0) + start_sql = None if start is None else start.getsql()[0] + stop_sql = None if stop is None else stop.getsql()[0] + sql = [ 'STRING_SLICE', expr_sql, start_sql, stop_sql ] + return StringExprMonad(monad.type, sql, nullable= + monad.nullable or start is not None and start.nullable or stop is not None and stop.nullable) - if isinstance(start, translator.NumericConstMonad): - if start.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') - start_sql = [ 'VALUE', start.value + 1 ] - else: - start_sql = start.getsql()[0] - start_sql = [ 'ADD', start_sql, [ 'VALUE', 1 ] ] - - if stop is None: - len_sql = None - elif isinstance(stop, translator.NumericConstMonad): - if stop.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') - if isinstance(start, translator.NumericConstMonad): - len_sql = [ 'VALUE', stop.value - start.value ] - else: - len_sql = [ 'SUB', [ 'VALUE', stop.value ], start.getsql()[0] ] - else: - stop_sql = stop.getsql()[0] - if isinstance(start, translator.NumericConstMonad): - len_sql = [ 'SUB', stop_sql, [ 'VALUE', start.value ] ] - else: - len_sql = [ 'SUB', stop_sql, start.getsql()[0] ] - - sql = [ 'SUBSTR', expr_sql, start_sql, len_sql ] - return translator.StringExprMonad(translator, monad.type, sql) - - if isinstance(monad, translator.StringConstMonad) and isinstance(index, translator.NumericConstMonad): - return translator.ConstMonad.new(translator, monad.value[index.value]) + index = param_to_const(index) + if isinstance(monad, StringConstMonad) and isinstance(index, NumericConstMonad): + return ConstMonad.new(monad.value[index.value]) if index.type is not int: throw(TypeError, 'String indices must be integers. Got %r in expression {EXPR}' % type2str(index.type)) expr_sql = monad.getsql()[0] - if isinstance(index, translator.NumericConstMonad): + + if isinstance(index, NumericConstMonad): value = index.value - if value >= 0: value += 1 - index_sql = [ 'VALUE', value ] + if dialect == 'PostgreSQL' and value < 0: + index_sql = [ 'LENGTH', expr_sql ] + if value < -1: + index_sql = [ 'SUB', index_sql, [ 'VALUE', -(value + 1) ] ] + else: + if value >= 0: value += 1 + index_sql = [ 'VALUE', value ] else: inner_sql = index.getsql()[0] - index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] + then = ['ADD', inner_sql, ['VALUE', 1]] + else_ = [ 'ADD', ['LENGTH', expr_sql], then ] if dialect == 'PostgreSQL' else inner_sql + index_sql = [ 'IF', [ 'GE', inner_sql, [ 'VALUE', 0 ] ], then, else_ ] + sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] - return translator.StringExprMonad(translator, monad.type, sql) + return StringExprMonad(monad.type, sql, nullable=monad.nullable) + def negate(monad): + sql = monad.getsql()[0] + translator = monad.translator + if translator.dialect == 'Oracle': + result_sql = [ 'IS_NULL', sql ] + else: + result_sql = [ 'EQ', sql, [ 'VALUE', '' ] ] + if monad.nullable: + if isinstance(monad, AttrMonad): + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + else: + result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', '' ] ], [ 'VALUE', '' ]] + result = BoolExprMonad(result_sql, nullable=False) + result.aggregated = monad.aggregated + return result def nonzero(monad): sql = monad.getsql()[0] translator = monad.translator - result = translator.BoolExprMonad(translator, [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) + if translator.dialect == 'Oracle': + result_sql = [ 'IS_NOT_NULL', sql ] + else: + result_sql = [ 'NE', sql, [ 'VALUE', '' ] ] + result = BoolExprMonad(result_sql, nullable=False) result.aggregated = monad.aggregated return result def len(monad): sql = monad.getsql()[0] - translator = monad.translator - return translator.NumericExprMonad(translator, int, [ 'LENGTH', sql ]) + return NumericExprMonad(int, [ 'LENGTH', sql ]) def contains(monad, item, not_in=False): check_comparable(item, monad, 'LIKE') return monad._like(item, before='%', after='%', not_like=not_in) @@ -1402,7 +1960,7 @@ def call_endswith(monad, arg): def _like(monad, item, before=None, after=None, not_like=False): escape = False translator = monad.translator - if isinstance(item, translator.StringConstMonad): + if isinstance(item, StringConstMonad): value = item.value if '%' in value or '_' in value: escape = True @@ -1419,11 +1977,15 @@ def _like(monad, item, before=None, after=None, not_like=False): if before and after: item_sql = [ 'CONCAT', [ 'VALUE', before ], item_sql, [ 'VALUE', after ] ] elif before: item_sql = [ 'CONCAT', [ 'VALUE', before ], item_sql ] elif after: item_sql = [ 'CONCAT', item_sql, [ 'VALUE', after ] ] - sql = [ 'NOT_LIKE' if not_like else 'LIKE', monad.getsql()[0], item_sql ] - if escape: sql.append([ 'VALUE', '!' ]) - return translator.BoolExprMonad(translator, sql) + sql = monad.getsql()[0] + if not_like and monad.nullable and not isinstance(monad, AttrMonad) and translator.dialect != 'Oracle': + sql = [ 'COALESCE', sql, [ 'VALUE', '' ] ] + result_sql = [ 'NOT_LIKE' if not_like else 'LIKE', sql, item_sql ] + if escape: result_sql.append([ 'VALUE', '!' ]) + if not_like and monad.nullable and (isinstance(monad, AttrMonad) or translator.dialect == 'Oracle'): + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + return BoolExprMonad(result_sql, nullable=not_like) def strip(monad, chars, strip_type): - translator = monad.translator if chars is not None and not are_comparable_types(monad.type, chars.type, None): if chars.type == 'METHOD': raise_forgot_parentheses(chars) throw(TypeError, "'chars' argument must be of %r type in {EXPR}, got: %r" @@ -1431,7 +1993,7 @@ def strip(monad, chars, strip_type): parent_sql = monad.getsql()[0] sql = [ strip_type, parent_sql ] if chars is not None: sql.append(chars.getsql()[0]) - return translator.StringExprMonad(translator, monad.type, sql) + return StringExprMonad(monad.type, sql, nullable=monad.nullable) def call_strip(monad, chars=None): return monad.strip(chars, 'TRIM') def call_lstrip(monad, chars=None): @@ -1439,34 +2001,141 @@ def call_lstrip(monad, chars=None): def call_rstrip(monad, chars=None): return monad.strip(chars, 'RTRIM') +class JsonMixin(object): + disable_distinct = True # at least in Oracle we cannot use DISTINCT with JSON column + disable_ordering = True # at least in Oracle we cannot use ORDER BY with JSON column + + def mixin_init(monad): + assert monad.type is Json, monad.type + def get_path(monad): + return monad, [] + def __getitem__(monad, key): + return JsonItemMonad(monad, key) + def contains(monad, key, not_in=False): + translator = monad.translator + if isinstance(key, ParamMonad): + if translator.dialect == 'Oracle': throw(TypeError, + 'For `key in JSON` operation %s supports literal key values only, ' + 'parameters are not allowed: {EXPR}' % translator.dialect) + elif not isinstance(key, StringConstMonad): raise NotImplementedError + base_monad, path = monad.get_path() + base_sql = base_monad.getsql()[0] + key_sql = key.getsql()[0] + sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] + if not_in: sql = [ 'NOT', sql ] + return BoolExprMonad(sql) + def __or__(monad, other): + if not isinstance(other, JsonMixin): + raise TypeError('Should be JSON: %s' % ast2src(other.node)) + left_sql = monad.getsql()[0] + right_sql = other.getsql()[0] + sql = [ 'JSON_CONCAT', left_sql, right_sql ] + return JsonExprMonad(Json, sql) + def len(monad): + sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] + return NumericExprMonad(int, sql) + def cast_from_json(monad, type): + if type in (Json, NoneType): return monad + throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') + def nonzero(monad): + return BoolExprMonad([ 'JSON_NONZERO', monad.getsql()[0] ]) + +class ArrayMixin(MonadMixin): + def contains(monad, key, not_in=False): + if key.type is monad.type.item_type: + sql = 'ARRAY_CONTAINS', key.getsql()[0], not_in, monad.getsql()[0] + return BoolExprMonad(sql) + if isinstance(key, ListMonad): + if not key.items: + if not_in: + return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False) + else: + return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False) + sql = [ 'MAKE_ARRAY' ] + sql.extend(item.getsql()[0] for item in key.items) + sql = 'ARRAY_SUBSET', sql, not_in, monad.getsql()[0] + return BoolExprMonad(sql) + elif isinstance(key, ArrayParamMonad): + sql = 'ARRAY_SUBSET', key.getsql()[0], not_in, monad.getsql()[0] + return BoolExprMonad(sql) + throw(TypeError, 'Cannot search for %s in %s: {EXPR}' % + (type2str(key.type), type2str(monad.type))) + + def len(monad): + sql = ['ARRAY_LENGTH', monad.getsql()[0]] + return NumericExprMonad(int, sql) + + def nonzero(monad): + return BoolExprMonad(['GT', ['ARRAY_LENGTH', monad.getsql()[0]], ['VALUE', 0]]) + + def _index(monad, index, from_one, plus_one): + if isinstance(index, NumericConstMonad): + expr_sql = monad.getsql()[0] + index_sql = index.getsql()[0] + value = index_sql[1] + if value >= 0: + index_sql = ['VALUE', value + int(from_one and plus_one)] + else: + index_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value + int(from_one and plus_one))]] + return index_sql + elif isinstance(index, NumericMixin): + expr_sql = monad.getsql()[0] + index0 = index.getsql()[0] + index1 = ['ADD', index0, ['VALUE', 1]] if from_one and plus_one else index0 + index_sql = ['CASE', None, [[['GE', index0, ['VALUE', 0]], index1]], + ['ADD', ['ARRAY_LENGTH', expr_sql], index1]] + return index_sql + + def __getitem__(monad, index): + dialect = monad.translator.database.provider.dialect + expr_sql = monad.getsql()[0] + from_one = dialect != 'SQLite' + if isinstance(index, NumericMixin): + index_sql = monad._index(index, from_one, plus_one=True) + sql = ['ARRAY_INDEX', expr_sql, index_sql] + return ExprMonad.new(monad.type.item_type, sql) + elif isinstance(index, slice): + if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') + start_sql = monad._index(index.start, from_one, plus_one=True) + stop_sql = monad._index(index.stop, from_one, plus_one=False) + sql = ['ARRAY_SLICE', expr_sql, start_sql, stop_sql] + return ExprMonad.new(monad.type, sql) + + class ObjectMixin(MonadMixin): def mixin_init(monad): assert isinstance(monad.type, EntityMeta) def negate(monad): - translator = monad.translator - return translator.CmpMonad('is', monad, translator.NoneMonad(translator)) + return CmpMonad('is', monad, NoneMonad()) def nonzero(monad): - translator = monad.translator - return translator.CmpMonad('is not', monad, translator.NoneMonad(translator)) - def getattr(monad, name): - translator = monad.translator + return CmpMonad('is not', monad, NoneMonad()) + def getattr(monad, attrname): entity = monad.type - attr = entity._adict_.get(name) or entity._subclass_adict_.get(name) - if attr is None: throw(AttributeError, - 'Entity %s does not have attribute %s: {EXPR}' % (entity.__name__, name)) + attr = entity._adict_.get(attrname) or entity._subclass_adict_.get(attrname) + if attr is None: + if hasattr(entity, attrname): + attr = getattr(entity, attrname, None) + if isinstance(attr, property): + new_monad = HybridMethodMonad(monad, attrname, attr.fget) + return new_monad() + if callable(attr): + func = getattr(attr, '__func__') if PY2 else attr + if func is not None: return HybridMethodMonad(monad, attrname, func) + throw(NotImplementedError, '{EXPR} cannot be translated to SQL') + throw(AttributeError, 'Entity %s does not have attribute %s: {EXPR}' % (entity.__name__, attrname)) if hasattr(monad, 'tableref'): monad.tableref.used_attrs.add(attr) if not attr.is_collection: - return translator.AttrMonad.new(monad, attr) + return AttrMonad.new(monad, attr) else: - return translator.AttrSetMonad(monad, attr) + return AttrSetMonad(monad, attr) def requires_distinct(monad, joined=False): return monad.attr.reverse.is_collection or monad.parent.requires_distinct(joined) # parent ??? class ObjectIterMonad(ObjectMixin, Monad): - def __init__(monad, translator, tableref, entity): - Monad.__init__(monad, translator, entity) + def __init__(monad, tableref, entity): + Monad.__init__(monad, entity) monad.tableref = tableref - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): entity = monad.type alias, pk_columns = monad.tableref.make_join(pk_only=True) return [ [ 'COLUMN', alias, column ] for column in pk_columns ] @@ -1476,30 +2145,31 @@ def requires_distinct(monad, joined=False): class AttrMonad(Monad): @staticmethod def new(parent, attr, *args, **kwargs): - translator = parent.translator - type = normalize_type(attr.py_type) - if type in numeric_types: cls = translator.NumericAttrMonad - elif type is unicode: cls = translator.StringAttrMonad - elif type is date: cls = translator.DateAttrMonad - elif type is time: cls = translator.TimeAttrMonad - elif type is timedelta: cls = translator.TimedeltaAttrMonad - elif type is datetime: cls = translator.DatetimeAttrMonad - elif type is buffer: cls = translator.BufferAttrMonad - elif type is UUID: cls = translator.UuidAttrMonad - elif isinstance(type, EntityMeta): cls = translator.ObjectAttrMonad - else: throw(NotImplementedError, type) # pragma: no cover + t = normalize_type(attr.py_type) + if t in numeric_types: cls = NumericAttrMonad + elif t is unicode: cls = StringAttrMonad + elif t is date: cls = DateAttrMonad + elif t is time: cls = TimeAttrMonad + elif t is timedelta: cls = TimedeltaAttrMonad + elif t is datetime: cls = DatetimeAttrMonad + elif t is buffer: cls = BufferAttrMonad + elif t is UUID: cls = UuidAttrMonad + elif t is Json: cls = JsonAttrMonad + elif isinstance(t, EntityMeta): cls = ObjectAttrMonad + elif isinstance(t, type) and issubclass(t, Array): cls = ArrayAttrMonad + else: throw(NotImplementedError, t) # pragma: no cover return cls(parent, attr, *args, **kwargs) def __new__(cls, *args): if cls is AttrMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, parent, attr): assert monad.__class__ is not AttrMonad - translator = parent.translator attr_type = normalize_type(attr.py_type) - Monad.__init__(monad, parent.translator, attr_type) + Monad.__init__(monad, attr_type) monad.parent = parent monad.attr = attr - def getsql(monad, subquery=None): + monad.nullable = attr.nullable + def getsql(monad, sqlquery=None): parent = monad.parent attr = monad.attr entity = attr.entity @@ -1512,9 +2182,9 @@ def getsql(monad, subquery=None): else: columns = parent_columns elif not attr.columns: assert isinstance(monad, ObjectAttrMonad) - subquery = monad.translator.subquery - monad.translator.left_join = subquery.left_join = True - subquery.from_ast[0] = 'LEFT_JOIN' + sqlquery = monad.translator.sqlquery + monad.translator.left_join = sqlquery.left_join = True + sqlquery.from_ast[0] = 'LEFT_JOIN' alias, columns = monad.tableref.make_join() else: columns = attr.columns return [ [ 'COLUMN', alias, column ] for column in columns ] @@ -1526,59 +2196,64 @@ def __init__(monad, parent, attr): parent_monad = monad.parent entity = monad.type name_path = '-'.join((parent_monad.tableref.name_path, attr.name)) - monad.tableref = translator.subquery.get_tableref(name_path) + monad.tableref = translator.sqlquery.get_tableref(name_path) if monad.tableref is None: - parent_subquery = parent_monad.tableref.subquery - monad.tableref = parent_subquery.add_tableref(name_path, parent_monad.tableref, attr) + parent_sqlquery = parent_monad.tableref.sqlquery + monad.tableref = parent_sqlquery.add_tableref(name_path, parent_monad.tableref, attr) -class NumericAttrMonad(NumericMixin, AttrMonad): pass class StringAttrMonad(StringMixin, AttrMonad): pass +class NumericAttrMonad(NumericMixin, AttrMonad): pass class DateAttrMonad(DateMixin, AttrMonad): pass class TimeAttrMonad(TimeMixin, AttrMonad): pass class TimedeltaAttrMonad(TimedeltaMixin, AttrMonad): pass class DatetimeAttrMonad(DatetimeMixin, AttrMonad): pass class BufferAttrMonad(BufferMixin, AttrMonad): pass class UuidAttrMonad(UuidMixin, AttrMonad): pass +class JsonAttrMonad(JsonMixin, AttrMonad): pass +class ArrayAttrMonad(ArrayMixin, AttrMonad): pass class ParamMonad(Monad): @staticmethod - def new(translator, type, paramkey): - type = normalize_type(type) - if type in numeric_types: cls = translator.NumericParamMonad - elif type is unicode: cls = translator.StringParamMonad - elif type is date: cls = translator.DateParamMonad - elif type is time: cls = translator.TimeParamMonad - elif type is timedelta: cls = translator.TimedeltaParamMonad - elif type is datetime: cls = translator.DatetimeParamMonad - elif type is buffer: cls = translator.BufferParamMonad - elif type is UUID: cls = translator.UuidParamMonad - elif isinstance(type, EntityMeta): cls = translator.ObjectParamMonad - else: throw(NotImplementedError, type) # pragma: no cover - result = cls(translator, type, paramkey) + def new(t, paramkey): + t = normalize_type(t) + if t in numeric_types: cls = NumericParamMonad + elif t is unicode: cls = StringParamMonad + elif t is date: cls = DateParamMonad + elif t is time: cls = TimeParamMonad + elif t is timedelta: cls = TimedeltaParamMonad + elif t is datetime: cls = DatetimeParamMonad + elif t is buffer: cls = BufferParamMonad + elif t is UUID: cls = UuidParamMonad + elif t is Json: cls = JsonParamMonad + elif isinstance(t, type) and issubclass(t, Array): cls = ArrayParamMonad + elif isinstance(t, EntityMeta): cls = ObjectParamMonad + else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (t,)) + result = cls(t, paramkey) result.aggregated = False return result - def __new__(cls, *args): + def __new__(cls, *args, **kwargs): if cls is ParamMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, type, paramkey): - type = normalize_type(type) - Monad.__init__(monad, translator, type) + def __init__(monad, t, paramkey): + t = normalize_type(t) + Monad.__init__(monad, t, nullable=False) monad.paramkey = paramkey - if not isinstance(type, EntityMeta): - provider = translator.database.provider - monad.converter = provider.get_converter_by_py_type(type) + if not isinstance(t, EntityMeta): + provider = monad.translator.database.provider + monad.converter = provider.get_converter_by_py_type(t) else: monad.converter = None - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'PARAM', monad.paramkey, monad.converter ] ] class ObjectParamMonad(ObjectMixin, ParamMonad): - def __init__(monad, translator, entity, paramkey): - assert translator.database is entity._database_ - ParamMonad.__init__(monad, translator, entity, paramkey) + def __init__(monad, entity, paramkey): + ParamMonad.__init__(monad, entity, paramkey) + if monad.translator.database is not entity._database_: + assert monad.translator.database is entity._database_, (paramkey, monad.translator.database, entity._database_) varkey, i, j = paramkey assert j is None monad.params = tuple((varkey, i, j) for j in xrange(len(entity._pk_converters_))) - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): entity = monad.type assert len(monad.params) == len(entity._pk_converters_) return [ [ 'PARAM', param, converter ] for param, converter in izip(monad.params, entity._pk_converters_) ] @@ -1589,76 +2264,147 @@ class StringParamMonad(StringMixin, ParamMonad): pass class NumericParamMonad(NumericMixin, ParamMonad): pass class DateParamMonad(DateMixin, ParamMonad): pass class TimeParamMonad(TimeMixin, ParamMonad): pass -class TimedeltaParamMonad(TimeMixin, ParamMonad): pass +class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass +class ArrayParamMonad(ArrayMixin, ParamMonad): + def __init__(monad, t, paramkey, list_monad=None): + ParamMonad.__init__(monad, t, paramkey) + monad.list_monad = list_monad + def contains(monad, key, not_in=False): + if key.type is monad.type.item_type: + return monad.list_monad.contains(key, not_in) + return ArrayMixin.contains(monad, key, not_in) + +class JsonParamMonad(JsonMixin, ParamMonad): + def getsql(monad, sqlquery=None): + return [ [ 'JSON_PARAM', ParamMonad.getsql(monad)[0] ] ] + class ExprMonad(Monad): @staticmethod - def new(translator, type, sql): - if type in numeric_types: cls = translator.NumericExprMonad - elif type is unicode: cls = translator.StringExprMonad - elif type is date: cls = translator.DateExprMonad - elif type is time: cls = translator.TimeExprMonad - elif type is timedelta: cls = translator.TimedeltaExprMonad - elif type is datetime: cls = translator.DatetimeExprMonad - else: throw(NotImplementedError, type) # pragma: no cover - return cls(translator, type, sql) - def __new__(cls, *args): + def new(t, sql, nullable=True): + if t in numeric_types: cls = NumericExprMonad + elif t is unicode: cls = StringExprMonad + elif t is date: cls = DateExprMonad + elif t is time: cls = TimeExprMonad + elif t is timedelta: cls = TimedeltaExprMonad + elif t is datetime: cls = DatetimeExprMonad + elif t is Json: cls = JsonExprMonad + elif isinstance(t, EntityMeta): cls = ObjectExprMonad + elif isinstance(t, type) and issubclass(t, Array): cls = ArrayExprMonad + else: throw(NotImplementedError, t) # pragma: no cover + return cls(t, sql, nullable=nullable) + def __new__(cls, *args, **kwargs): if cls is ExprMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, type, sql): - Monad.__init__(monad, translator, type) + def __init__(monad, type, sql, nullable=True): + Monad.__init__(monad, type, nullable=nullable) monad.sql = sql - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ monad.sql ] +class ObjectExprMonad(ObjectMixin, ExprMonad): + def getsql(monad, sqlquery=None): + return monad.sql + class StringExprMonad(StringMixin, ExprMonad): pass class NumericExprMonad(NumericMixin, ExprMonad): pass class DateExprMonad(DateMixin, ExprMonad): pass class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass +class JsonExprMonad(JsonMixin, ExprMonad): pass +class ArrayExprMonad(ArrayMixin, ExprMonad): pass + +class JsonItemMonad(JsonMixin, Monad): + def __init__(monad, parent, key): + assert isinstance(parent, JsonMixin), parent + Monad.__init__(monad, Json) + monad.parent = parent + if isinstance(key, slice): + if key != slice(None, None, None): throw(NotImplementedError) + monad.key_ast = [ 'VALUE', key ] + elif isinstance(key, (ParamMonad, StringConstMonad, NumericConstMonad, EllipsisMonad)): + monad.key_ast = key.getsql()[0] + else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) + translator = monad.translator + if isinstance(key, (slice, EllipsisMonad)) and not translator.json_path_wildcard_syntax: + throw(TranslationError, '%s does not support wildcards in JSON path: {EXPR}' % translator.dialect) + def get_path(monad): + path = [] + while isinstance(monad, JsonItemMonad): + path.append(monad.key_ast) + monad = monad.parent + path.reverse() + return monad, path + def to_int(monad): + return monad.cast_from_json(int) + def to_str(monad): + return monad.cast_from_json(unicode) + def to_real(monad): + return monad.cast_from_json(float) + def cast_from_json(monad, type): + translator = monad.translator + if issubclass(type, Json): + if not translator.json_values_are_comparable: throw(TranslationError, + '%s does not support comparison of json structures: {EXPR}' % translator.dialect) + return monad + base_monad, path = monad.get_path() + sql = [ 'JSON_VALUE', base_monad.getsql()[0], path, type ] + return ExprMonad.new(Json if type is NoneType else type, sql) + def getsql(monad): + base_monad, path = monad.get_path() + base_sql = base_monad.getsql()[0] + translator = monad.translator + if translator.inside_order_by and translator.dialect == 'SQLite': + return [ [ 'JSON_VALUE', base_sql, path, None ] ] + return [ [ 'JSON_QUERY', base_sql, path ] ] class ConstMonad(Monad): @staticmethod - def new(translator, value): - value_type = get_normalized_type_of(value) - if value_type in numeric_types: cls = translator.NumericConstMonad - elif value_type is unicode: cls = translator.StringConstMonad - elif value_type is date: cls = translator.DateConstMonad - elif value_type is time: cls = translator.TimeConstMonad - elif value_type is timedelta: cls = translator.TimedeltaConstMonad - elif value_type is datetime: cls = translator.DatetimeConstMonad - elif value_type is NoneType: cls = translator.NoneMonad - elif value_type is buffer: cls = translator.BufferConstMonad + def new(value): + value_type, value = normalize(value) + if value_type in numeric_types: cls = NumericConstMonad + elif value_type is unicode: cls = StringConstMonad + elif value_type is date: cls = DateConstMonad + elif value_type is time: cls = TimeConstMonad + elif value_type is timedelta: cls = TimedeltaConstMonad + elif value_type is datetime: cls = DatetimeConstMonad + elif value_type is NoneType: cls = NoneMonad + elif value_type is buffer: cls = BufferConstMonad + elif value_type is Json: cls = JsonConstMonad + elif issubclass(value_type, type(Ellipsis)): cls = EllipsisMonad else: throw(NotImplementedError, value_type) # pragma: no cover - result = cls(translator, value) + result = cls(value) result.aggregated = False return result def __new__(cls, *args): if cls is ConstMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, value): - value_type = get_normalized_type_of(value) - Monad.__init__(monad, translator, value_type) + def __init__(monad, value): + value_type, value = normalize(value) + Monad.__init__(monad, value_type, nullable=value_type is NoneType) monad.value = value - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'VALUE', monad.value ] ] class NoneMonad(ConstMonad): type = NoneType - def __init__(monad, translator, value=None): + def __init__(monad, value=None): assert value is None - ConstMonad.__init__(monad, translator, value) + ConstMonad.__init__(monad, value) -class BufferConstMonad(BufferMixin, ConstMonad): pass +class EllipsisMonad(ConstMonad): + pass class StringConstMonad(StringMixin, ConstMonad): def len(monad): - return monad.translator.ConstMonad.new(monad.translator, len(monad.value)) + return ConstMonad.new(len(monad.value)) +class JsonConstMonad(JsonMixin, ConstMonad): pass +class BufferConstMonad(BufferMixin, ConstMonad): pass class NumericConstMonad(NumericMixin, ConstMonad): pass class DateConstMonad(DateMixin, ConstMonad): pass class TimeConstMonad(TimeMixin, ConstMonad): pass @@ -1666,22 +2412,21 @@ class TimedeltaConstMonad(TimedeltaMixin, ConstMonad): pass class DatetimeConstMonad(DatetimeMixin, ConstMonad): pass class BoolMonad(Monad): - def __init__(monad, translator): - monad.translator = translator - monad.type = bool + def __init__(monad, nullable=True): + Monad.__init__(monad, bool, nullable=nullable) + def nonzero(monad): + return monad sql_negation = { 'IN' : 'NOT_IN', 'EXISTS' : 'NOT_EXISTS', 'LIKE' : 'NOT_LIKE', 'BETWEEN' : 'NOT_BETWEEN', 'IS_NULL' : 'IS_NOT_NULL' } sql_negation.update((value, key) for key, value in items_list(sql_negation)) class BoolExprMonad(BoolMonad): - def __init__(monad, translator, sql): - monad.translator = translator - monad.type = bool + def __init__(monad, sql, nullable=True): + BoolMonad.__init__(monad, nullable=nullable) monad.sql = sql - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ monad.sql ] def negate(monad): - translator = monad.translator sql = monad.sql sqlop = sql[0] negated_op = sql_negation.get(sqlop) @@ -1690,8 +2435,8 @@ def negate(monad): elif negated_op == 'NOT': assert len(sql) == 2 negated_sql = sql[1] - else: return translator.NotMonad(translator, sql) - return translator.BoolExprMonad(translator, negated_sql) + else: return NotMonad(monad) + return BoolExprMonad(negated_sql, nullable=monad.nullable) cmp_ops = { '>=' : 'GE', '>' : 'GT', '<=' : 'LE', '<' : 'LT' } @@ -1699,8 +2444,9 @@ def negate(monad): cmp_negate.update((b, a) for a, b in items_list(cmp_negate)) class CmpMonad(BoolMonad): + EQ = 'EQ' + NE = 'NE' def __init__(monad, op, left, right): - translator = left.translator if op == '<>': op = '!=' if left.type is NoneType: assert right.type is not NoneType @@ -1711,17 +2457,22 @@ def __init__(monad, op, left, right): elif op == 'is': op = '==' elif op == 'is not': op = '!=' check_comparable(left, right, op) - result_type, left, right = coerce_monads(left, right) - BoolMonad.__init__(monad, translator) + result_type, left, right = coerce_monads(left, right, for_comparison=True) + BoolMonad.__init__(monad, nullable=left.nullable or right.nullable) monad.op = op + monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) + + if isinstance(left, JsonMixin): + left = left.cast_from_json(right.type) + if isinstance(right, JsonMixin): + right = right.cast_from_json(left.type) + monad.left = left monad.right = right - monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) def negate(monad): - return monad.translator.CmpMonad(cmp_negate[monad.op], monad.left, monad.right) - def getsql(monad, subquery=None): + return CmpMonad(cmp_negate[monad.op], monad.left, monad.right) + def getsql(monad, sqlquery=None): op = monad.op - sql = [] left_sql = monad.left.getsql() if op == 'is': return [ sqland([ [ 'IS_NULL', item ] for item in left_sql ]) ] @@ -1740,30 +2491,30 @@ def getsql(monad, subquery=None): if monad.translator.row_value_syntax: return [ [ cmp_ops[op], [ 'ROW' ] + left_sql, [ 'ROW' ] + right_sql ] ] clauses = [] - for i in xrange(1, size): - clauses.append(sqland([ [ 'EQ', left_sql[j], right_sql[j] ] for j in xrange(1, i) ] - + [ [ cmp_ops[op[0] if i < size - 1 else op], left_sql[i], right_sql[i] ] ])) + for i in xrange(size): + clause = [ [ monad.EQ, left_sql[j], right_sql[j] ] for j in range(i) ] + clause.append([ cmp_ops[op], left_sql[i], right_sql[i] ]) + clauses.append(sqland(clause)) return [ sqlor(clauses) ] if op == '==': - return [ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, right_sql) ]) ] + return [ sqland([ [ monad.EQ, a, b ] for a, b in izip(left_sql, right_sql) ]) ] if op == '!=': - return [ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, right_sql) ]) ] + return [ sqlor([ [ monad.NE, a, b ] for a, b in izip(left_sql, right_sql) ]) ] assert False, op # pragma: no cover class LogicalBinOpMonad(BoolMonad): def __init__(monad, operands): assert len(operands) >= 2 items = [] - translator = operands[0].translator - monad.translator = translator for operand in operands: if operand.type is not bool: items.append(operand.nonzero()) - elif isinstance(operand, translator.LogicalBinOpMonad) and monad.binop == operand.binop: + elif isinstance(operand, LogicalBinOpMonad) and monad.binop == operand.binop: items.extend(operand.operands) else: items.append(operand) - BoolMonad.__init__(monad, items[0].translator) + nullable = any(item.nullable for item in items) + BoolMonad.__init__(monad, nullable=nullable) monad.operands = items - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): result = [ monad.binop ] for operand in monad.operands: operand_sql = operand.getsql() @@ -1780,17 +2531,71 @@ class OrMonad(LogicalBinOpMonad): class NotMonad(BoolMonad): def __init__(monad, operand): if operand.type is not bool: operand = operand.nonzero() - BoolMonad.__init__(monad, operand.translator) + BoolMonad.__init__(monad, nullable=operand.nullable) monad.operand = operand def negate(monad): return monad.operand - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'NOT', monad.operand.getsql()[0] ] ] -class ErrorSpecialFuncMonad(Monad): - def __init__(monad, translator, func): - Monad.__init__(monad, translator, func) - monad.func = func +class HybridFuncMonad(Monad): + def __init__(monad, func_type, func_name, *params): + Monad.__init__(monad, func_type) + monad.func = func_type.func + monad.func_name = func_name + monad.params = params + def __call__(monad, *args, **kwargs): + translator = monad.translator + name_mapping = inspect.getcallargs(monad.func, *(monad.params + args), **kwargs) + + func = monad.func + if PY2 and isinstance(func, types.UnboundMethodType): + func = func.im_func + func_id = id(func) + try: + func_ast, external_names, cells = decompile(func) + except DecompileError: + throw(TranslationError, '%s(...) is too complex to decompile' % ast2src(monad.node)) + + func_ast, func_extractors = create_extractors( + func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) + + root_translator = translator.root_translator + if func not in root_translator.func_extractors_map: + func_vars, func_vartypes = extract_vars(func_id, translator.filter_num, func_extractors, func.__globals__, {}, cells) + translator.database.provider.normalize_vars(func_vars, func_vartypes) + if func.__closure__: + translator.can_be_cached = False + if func_extractors: + root_translator.func_extractors_map[func] = func_extractors + root_translator.func_vartypes.update(func_vartypes) + root_translator.vartypes.update(func_vartypes) + root_translator.vars.update(func_vars) + + stack = translator.namespace_stack + stack.append(name_mapping) + func_ast = copy_ast(func_ast) + try: + prev_code_key = translator.code_key + translator.code_key = func_id + try: + translator.dispatch(func_ast) + finally: + translator.code_key = prev_code_key + except Exception as e: + if len(e.args) == 1 and isinstance(e.args[0], basestring): + msg = e.args[0] + ' (inside %s)' % (monad.func_name) + e.args = (msg,) + raise + stack.pop() + return func_ast.monad + +class HybridMethodMonad(HybridFuncMonad): + def __init__(monad, parent, attrname, func): + entity = parent.type + assert isinstance(entity, EntityMeta) + func_name = '%s.%s' % (entity.__name__, attrname) + HybridFuncMonad.__init__(monad, FuncType(func), func_name, parent) registered_functions = SQLTranslator.registered_functions = {} @@ -1806,122 +2611,205 @@ def __new__(meta, cls_name, bases, cls_dict): class FuncMonad(with_metaclass(FuncMonadMeta, Monad)): def __call__(monad, *args, **kwargs): - translator = monad.translator for arg in args: - assert isinstance(arg, translator.Monad) + assert isinstance(arg, Monad) for value in kwargs.values(): - assert isinstance(value, translator.Monad) + assert isinstance(value, Monad) try: return monad.call(*args, **kwargs) except TypeError as exc: reraise_improved_typeerror(exc, 'call', monad.type.__name__) +def get_classes(classinfo): + if isinstance(classinfo, EntityMonad): + yield classinfo.type.item_type + elif isinstance(classinfo, ListMonad): + for item in classinfo.items: + for type in get_classes(item): + yield type + else: throw(TypeError, ast2src(classinfo.node)) + +class FuncIsinstanceMonad(FuncMonad): + func = isinstance + def call(monad, obj, classinfo): + if not isinstance(obj, ObjectMixin): throw(ValueError, + 'Inside a query, isinstance first argument should be of entity type. Got: %s' % ast2src(obj.node)) + entity = obj.type + classes = list(get_classes(classinfo)) + subclasses = set() + for cls in classes: + if entity._root_ is cls._root_: + subclasses.add(cls) + subclasses.update(cls._subclasses_) + if entity in subclasses: + return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False) + + subclasses.intersection_update(entity._subclasses_) + if not subclasses: + return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False) + + discr_attr = entity._discriminator_attr_ + assert discr_attr is not None + discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in subclasses ] + alias, pk_columns = obj.tableref.make_join(pk_only=True) + sql = [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] + return BoolExprMonad(sql, nullable=False) + + class FuncBufferMonad(FuncMonad): func = buffer def call(monad, source, encoding=None, errors=None): - translator = monad.translator - if not isinstance(source, translator.StringConstMonad): throw(TypeError) + if not isinstance(source, StringConstMonad): throw(TypeError) source = source.value if encoding is not None: - if not isinstance(encoding, translator.StringConstMonad): throw(TypeError) + if not isinstance(encoding, StringConstMonad): throw(TypeError) encoding = encoding.value if errors is not None: - if not isinstance(errors, translator.StringConstMonad): throw(TypeError) + if not isinstance(errors, StringConstMonad): throw(TypeError) errors = errors.value if PY2: if encoding and errors: source = source.encode(encoding, errors) elif encoding: source = source.encode(encoding) - return translator.ConstMonad.new(translator, buffer(source)) + return ConstMonad.new(buffer(source)) else: if encoding and errors: value = buffer(source, encoding, errors) elif encoding: value = buffer(source, encoding) else: value = buffer(source) - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(value) + +class FuncBoolMonad(FuncMonad): + func = bool + def call(monad, x): + return x.nonzero() + +class FuncIntMonad(FuncMonad): + func = int + def call(monad, x): + return x.to_int() + +class FuncStrMonad(FuncMonad): + func = str + def call(monad, x): + return x.to_str() + +class FuncFloatMonad(FuncMonad): + func = float + def call(monad, x): + return x.to_real() class FuncDecimalMonad(FuncMonad): func = Decimal def call(monad, x): - translator = monad.translator - if not isinstance(x, translator.StringConstMonad): throw(TypeError) - return translator.ConstMonad.new(translator, Decimal(x.value)) + if not isinstance(x, StringConstMonad): throw(TypeError) + return ConstMonad.new(Decimal(x.value)) class FuncDateMonad(FuncMonad): func = date def call(monad, year, month, day): - translator = monad.translator for arg, name in izip((year, month, day), ('year', 'month', 'day')): - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of date(year, month, day) function must be of 'int' type. " "Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) - return translator.ConstMonad.new(translator, date(year.value, month.value, day.value)) + return ConstMonad.new(date(year.value, month.value, day.value)) def call_today(monad): - translator = monad.translator - return translator.DateExprMonad(translator, date, [ 'TODAY' ]) + return DateExprMonad(date, [ 'TODAY' ], nullable=monad.nullable) class FuncTimeMonad(FuncMonad): func = time def call(monad, *args): - translator = monad.translator for arg, name in izip(args, ('hour', 'minute', 'second', 'microsecond')): - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of time(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) - return translator.ConstMonad.new(translator, time(*tuple(arg.value for arg in args))) + return ConstMonad.new(time(*tuple(arg.value for arg in args))) class FuncTimedeltaMonad(FuncMonad): func = timedelta def call(monad, days=None, seconds=None, microseconds=None, milliseconds=None, minutes=None, hours=None, weeks=None): - translator = monad.translator args = days, seconds, microseconds, milliseconds, minutes, hours, weeks for arg, name in izip(args, ('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks')): if arg is None: continue - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of timedelta(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = timedelta(*(arg.value if arg is not None else 0 for arg in args)) - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(value) class FuncDatetimeMonad(FuncDateMonad): func = datetime def call(monad, year, month, day, hour=None, minute=None, second=None, microsecond=None): args = year, month, day, hour, minute, second, microsecond - translator = monad.translator for arg, name in izip(args, ('year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond')): if arg is None: continue - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of datetime(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = datetime(*(arg.value if arg is not None else 0 for arg in args)) - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(value) def call_now(monad): - translator = monad.translator - return translator.DatetimeExprMonad(translator, datetime, [ 'NOW' ]) + return DatetimeExprMonad(datetime, [ 'NOW' ], nullable=monad.nullable) + +class FuncBetweenMonad(FuncMonad): + func = between + def call(monad, x, a, b): + check_comparable(x, a, '<') + check_comparable(x, b, '<') + if isinstance(x.type, EntityMeta): throw(TypeError, + '%s instance cannot be argument of between() function: {EXPR}' % x.type.__name__) + sql = [ 'BETWEEN', x.getsql()[0], a.getsql()[0], b.getsql()[0] ] + return BoolExprMonad(sql, nullable=x.nullable or a.nullable or b.nullable) class FuncConcatMonad(FuncMonad): func = concat def call(monad, *args): if len(args) < 2: throw(TranslationError, 'concat() function requires at least two arguments') - translator = args[0].translator result_ast = [ 'CONCAT' ] + translator = monad.translator for arg in args: t = arg.type if isinstance(t, EntityMeta) or type(t) in (tuple, SetType): throw(TranslationError, 'Invalid argument of concat() function: %s' % ast2src(arg.node)) + if translator.database.provider_name == 'cockroach' and not isinstance(arg, StringMixin): + arg = arg.to_str() result_ast.extend(arg.getsql()) - return translator.ExprMonad.new(translator, unicode, result_ast) + return ExprMonad.new(unicode, result_ast, nullable=any(arg.nullable for arg in args)) class FuncLenMonad(FuncMonad): func = len def call(monad, x): return x.len() +class FuncGetattrMonad(FuncMonad): + func = getattr + def call(monad, obj_monad, name_monad): + if isinstance(name_monad, ConstMonad): + attrname = name_monad.value + elif isinstance(name_monad, ParamMonad): + translator = monad.translator.root_translator + key = name_monad.paramkey[0] + if key in translator.fixed_param_values: + attrname = translator.fixed_param_values[key] + else: + attrname = translator.vars[key] + translator.fixed_param_values[key] = attrname + else: throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' + 'because %s will be different for each row' % ast2src(name_monad.node)) + if not isinstance(attrname, basestring): + throw(TypeError, 'In `{EXPR}` second argument should be a string. Got: %r' % attrname) + return obj_monad.getattr(attrname) + +class FuncRawSQLMonad(FuncMonad): + func = raw_sql + def call(monad, *args): + throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' + 'because raw SQL fragment will be different for each row') + class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count - def call(monad, x=None): - translator = monad.translator - if isinstance(x, translator.StringConstMonad) and x.value == '*': x = None - if x is not None: return x.count() - result = translator.ExprMonad.new(translator, int, [ 'COUNT', 'ALL' ]) + def call(monad, x=None, distinct=None): + if isinstance(x, StringConstMonad) and x.value == '*': x = None + if x is not None: return x.count(distinct) + result = ExprMonad.new(int, [ 'COUNT', None ], nullable=False) result.aggregated = True return result @@ -1932,13 +2820,39 @@ def call(monad, x): class FuncSumMonad(FuncMonad): func = sum, core.sum - def call(monad, x): - return x.aggregate('SUM') + def call(monad, x, distinct=None): + return x.aggregate('SUM', distinct) class FuncAvgMonad(FuncMonad): func = utils.avg, core.avg - def call(monad, x): - return x.aggregate('AVG') + def call(monad, x, distinct=None): + return x.aggregate('AVG', distinct) + +class FuncGroupConcatMonad(FuncMonad): + func = utils.group_concat, core.group_concat + def call(monad, x, sep=None, distinct=None): + if sep is not None: + if distinct and monad.translator.database.provider.dialect == 'SQLite': + throw(TypeError, 'SQLite does not allow to specify distinct and separator in group_concat at the same time: {EXPR}') + if not(isinstance(sep, StringConstMonad) and isinstance(sep.value, basestring)): + throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % ast2src(sep.node)) + sep = sep.value + return x.aggregate('GROUP_CONCAT', distinct=distinct, sep=sep) + +class FuncCoalesceMonad(FuncMonad): + func = coalesce + def call(monad, *args): + if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments') + arg = args[0] + t = arg.type + result = [ [ sql ] for sql in arg.getsql() ] + for arg in args[1:]: + if arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type') + for i, sql in enumerate(arg.getsql()): + result[i].append(sql) + sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] + if not isinstance(t, EntityMeta): sql = sql[0] + return ExprMonad.new(t, sql, nullable=all(arg.nullable for arg in args)) class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct @@ -1982,22 +2896,21 @@ def minmax(monad, sqlop, *args): args = list(args) for i, arg in enumerate(args): if arg.type is bool: - args[i] = NumericExprMonad(translator, int, [ 'TO_INT', arg.getsql() ]) - sql = [ sqlop ] + [ arg.getsql()[0] for arg in args ] - return translator.ExprMonad.new(translator, t, sql) + args[i] = NumericExprMonad(int, [ 'TO_INT', arg.getsql()[0] ], nullable=arg.nullable) + sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] + return ExprMonad.new(t, sql, nullable=any(arg.nullable for arg in args)) class FuncSelectMonad(FuncMonad): func = core.select def call(monad, queryset): - translator = monad.translator - if not isinstance(queryset, translator.QuerySetMonad): throw(TypeError, + if not isinstance(queryset, QuerySetMonad): throw(TypeError, "'select' function expects generator expression, got: {EXPR}") return queryset class FuncExistsMonad(FuncMonad): func = core.exists def call(monad, arg): - if not isinstance(arg, monad.translator.SetMixin): throw(TypeError, + if not isinstance(arg, SetMixin): throw(TypeError, "'exists' function expects generator expression or collection, got: {EXPR}") return arg.nonzero() @@ -2008,14 +2921,15 @@ def call(monad, expr): class DescMonad(Monad): def __init__(monad, expr): - Monad.__init__(monad, expr.translator, expr.type) + Monad.__init__(monad, expr.type, nullable=expr.nullable) monad.expr = expr def getsql(monad): return [ [ 'DESC', item ] for item in monad.expr.getsql() ] class JoinMonad(Monad): - def __init__(monad, translator, type): - Monad.__init__(monad, translator, type) + def __init__(monad, type): + Monad.__init__(monad, type) + translator = monad.translator monad.hint_join_prev = translator.hint_join translator.hint_join = True def __call__(monad, x): @@ -2025,11 +2939,11 @@ def __call__(monad, x): class FuncRandomMonad(FuncMonad): func = random - def __init__(monad, translator, type): - FuncMonad.__init__(monad, translator, type) - translator.query_result_is_cacheable = False + def __init__(monad, type): + FuncMonad.__init__(monad, type) + monad.translator.query_result_is_cacheable = False def __call__(monad): - return NumericExprMonad(monad.translator, float, [ 'RANDOM' ]) + return NumericExprMonad(float, [ 'RANDOM' ], nullable=False) class SetMixin(MonadMixin): forced_distinct = False @@ -2041,21 +2955,18 @@ def call_distinct(monad): def make_attrset_binop(op, sqlop): def attrset_binop(monad, monad2): - NumericSetExprMonad = monad.translator.NumericSetExprMonad return NumericSetExprMonad(op, sqlop, monad, monad2) return attrset_binop class AttrSetMonad(SetMixin, Monad): def __init__(monad, parent, attr): - translator = parent.translator item_type = normalize_type(attr.py_type) - Monad.__init__(monad, translator, SetType(item_type)) + Monad.__init__(monad, SetType(item_type)) monad.parent = parent monad.attr = attr - monad.subquery = None + monad.sqlquery = None monad.tableref = None def cmp(monad, op, monad2): - translator = monad.translator if type(monad2.type) is SetType \ and are_comparable_types(monad.type.item_type, monad2.type.item_type): pass elif monad.type != monad2.type: check_comparable(monad, monad2) @@ -2065,10 +2976,10 @@ def contains(monad, item, not_in=False): check_comparable(item, monad, 'in') if not translator.hint_join: sqlop = 'NOT_IN' if not_in else 'IN' - subquery = monad._subselect() - expr_list = subquery.expr_list - from_ast = subquery.from_ast - conditions = subquery.outer_conditions + subquery.conditions + sqlquery = monad._subselect() + expr_list = sqlquery.expr_list + from_ast = sqlquery.from_ast + conditions = sqlquery.outer_conditions + sqlquery.conditions if len(expr_list) == 1: subquery_ast = [ 'SELECT', [ 'ALL' ] + expr_list, from_ast, [ 'WHERE' ] + conditions ] sql_ast = [ sqlop, item.getsql()[0], subquery_ast ] @@ -2078,30 +2989,30 @@ def contains(monad, item, not_in=False): else: conditions += [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item.getsql(), expr_list) ] sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS', from_ast, [ 'WHERE' ] + conditions ] - result = translator.BoolExprMonad(translator, sql_ast) + result = BoolExprMonad(sql_ast, nullable=False) result.nogroup = True return result elif not not_in: translator.distinct = True - tableref = monad.make_tableref(translator.subquery) + tableref = monad.make_tableref(translator.sqlquery) expr_list = monad.make_expr_list() expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) - return translator.BoolExprMonad(translator, expr_ast) + return BoolExprMonad(expr_ast, nullable=False) else: - subquery = Subquery(translator.subquery) - tableref = monad.make_tableref(subquery) + sqlquery = SqlQuery(translator, translator.sqlquery) + tableref = monad.make_tableref(sqlquery) attr = monad.attr alias, columns = tableref.make_join(pk_only=attr.reverse) expr_list = monad.make_expr_list() if not attr.reverse: columns = attr.columns - from_ast = translator.subquery.from_ast + from_ast = translator.sqlquery.from_ast from_ast[0] = 'LEFT_JOIN' - from_ast.extend(subquery.from_ast[1:]) + from_ast.extend(sqlquery.from_ast[1:]) conditions = [ [ 'EQ', [ 'COLUMN', alias, column ], expr ] for column, expr in izip(columns, item.getsql()) ] - conditions.extend(subquery.conditions) + conditions.extend(sqlquery.conditions) from_ast[-1][-1] = sqland([ from_ast[-1][-1] ] + conditions) expr_ast = sqland([ [ 'IS_NULL', expr ] for expr in expr_list ]) - return translator.BoolExprMonad(translator, expr_ast) + return BoolExprMonad(expr_ast, nullable=False) def getattr(monad, name): try: return Monad.getattr(monad, name) except AttributeError: pass @@ -2109,7 +3020,13 @@ def getattr(monad, name): if not isinstance(entity, EntityMeta): throw(AttributeError) attr = entity._adict_.get(name) if attr is None: throw(AttributeError) - return monad.translator.AttrSetMonad(monad, attr) + return AttrSetMonad(monad, attr) + def call_select(monad): + # calling with lambda argument processed in preCallFunc + return monad + call_filter = call_select + def call_exists(monad): + return monad def requires_distinct(monad, joined=False, for_count=False): if monad.parent.requires_distinct(joined): return True reverse = monad.attr.reverse @@ -2117,50 +3034,53 @@ def requires_distinct(monad, joined=False, for_count=False): if reverse.is_collection: translator = monad.translator if not for_count and not translator.hint_join: return True - if isinstance(monad.parent, monad.translator.AttrSetMonad): return True + if isinstance(monad.parent, AttrSetMonad): return True return False - def count(monad): + def count(monad, distinct=None): translator = monad.translator + distinct = distinct_from_monad(distinct, monad.requires_distinct(joined=translator.hint_join, for_count=True)) - subquery = monad._subselect() - expr_list = subquery.expr_list - from_ast = subquery.from_ast - inner_conditions = subquery.conditions - outer_conditions = subquery.outer_conditions + sqlquery = monad._subselect() + expr_list = sqlquery.expr_list + from_ast = sqlquery.from_ast + inner_conditions = sqlquery.conditions + outer_conditions = sqlquery.outer_conditions - distinct = monad.requires_distinct(joined=translator.hint_join, for_count=True) sql_ast = make_aggr = None extra_grouping = False if not distinct and monad.tableref.name_path != translator.optimize: - make_aggr = lambda expr_list: [ 'COUNT', 'ALL' ] + make_aggr = lambda expr_list: [ 'COUNT', None ] elif len(expr_list) == 1: - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT' ] + expr_list + make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list elif translator.dialect == 'Oracle': if monad.tableref.name_path == translator.optimize: alias, pk_columns = monad.tableref.make_join(pk_only=True) - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT' if distinct else 'ALL', [ 'COLUMN', alias, 'ROWID' ] ] + make_aggr = lambda expr_list: [ 'COUNT', distinct, [ 'COLUMN', alias, 'ROWID' ] ] else: extra_grouping = True - if translator.hint_join: make_aggr = lambda expr_list: [ 'COUNT', 'ALL' ] - else: make_aggr = lambda expr_list: [ 'COUNT', 'ALL', [ 'COUNT', 'ALL' ] ] + if translator.hint_join: make_aggr = lambda expr_list: [ 'COUNT', None ] + else: make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COUNT', None ] ] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr_list - expr = [ 'CASE', None, [ [ [ 'IS_NULL', row ], [ 'VALUE', None ] ] ], row ] - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT', expr ] + cond = [ 'IS_NULL', row ] + if translator.database.provider_name == 'cockroach': + cond = [ 'OR' ] + [ [ 'IS_NULL', expr ] for expr in expr_list ] + expr = [ 'CASE', None, [ [ cond, [ 'VALUE', None ] ] ], row ] + make_aggr = lambda expr_list: [ 'COUNT', True, expr ] elif translator.row_value_syntax: - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT' ] + expr_list + make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list elif translator.dialect == 'SQLite': if not distinct: alias, pk_columns = monad.tableref.make_join(pk_only=True) - make_aggr = lambda expr_list: [ 'COUNT', 'ALL', [ 'COLUMN', alias, 'ROWID' ] ] + make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COLUMN', alias, 'ROWID' ] ] elif translator.hint_join: # Same join as in Oracle extra_grouping = True - make_aggr = lambda expr_list: [ 'COUNT', 'ALL' ] + make_aggr = lambda expr_list: [ 'COUNT', None ] elif translator.sqlite_version < (3, 6, 21): alias, pk_columns = monad.tableref.make_join(pk_only=False) - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT', [ 'COLUMN', alias, 'ROWID' ] ] + make_aggr = lambda expr_list: [ 'COUNT', True, [ 'COLUMN', alias, 'ROWID' ] ] else: - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ 't', 'SELECT', [ [ 'DISTINCT' ] + expr_list, from_ast, [ 'WHERE' ] + outer_conditions + inner_conditions ] ] ] ] @@ -2171,12 +3091,13 @@ def count(monad): else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr, extra_grouping) translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = translator.ExprMonad.new(translator, int, sql_ast) + result = ExprMonad.new(int, sql_ast, nullable=False) if optimized: result.aggregated = True else: result.nogroup = True return result len = count - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None, sep=None): + distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator item_type = monad.type.item_type @@ -2188,47 +3109,57 @@ def aggregate(monad, func_name): if item_type not in comparable_types: throw(TypeError, "Function %s() expects query or items of comparable type, got %r in {EXPR}" % (func_name.lower(), type2str(item_type))) + elif func_name == 'GROUP_CONCAT': + if isinstance(item_type, EntityMeta) and item_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover - if monad.forced_distinct and func_name in ('SUM', 'AVG'): - make_aggr = lambda expr_list: [ func_name ] + expr_list + [ True ] - else: - make_aggr = lambda expr_list: [ func_name ] + expr_list + def make_aggr(expr_list): + result = [ func_name, distinct ] + expr_list + if sep is not None: + assert func_name == 'GROUP_CONCAT' + result.append(['VALUE', sep]) + return result + + # make_aggr = lambda expr_list: [ func_name, distinct ] + expr_list if translator.hint_join: sql_ast, optimized = monad._joined_subselect(make_aggr, coalesce_to_zero=(func_name=='SUM')) else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr) - result_type = float if func_name == 'AVG' else item_type + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = item_type translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = translator.ExprMonad.new(monad.translator, result_type, sql_ast) + result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') if optimized: result.aggregated = True else: result.nogroup = True return result def nonzero(monad): - subquery = monad._subselect() - sql_ast = [ 'EXISTS', subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] - translator = monad.translator - return translator.BoolExprMonad(translator, sql_ast) + sqlquery = monad._subselect() + sql_ast = [ 'EXISTS', sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] + return BoolExprMonad(sql_ast, nullable=False) def negate(monad): - subquery = monad._subselect() - sql_ast = [ 'NOT_EXISTS', subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] - translator = monad.translator - return translator.BoolExprMonad(translator, sql_ast) - def make_tableref(monad, subquery): + sqlquery = monad._subselect() + sql_ast = [ 'NOT_EXISTS', sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] + return BoolExprMonad(sql_ast, nullable=False) + call_is_empty = negate + def make_tableref(monad, sqlquery): parent = monad.parent attr = monad.attr - translator = monad.translator if isinstance(parent, ObjectMixin): parent_tableref = parent.tableref - elif isinstance(parent, translator.AttrSetMonad): parent_tableref = parent.make_tableref(subquery) + elif isinstance(parent, AttrSetMonad): parent_tableref = parent.make_tableref(sqlquery) else: assert False # pragma: no cover if attr.reverse: name_path = parent_tableref.name_path + '-' + attr.name - monad.tableref = subquery.get_tableref(name_path) \ - or subquery.add_tableref(name_path, parent_tableref, attr) + monad.tableref = sqlquery.get_tableref(name_path) \ + or sqlquery.add_tableref(name_path, parent_tableref, attr) else: monad.tableref = parent_tableref monad.tableref.can_affect_distinct = True return monad.tableref @@ -2244,35 +3175,36 @@ def make_expr_list(monad): return [ [ 'COLUMN', alias, column ] for column in columns ] def _aggregated_scalar_subselect(monad, make_aggr, extra_grouping=False): translator = monad.translator - subquery = monad._subselect() + sqlquery = monad._subselect() optimized = False if translator.optimize == monad.tableref.name_path: - sql_ast = make_aggr(subquery.expr_list) + sql_ast = make_aggr(sqlquery.expr_list) optimized = True if not translator.from_optimized: - from_ast = monad.subquery.from_ast[1:] - from_ast[0] = from_ast[0] + [ sqland(subquery.outer_conditions) ] - translator.subquery.from_ast.extend(from_ast) + from_ast = monad.sqlquery.from_ast[1:] + assert sqlquery.outer_conditions + from_ast[0] = from_ast[0] + [ sqland(sqlquery.outer_conditions) ] + translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True - else: sql_ast = [ 'SELECT', [ 'AGGREGATES', make_aggr(subquery.expr_list) ], - subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] + else: sql_ast = [ 'SELECT', [ 'AGGREGATES', make_aggr(sqlquery.expr_list) ], + sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] if extra_grouping: # This is for Oracle only, with COUNT(COUNT(*)) - sql_ast.append([ 'GROUP_BY' ] + subquery.expr_list) + sql_ast.append([ 'GROUP_BY' ] + sqlquery.expr_list) return sql_ast, optimized def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=False): translator = monad.translator - subquery = monad._subselect() - expr_list = subquery.expr_list - from_ast = subquery.from_ast - inner_conditions = subquery.conditions - outer_conditions = subquery.outer_conditions + sqlquery = monad._subselect() + expr_list = sqlquery.expr_list + from_ast = sqlquery.from_ast + inner_conditions = sqlquery.conditions + outer_conditions = sqlquery.outer_conditions groupby_columns = [ inner_column[:] for cond, outer_column, inner_column in outer_conditions ] - assert len(set(alias for _, alias, column in groupby_columns)) == 1 + assert len({alias for _, alias, column in groupby_columns}) == 1 if extra_grouping: - inner_alias = translator.subquery.get_short_alias(None, 't') + inner_alias = translator.sqlquery.make_alias('t') inner_columns = [ 'DISTINCT' ] col_mapping = {} col_names = set() @@ -2285,7 +3217,7 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F expr = [ 'AS', column_ast, cname ] new_name = cname else: - new_name = 'expr-%d' % next(translator.subquery.expr_counter) + new_name = 'expr-%d' % next(translator.sqlquery.expr_counter) col_mapping[tname, cname] = new_name expr = [ 'AS', column_ast, new_name ] inner_columns.append(expr) @@ -2301,41 +3233,42 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F new_name = col_mapping[tname, cname] outer_conditions[i] = [ cond, outer_column, [ 'COLUMN', inner_alias, new_name ] ] - subquery_columns = [ 'ALL' ] + subselect_columns = [ 'ALL' ] for column_ast in groupby_columns: assert column_ast[0] == 'COLUMN' - subquery_columns.append([ 'AS', column_ast, column_ast[2] ]) - expr_name = 'expr-%d' % next(translator.subquery.expr_counter) - subquery_columns.append([ 'AS', make_aggr(expr_list), expr_name ]) - subquery_ast = [ subquery_columns, from_ast ] + subselect_columns.append([ 'AS', column_ast, column_ast[2] ]) + expr_name = 'expr-%d' % next(translator.sqlquery.expr_counter) + subselect_columns.append([ 'AS', make_aggr(expr_list), expr_name ]) + subquery_ast = [ subselect_columns, from_ast ] if inner_conditions and not extra_grouping: subquery_ast.append([ 'WHERE' ] + inner_conditions) subquery_ast.append([ 'GROUP_BY' ] + groupby_columns) - alias = translator.subquery.get_short_alias(None, 't') + alias = translator.sqlquery.make_alias('t') for cond in outer_conditions: cond[2][1] = alias - translator.subquery.from_ast.append([ alias, 'SELECT', subquery_ast, sqland(outer_conditions) ]) + translator.sqlquery.from_ast.append([ alias, 'SELECT', subquery_ast, sqland(outer_conditions) ]) expr_ast = [ 'COLUMN', alias, expr_name ] if coalesce_to_zero: expr_ast = [ 'COALESCE', expr_ast, [ 'VALUE', 0 ] ] return expr_ast, False - def _subselect(monad): - if monad.subquery is not None: return monad.subquery + def _subselect(monad, sqlquery=None, extract_outer_conditions=True): + if monad.sqlquery is not None: return monad.sqlquery attr = monad.attr translator = monad.translator - subquery = Subquery(translator.subquery) - monad.make_tableref(subquery) - subquery.expr_list = monad.make_expr_list() + if sqlquery is None: + sqlquery = SqlQuery(translator, translator.sqlquery) + monad.make_tableref(sqlquery) + sqlquery.expr_list = monad.make_expr_list() if not attr.reverse and not attr.is_required: - subquery.conditions.extend([ 'IS_NOT_NULL', expr ] for expr in subquery.expr_list) - if subquery is not translator.subquery: - outer_cond = subquery.from_ast[1].pop() - if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] - else: subquery.outer_conditions = [ outer_cond ] - monad.subquery = subquery - return subquery - def getsql(monad, subquery=None): - if subquery is None: subquery = monad.translator.subquery - monad.make_tableref(subquery) + sqlquery.conditions.extend([ 'IS_NOT_NULL', expr ] for expr in sqlquery.expr_list) + if sqlquery is not translator.sqlquery and extract_outer_conditions: + outer_cond = sqlquery.from_ast[1].pop() + if outer_cond[0] == 'AND': sqlquery.outer_conditions = outer_cond[1:] + else: sqlquery.outer_conditions = [ outer_cond ] + monad.sqlquery = sqlquery + return sqlquery + def getsql(monad, sqlquery=None): + if sqlquery is None: sqlquery = monad.translator.sqlquery + monad.make_tableref(sqlquery) return monad.make_expr_list() __add__ = make_attrset_binop('+', 'ADD') __sub__ = make_attrset_binop('-', 'SUB') @@ -2345,7 +3278,6 @@ def getsql(monad, subquery=None): def make_numericset_binop(op, sqlop): def numericset_binop(monad, monad2): - NumericSetExprMonad = monad.translator.NumericSetExprMonad return NumericSetExprMonad(op, sqlop, monad, monad2) return numericset_binop @@ -2355,43 +3287,52 @@ def __init__(monad, op, sqlop, left, right): assert type(result_type) is SetType if result_type.item_type not in numeric_types: throw(TypeError, _binop_errmsg % (type2str(left.type), type2str(right.type), op)) - Monad.__init__(monad, left.translator, result_type) + Monad.__init__(monad, result_type) monad.op = op monad.sqlop = sqlop monad.left = left monad.right = right - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None, sep=None): + distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator - subquery = Subquery(translator.subquery) - expr = monad.getsql(subquery)[0] + sqlquery = SqlQuery(translator, translator.sqlquery) + expr = monad.getsql(sqlquery)[0] translator.aggregated_subquery_paths.add(monad.tableref.name_path) - outer_cond = subquery.from_ast[1].pop() - if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] - else: subquery.outer_conditions = [ outer_cond ] - result_type = float if func_name == 'AVG' else monad.type.item_type - aggr_ast = [ func_name, expr ] - if monad.forced_distinct and func_name in ('SUM', 'AVG'): aggr_ast.append(True) + outer_cond = sqlquery.from_ast[1].pop() + if outer_cond[0] == 'AND': sqlquery.outer_conditions = outer_cond[1:] + else: sqlquery.outer_conditions = [ outer_cond ] + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = monad.type.item_type + aggr_ast = [ func_name, distinct, expr ] + if func_name == 'GROUP_CONCAT': + if sep is not None: + aggr_ast.append(['VALUE', sep]) if translator.optimize != monad.tableref.name_path: sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], - subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] - result = translator.ExprMonad.new(translator, result_type, sql_ast) + sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] + result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') result.nogroup = True else: if not translator.from_optimized: - from_ast = subquery.from_ast[1:] - from_ast[0] = from_ast[0] + [ sqland(subquery.outer_conditions) ] - translator.subquery.from_ast.extend(from_ast) + from_ast = sqlquery.from_ast[1:] + assert sqlquery.outer_conditions + from_ast[0] = from_ast[0] + [ sqland(sqlquery.outer_conditions) ] + translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True sql_ast = aggr_ast - result = translator.ExprMonad.new(translator, result_type, sql_ast) + result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') result.aggregated = True return result - def getsql(monad, subquery=None): - if subquery is None: subquery = monad.translator.subquery + def getsql(monad, sqlquery=None): + if sqlquery is None: sqlquery = monad.translator.sqlquery left, right = monad.left, monad.right - left_expr = left.getsql(subquery)[0] - right_expr = right.getsql(subquery)[0] + left_expr = left.getsql(sqlquery)[0] + right_expr = right.getsql(sqlquery)[0] if isinstance(left, NumericMixin): left_path = '' else: left_path = left.tableref.name_path + '-' if isinstance(right, NumericMixin): right_path = '' @@ -2409,33 +3350,47 @@ def getsql(monad, subquery=None): class QuerySetMonad(SetMixin, Monad): nogroup = True - def __init__(monad, translator, subtranslator): - monad.translator = translator - monad.subtranslator = subtranslator + def __init__(monad, subtranslator): item_type = subtranslator.expr_type - monad.item_type = item_type monad_type = SetType(item_type) - Monad.__init__(monad, translator, monad_type) + Monad.__init__(monad, monad_type) + monad.subtranslator = subtranslator + monad.item_type = item_type + monad.limit = monad.offset = None + def requires_distinct(monad, joined=False): + assert False + def call_limit(monad, limit=None, offset=None): + if limit is not None and not isinstance(limit, int_types): + if not isinstance(limit, (NoneMonad, NumericConstMonad)): + throw(TypeError, '`limit` parameter should be of int type') + limit = limit.value + if offset is not None and not isinstance(offset, int_types): + if not isinstance(offset, (NoneMonad, NumericConstMonad)): + throw(TypeError, '`offset` parameter should be of int type') + offset = offset.value + monad.limit = limit + monad.offset = offset + return monad def contains(monad, item, not_in=False): translator = monad.translator check_comparable(item, monad, 'in') - if isinstance(item, translator.ListMonad): + if isinstance(item, ListMonad): item_columns = [] for subitem in item.items: item_columns.extend(subitem.getsql()) else: item_columns = item.getsql() sub = monad.subtranslator - if translator.hint_join and len(sub.subquery.from_ast[1]) == 3: - subquery_ast = sub.shallow_copy_of_subquery_ast() + if translator.hint_join and len(sub.sqlquery.from_ast[1]) == 3: + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False) select_ast, from_ast, where_ast = subquery_ast[1:4] - subquery = translator.subquery + sqlquery = translator.sqlquery if not not_in: translator.distinct = True - if subquery.from_ast[0] == 'FROM': - subquery.from_ast[0] = 'INNER_JOIN' + if sqlquery.from_ast[0] == 'FROM': + sqlquery.from_ast[0] = 'INNER_JOIN' else: - subquery.left_join = True - subquery.from_ast[0] = 'LEFT_JOIN' + sqlquery.left_join = True + sqlquery.from_ast[0] = 'LEFT_JOIN' col_names = set() new_names = [] exprs = [] @@ -2449,90 +3404,122 @@ def contains(monad, item, not_in=False): new_names.append(col_name) select_ast[i] = [ 'AS', column_ast, col_name ] continue - new_name = 'expr-%d' % next(subquery.expr_counter) + new_name = 'expr-%d' % next(sqlquery.expr_counter) new_names.append(new_name) select_ast[i] = [ 'AS', column_ast, new_name ] - alias = subquery.get_short_alias(None, 't') + alias = sqlquery.make_alias('t') outer_conditions = [ [ 'EQ', item_column, [ 'COLUMN', alias, new_name ] ] for item_column, new_name in izip(item_columns, new_names) ] - subquery.from_ast.append([ alias, 'SELECT', subquery_ast[1:], sqland(outer_conditions) ]) + sqlquery.from_ast.append([ alias, 'SELECT', subquery_ast[1:], sqland(outer_conditions) ]) if not_in: sql_ast = sqland([ [ 'IS_NULL', [ 'COLUMN', alias, new_name ] ] for new_name in new_names ]) else: sql_ast = [ 'EQ', [ 'VALUE', 1 ], [ 'VALUE', 1 ] ] else: if len(item_columns) == 1: - subquery_ast = sub.shallow_copy_of_subquery_ast(is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', item_columns[0], subquery_ast ] elif translator.row_value_syntax: - subquery_ast = sub.shallow_copy_of_subquery_ast(is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', [ 'ROW' ] + item_columns, subquery_ast ] else: - subquery_ast = sub.shallow_copy_of_subquery_ast() + ambiguous_names = set() + if sub.injected: + for name in translator.sqlquery.tablerefs: + if name in sub.sqlquery.tablerefs: + ambiguous_names.add(name) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False) + if ambiguous_names: + select_ast = subquery_ast[1] + expr_aliases = [] + for i, expr_ast in enumerate(select_ast): + if i > 0: + if expr_ast[0] == 'AS': + expr_ast = expr_ast[1] + expr_alias = 'expr-%d' % i + expr_aliases.append(expr_alias) + expr_ast = [ 'AS', expr_ast, expr_alias ] + select_ast[i] = expr_ast + + new_table_alias = translator.sqlquery.make_alias('t') + new_select_ast = [ 'ALL' ] + for expr_alias in expr_aliases: + new_select_ast.append([ 'COLUMN', new_table_alias, expr_alias ]) + new_from_ast = [ 'FROM', [ new_table_alias, 'SELECT', subquery_ast[1:] ] ] + new_where_ast = [ 'WHERE' ] + subquery_ast = [ 'SELECT', new_select_ast, new_from_ast, new_where_ast ] select_ast, from_ast, where_ast = subquery_ast[1:4] in_conditions = [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item_columns, select_ast[1:]) ] - if not sub.aggregated: where_ast += in_conditions - else: + if not ambiguous_names and sub.aggregated: having_ast = find_or_create_having_ast(subquery_ast) having_ast += in_conditions + else: where_ast += in_conditions sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] - return translator.BoolExprMonad(translator, sql_ast) + return BoolExprMonad(sql_ast, nullable=False) def nonzero(monad): - subquery_ast = monad.subtranslator.shallow_copy_of_subquery_ast() + subquery_ast = monad.subtranslator.construct_subquery_ast(distinct=False) + expr_monads = monad.subtranslator.expr_monads + if len(expr_monads) > 1: + throw(NotImplementedError) + expr_monad = expr_monads[0] + if not isinstance(expr_monad, ObjectIterMonad): + sql = expr_monad.nonzero().getsql() + assert subquery_ast[3][0] == 'WHERE' + subquery_ast[3].append(sql[0]) subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] - translator = monad.translator - return translator.BoolExprMonad(translator, subquery_ast) + return BoolExprMonad(subquery_ast, nullable=False) def negate(monad): sql = monad.nonzero().sql assert sql[0] == 'EXISTS' - translator = monad.translator - return translator.BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:]) - def count(monad): + return BoolExprMonad([ 'NOT_EXISTS' ] + sql[1:], nullable=False) + def count(monad, distinct=None): + distinct = distinct_from_monad(distinct) translator = monad.translator sub = monad.subtranslator + if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] sql_ast = None expr_type = sub.expr_type if isinstance(expr_type, (tuple, EntityMeta)): - if not sub.distinct: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ] + if not sub.distinct and not distinct: + select_ast = [ 'AGGREGATES', [ 'COUNT', None ] ] elif len(sub.expr_columns) == 1: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'DISTINCT' ] + sub.expr_columns ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct ] + sub.expr_columns ] elif translator.dialect == 'Oracle': - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL', [ 'COUNT', 'ALL' ] ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None, [ 'COUNT', None ] ] ], from_ast, where_ast, [ 'GROUP_BY' ] + sub.expr_columns ] elif translator.row_value_syntax: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'DISTINCT' ] + sub.expr_columns ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct ] + sub.expr_columns ] elif translator.dialect == 'SQLite': if translator.sqlite_version < (3, 6, 21): if sub.aggregated: throw(TranslationError) alias, pk_columns = sub.tableref.make_join(pk_only=False) - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] sql_ast = [ 'SELECT', - [ 'AGGREGATES', [ 'COUNT', 'DISTINCT', [ 'COLUMN', alias, 'ROWID' ] ] ], + [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, [ 'COLUMN', alias, 'ROWID' ] ] ], from_ast, where_ast ] else: - alias = translator.subquery.get_short_alias(None, 't') - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], - [ 'FROM', [ alias, 'SELECT', [ - [ 'DISTINCT' ] + sub.expr_columns, from_ast, where_ast ] ] ] ] + alias = translator.sqlquery.make_alias('t') + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], + [ 'FROM', [ alias, 'SELECT', [ [ 'DISTINCT' if distinct is not False else 'ALL' ] + + sub.expr_columns, from_ast, where_ast ] ] ] ] else: assert False # pragma: no cover elif len(sub.expr_columns) == 1: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'DISTINCT', sub.expr_columns[0] ] ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, sub.expr_columns[0] ] ] else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] - return translator.ExprMonad.new(translator, int, sql_ast) + return ExprMonad.new(int, sql_ast, nullable=False) len = count - def aggregate(monad, func_name): - translator = monad.translator + def aggregate(monad, func_name, distinct=None, sep=None): + distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] expr_type = sub.expr_type if func_name in ('SUM', 'AVG'): @@ -2543,52 +3530,50 @@ def aggregate(monad, func_name): if expr_type not in comparable_types: throw(TypeError, "Function %s() cannot be applied to type %r in {EXPR}" % (func_name.lower(), type2str(expr_type))) + elif func_name == 'GROUP_CONCAT': + if isinstance(expr_type, EntityMeta) and expr_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 - aggr_ast = [ func_name, sub.expr_columns[0] ] - if monad.forced_distinct and func_name in ('SUM', 'AVG'): aggr_ast.append(True) + aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] + if func_name == 'GROUP_CONCAT': + if sep is not None: + aggr_ast.append(['VALUE', sep]) select_ast = [ 'AGGREGATES', aggr_ast ] sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] - result_type = float if func_name == 'AVG' else expr_type - return translator.ExprMonad.new(translator, result_type, sql_ast) - def call_count(monad): - return monad.count() - def call_sum(monad): - return monad.aggregate('SUM') + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = expr_type + return ExprMonad.new(result_type, sql_ast, func_name != 'SUM') + def call_count(monad, distinct=None): + return monad.count(distinct=distinct) + def call_sum(monad, distinct=None): + return monad.aggregate('SUM', distinct) def call_min(monad): return monad.aggregate('MIN') def call_max(monad): return monad.aggregate('MAX') - def call_avg(monad): - return monad.aggregate('AVG') + def call_avg(monad, distinct=None): + return monad.aggregate('AVG', distinct) + def call_group_concat(monad, sep=None, distinct=None): + if sep is not None: + if not isinstance(sep, basestring): + throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) + return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) + def getsql(monad): + return monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset) -def find_or_create_having_ast(subquery_ast): +def find_or_create_having_ast(sections): groupby_offset = None - for i, section in enumerate(subquery_ast): + for i, section in enumerate(sections): section_name = section[0] if section_name == 'GROUP_BY': groupby_offset = i elif section_name == 'HAVING': return section having_ast = [ 'HAVING' ] - subquery_ast.insert(groupby_offset + 1, having_ast) + sections.insert(groupby_offset + 1, having_ast) return having_ast - - - - -# class LimitOffsetMonad(Monad): -# def __init__(monad, translator, limit, offset): -# ExprMonad.__init__(monad, translator, type, sql): - -# def getsql(monad): -# return [sql] - - -######## -for name, value in items_list(globals()): - if name.endswith('Monad') or name.endswith('Mixin'): - setattr(SQLTranslator, name, value) -del name, value - - diff --git a/pony/orm/tests/__init__.py b/pony/orm/tests/__init__.py index 150399271..2f1caa71d 100644 --- a/pony/orm/tests/__init__.py +++ b/pony/orm/tests/__init__.py @@ -1,4 +1,72 @@ +import unittest +import os +import types import pony.orm.core, pony.options pony.options.CUT_TRACEBACK = False -pony.orm.core.sql_debug(False) \ No newline at end of file +pony.orm.core.sql_debug(False) + + +def _load_env(): + settings_filename = os.environ.get('pony_test_db') + if settings_filename is None: + print('use default sqlite provider') + return dict(provider='sqlite', filename=':memory:') + with open(settings_filename, 'r') as f: + content = f.read() + + config = {} + exec(content, config) + settings = config.get('settings') + if settings is None or not isinstance(settings, dict): + raise ValueError('Incorrect settings pony test db file contents') + provider = settings.get('provider') + if provider is None: + raise ValueError('Incorrect settings pony test db file contents: provider was not specified') + print('use provider %s' % provider) + return settings + + +db_params = _load_env() + + +def setup_database(db): + if db.provider is None: + db.bind(**db_params) + if db.schema is None: + db.generate_mapping(check_tables=False) + db.drop_all_tables(with_all_data=True) + db.create_tables() + + +def teardown_database(db): + if db.schema: + db.drop_all_tables(with_all_data=True) + db.disconnect() + + +def only_for(providers): + if not isinstance(providers, (list, tuple)): + providers = [providers] + def decorator(x): + if isinstance(x, type) and issubclass(x, unittest.TestCase): + @classmethod + def setUpClass(cls): + raise unittest.SkipTest('%s tests implemented only for %s provider%s' % ( + cls.__name__, ', '.join(providers), '' if len(providers) < 2 else 's' + )) + if db_params['provider'] not in providers: + x.setUpClass = setUpClass + result = x + elif isinstance(x, types.FunctionType): + def new_test_func(self): + if db_params['provider'] not in providers: + raise unittest.SkipTest('%s test implemented only for %s provider%s' % ( + x.__name__, ', '.join(providers), '' if len(providers) < 2 else 's' + )) + return x(self) + result = new_test_func + else: + raise TypeError + return result + return decorator diff --git a/pony/orm/tests/fixtures.py b/pony/orm/tests/fixtures.py new file mode 100644 index 000000000..b1af03438 --- /dev/null +++ b/pony/orm/tests/fixtures.py @@ -0,0 +1,533 @@ +import sys +import os +import logging + +from pony.py23compat import PY2 +from ponytest import with_cli_args, pony_fixtures, provider_validators, provider, Fixture, \ + ValidationError + +from functools import wraps, partial +import click +from contextlib import contextmanager, closing + +from pony.utils import cached_property, class_property + +if not PY2: + from contextlib import contextmanager, ContextDecorator +else: + from contextlib2 import contextmanager, ContextDecorator + +import unittest + +from pony.orm import db_session, Database, rollback, delete + +if not PY2: + from io import StringIO +else: + from StringIO import StringIO + +from multiprocessing import Process + +import threading + +class DBContext(ContextDecorator): + + fixture = 'db' + enabled = False + + def __init__(self, Test): + if not isinstance(Test, type): + # FIXME ? + TestCls = type(Test) + NewClass = type(TestCls.__name__, (TestCls,), {}) + NewClass.__module__ = TestCls.__module__ + NewClass.db = property(lambda t: self.db) + Test.__class__ = NewClass + else: + Test.db = class_property(lambda cls: self.db) + self.Test = Test + + @class_property + def fixture_name(cls): + return cls.db_provider + + @class_property + def db_provider(cls): + # is used in tests + return cls.provider_key + + def init_db(self): + raise NotImplementedError + + @cached_property + def db(self): + raise NotImplementedError + + def __enter__(self): + self.init_db() + try: + self.Test.make_entities() + except (AttributeError, TypeError): + # No method make_entities with due signature + pass + else: + self.db.generate_mapping(check_tables=True, create_tables=True) + return self.db + + def __exit__(self, *exc_info): + self.db.provider.disconnect() + + @classmethod + def validate_fixtures(cls, fixtures, config): + return any(f.fixture_key == 'db' for f in fixtures) + + db_name = 'testdb' + + +@provider() +class GenerateMapping(ContextDecorator): + + weight = 200 + fixture = 'generate_mapping' + + def __init__(self, Test): + self.Test = Test + + def __enter__(self): + db = getattr(self.Test, 'db', None) + if not db or not db.entities: + return + for entity in db.entities.values(): + if entity._database_.schema is None: + db.generate_mapping(check_tables=True, create_tables=True) + break + + def __exit__(self, *exc_info): + pass + +@provider() +class MySqlContext(DBContext): + provider_key = 'mysql' + + def drop_db(self, cursor): + cursor.execute('use sys') + cursor.execute('drop database %s' % self.db_name) + + + def init_db(self): + from pony.orm.dbproviders.mysql import mysql_module + with closing(mysql_module.connect(**self.CONN).cursor()) as c: + try: + self.drop_db(c) + except mysql_module.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + c.execute('create database %s' % self.db_name) + c.execute('use %s' % self.db_name) + + CONN = { + 'host': "localhost", + 'user': "ponytest", + 'passwd': "ponytest", + } + + @cached_property + def db(self): + CONN = dict(self.CONN, db=self.db_name) + return Database('mysql', **CONN) + +@provider() +class SqlServerContext(DBContext): + + provider_key = 'sqlserver' + + def get_conn_string(self, db=None): + s = ( + 'DSN=MSSQLdb;' + 'SERVER=mssql;' + 'UID=sa;' + 'PWD=pass;' + ) + if db: + s += 'DATABASE=%s' % db + return s + + @cached_property + def db(self): + CONN = self.get_conn_string(self.db_name) + return Database('mssqlserver', CONN) + + def init_db(self): + import pyodbc + cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor() + with closing(cursor) as c: + try: + self.drop_db(c) + except pyodbc.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + c.execute('''CREATE DATABASE %s DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci''' % self.db_name ) + c.execute('use %s' % self.db_name) + + def drop_db(self, cursor): + cursor.execute('use master') + cursor.execute('drop database %s' % self.db_name) + + +class SqliteMixin(DBContext): + + def init_db(self): + try: + os.remove(self.db_path) + except OSError as exc: + print('Failed to drop db: %s' % exc) + + @cached_property + def db_path(self): + p = os.path.dirname(__file__) + p = os.path.join(p, '%s.sqlite' % self.db_name) + return os.path.abspath(p) + + @cached_property + def db(self): + return Database('sqlite', self.db_path, create_db=True) + + +@provider() +class SqliteNoJson1(SqliteMixin): + provider_key = 'sqlite_no_json1' + enabled = True + + def __init__(self, cls): + self.Test = cls + cls.no_json1 = True + return super(SqliteNoJson1, self).__init__(cls) + + def __enter__(self): + resource = super(SqliteNoJson1, self).__enter__() + self.json1_available = self.Test.db.provider.json1_available + self.Test.db.provider.json1_available = False + return resource + + def __exit__(self, *exc_info): + self.Test.db.provider.json1_available = self.json1_available + return super(SqliteNoJson1, self).__exit__(*exc_info) + + +@provider() +class SqliteJson1(SqliteMixin): + provider_key = 'sqlite_json1' + + def __enter__(self): + result = super(SqliteJson1, self).__enter__() + if not self.db.provider.json1_available: + raise unittest.SkipTest + return result + + +@provider() +class PostgresContext(DBContext): + provider_key = 'postgresql' + + def get_conn_dict(self, no_db=False): + d = dict( + user='ponytest', password='ponytest', + host='localhost', database='postgres', + ) + if not no_db: + d.update(database=self.db_name) + return d + + def init_db(self): + import psycopg2 + conn = psycopg2.connect( + **self.get_conn_dict(no_db=True) + ) + conn.set_isolation_level(0) + with closing(conn.cursor()) as cursor: + try: + self.drop_db(cursor) + except psycopg2.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + cursor.execute('create database %s' % self.db_name) + + def drop_db(self, cursor): + cursor.execute('drop database %s' % self.db_name) + + + @cached_property + def db(self): + return Database('postgres', **self.get_conn_dict()) + + +@provider() +class OracleContext(DBContext): + provider_key = 'oracle' + + def __enter__(self): + os.environ.update(dict( + ORACLE_BASE='/u01/app/oracle', + ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1', + ORACLE_OWNR='oracle', + ORACLE_SID='orcl', + )) + return super(OracleContext, self).__enter__() + + def init_db(self): + + import cx_Oracle + with closing(self.connect_sys()) as conn: + with closing(conn.cursor()) as cursor: + try: + self._destroy_test_user(cursor) + except cx_Oracle.DatabaseError as exc: + print('Failed to drop user: %s' % exc) + try: + self._drop_tablespace(cursor) + except cx_Oracle.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + cursor.execute( + """CREATE TABLESPACE %(tblspace)s + DATAFILE '%(datafile)s' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s + """ % self.parameters) + cursor.execute( + """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s + TEMPFILE '%(datafile_tmp)s' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s + """ % self.parameters) + self._create_test_user(cursor) + + + def _drop_tablespace(self, cursor): + cursor.execute( + 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' + % self.parameters) + cursor.execute( + 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' + % self.parameters) + + + parameters = { + 'tblspace': 'test_tblspace', + 'tblspace_temp': 'test_tblspace_temp', + 'datafile': 'test_datafile.dbf', + 'datafile_tmp': 'test_datafile_tmp.dbf', + 'user': 'ponytest', + 'password': 'ponytest', + 'maxsize': '100M', + 'maxsize_tmp': '100M', + } + + def connect_sys(self): + import cx_Oracle + return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA) + + def connect_test(self): + import cx_Oracle + return cx_Oracle.connect('ponytest/ponytest@localhost/ORCL') + + + @cached_property + def db(self): + return Database('oracle', 'ponytest/ponytest@localhost/ORCL') + + def _create_test_user(self, cursor): + cursor.execute( + """CREATE USER %(user)s + IDENTIFIED BY %(password)s + DEFAULT TABLESPACE %(tblspace)s + TEMPORARY TABLESPACE %(tblspace_temp)s + QUOTA UNLIMITED ON %(tblspace)s + """ % self.parameters + ) + cursor.execute( + """GRANT CREATE SESSION, + CREATE TABLE, + CREATE SEQUENCE, + CREATE PROCEDURE, + CREATE TRIGGER + TO %(user)s + """ % self.parameters + ) + + def _destroy_test_user(self, cursor): + cursor.execute(''' + DROP USER %(user)s CASCADE + ''' % self.parameters) + + +@provider(fixture='log', weight=100, enabled=False) +@contextmanager +def logging_context(test): + level = logging.getLogger().level + from pony.orm.core import debug, sql_debug + logging.getLogger().setLevel(logging.INFO) + sql_debug(True) + yield + logging.getLogger().setLevel(level) + sql_debug(debug) + +@provider(fixture='log_all', weight=-100, enabled=False) +def log_all(Test): + return logging_context(Test) + + +# @with_cli_args +# @click.option('--log', 'scope', flag_value='test') +# @click.option('--log-all', 'scope', flag_value='all') +# def use_logging(scope): +# if scope == 'test': +# yield logging_context +# elif scope =='all': +# yield log_all + + +# @provider(enabled=False) +# class DBSessionProvider(object): +# +# fixture= 'db_session' +# +# weight = 30 +# +# def __new__(cls, test): +# return db_session + + +@provider(fixture='rollback', weight=40) +@contextmanager +def do_rollback(test): + try: + yield + finally: + rollback() + + +@provider() +class SeparateProcess(object): + + # TODO read failures from sep process better + + fixture = 'separate_process' + enabled = False + + def __init__(self, Test): + self.Test = Test + + def __call__(self, func): + def wrapper(Test): + rnr = unittest.runner.TextTestRunner() + TestCls = Test if isinstance(Test, type) else type(Test) + def runTest(self): + try: + func(Test) + finally: + rnr.stream = unittest.runner._WritelnDecorator(StringIO()) + name = getattr(func, '__name__', 'runTest') + Case = type(TestCls.__name__, (TestCls,), {name: runTest}) + Case.__module__ = TestCls.__module__ + case = Case(name) + suite = unittest.suite.TestSuite([case]) + def run(): + result = rnr.run(suite) + if not result.wasSuccessful(): + sys.exit(1) + p = Process(target=run, args=()) + p.start() + p.join() + case.assertEqual(p.exitcode, 0) + return wrapper + + @classmethod + def validate_chain(cls, fixtures, klass): + for f in fixtures: + if f.KEY in ('ipdb', 'ipdb_all'): + return False + for f in fixtures: + if f.KEY == 'db' and f.provider_key in ('sqlserver', 'oracle'): + return True + +@provider() +class ClearTables(ContextDecorator): + + fixture = 'clear_tables' + + def __init__(self, test): + self.test = test + + def __enter__(self): + pass + + @db_session + def __exit__(self, *exc_info): + db = self.test.db + for entity in db.entities.values(): + if entity._database_.schema is None: + break + delete(i for i in entity) + + +import signal + +@provider() +class Timeout(object): + + fixture = 'timeout' + + @with_cli_args + @click.option('--timeout', type=int) + def __init__(self, Test, timeout): + self.Test = Test + self.timeout = timeout if timeout else Test.TIMEOUT + + enabled = False + + class Exception(Exception): + pass + + class FailedInSubprocess(Exception): + pass + + def __call__(self, func): + def wrapper(test): + p = Process(target=func, args=(test,)) + p.start() + + def on_expired(): + p.terminate() + + t = threading.Timer(self.timeout, on_expired) + t.start() + p.join() + t.cancel() + if p.exitcode == -signal.SIGTERM: + raise self.Exception + elif p.exitcode: + raise self.FailedInSubprocess + + return wrapper + + @classmethod + @with_cli_args + @click.option('--timeout', type=int) + def validate_chain(cls, fixtures, klass, timeout): + if not getattr(klass, 'TIMEOUT', None) and not timeout: + return False + for f in fixtures: + if f.KEY in ('ipdb', 'ipdb_all'): + return False + for f in fixtures: + if f.KEY == 'db' and f.provider_key in ('sqlserver', 'oracle'): + return True + + +pony_fixtures['test'].extend([ + 'log', + 'clear_tables', +]) + +pony_fixtures['class'].extend([ + 'separate_process', + 'timeout', + 'db', + 'log_all', + 'generate_mapping', +]) diff --git a/pony/orm/tests/model1.py b/pony/orm/tests/model1.py index 8858b3d4c..96a01c208 100644 --- a/pony/orm/tests/model1.py +++ b/pony/orm/tests/model1.py @@ -1,8 +1,9 @@ from __future__ import absolute_import, print_function, division from pony.orm.core import * +from pony.orm.tests import db_params -db = Database('sqlite', ':memory:') +db = Database(**db_params) class Student(db.Entity): _table_ = "Students" @@ -33,7 +34,8 @@ class Mark(db.Entity): PrimaryKey(student, subject) -db.generate_mapping(create_tables=True) +db.generate_mapping(check_tables=False) + @db_session def populate_db(): @@ -56,4 +58,3 @@ def populate_db(): Mark(student=s102, subject=Chemistry, value=5) Mark(student=s103, subject=Physics, value=2) Mark(student=s103, subject=Chemistry, value=4) -populate_db() diff --git a/pony/orm/tests/py36_test_f_strings.py b/pony/orm/tests/py36_test_f_strings.py new file mode 100644 index 000000000..7da921c7f --- /dev/null +++ b/pony/orm/tests/py36_test_f_strings.py @@ -0,0 +1,75 @@ +import unittest +from pony.orm.core import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + + +db = Database() + +class Person(db.Entity): + first_name = Required(str) + last_name = Required(str) + age = Optional(int) + value = Required(float) + + +class TestFString(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) + Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) + Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) + Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) + Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) + Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + rollback() + db_session.__enter__() + def tearDown(self): + rollback() + db_session.__exit__() + + def test_1(self): + x = 'Alexander' + y = 'Tischenko' + q = select(p.id for p in Person if p.first_name + ' ' + p.last_name == f'{x} {y}') + self.assertEqual(set(q), {1}) + + def test_2(self): + q = select(p.id for p in Person if f'{p.first_name} {p.last_name}' == 'Alexander Tischenko') + self.assertEqual(set(q), {1}) + + def test_3(self): + x = 'Great' + q = select(f'{p.first_name} the {x}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander the Great'}) + + def test_4(self): + q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander 23'}) + + def test_5(self): + q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander 23'}) + + @raises_exception(NotImplementedError, 'You cannot set width and precision markers in query') + def test_6(self): + width = 3 + precision = 4 + q = select(p.id for p in Person if f'{p.value:{width}.{precision}}')[:] + self.assertEqual({2,}, set(q)) + + def test_7(self): + x = 'Tischenko' + q = select(p.first_name + f"{' ' + x}" for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander Tischenko'}) + + def test_8(self): + q = select(p for p in Person if not p.age)[:] diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index fb51bbd58..1b754d782 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -106,9 +106,9 @@ WHERE "s"."group" = 101 >>> avg(s.gpa for s in Student if s.group.dept.number == 44) SELECT AVG("s"."gpa") -FROM "Student" "s", "Group" "group-1" -WHERE "group-1"."dept" = 44 - AND "s"."group" = "group-1"."number" +FROM "Student" "s", "Group" "group" +WHERE "group"."dept" = 44 + AND "s"."group" = "group"."number" >>> select(s for s in Student if s.group.number == 101 and s.dob == max(s.dob for s in Student if s.group.number == 101)) @@ -116,9 +116,9 @@ SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE "s"."group" = 101 AND "s"."dob" = ( - SELECT MAX("s"."dob") - FROM "Student" "s" - WHERE "s"."group" = 101 + SELECT MAX("s-2"."dob") + FROM "Student" "s-2" + WHERE "s-2"."group" = 101 ) >>> select(g for g in Group if avg(s.gpa for s in g.students) > 4.5) @@ -135,10 +135,10 @@ WHERE ( SELECT "g"."number" FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" GROUP BY "g"."number" -HAVING AVG("student-1"."gpa") > 4.5 +HAVING AVG("student"."gpa") > 4.5 >>> select((s.group, min(s.gpa), max(s.gpa)) for s in Student) @@ -146,6 +146,22 @@ SELECT "s"."group", MIN("s"."gpa"), MAX("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" +>>> select((g, min(g.students.gpa), max(g.students.gpa)) for g in Group) + +SELECT "g"."number", MIN("student"."gpa"), MAX("student"."gpa") +FROM "Group" "g" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" +GROUP BY "g"."number" + +>>> select((g, g.students.name, min(g.students.gpa), max(g.students.gpa)) for g in Group) + +SELECT "g"."number", "student"."name", MIN("student"."gpa"), MAX("student"."gpa") +FROM "Group" "g" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" +GROUP BY "g"."number", "student"."name" + >>> count(s for s in Student if s.group.number == 101) SELECT COUNT(*) @@ -154,19 +170,19 @@ WHERE "s"."group" = 101 >>> select((g, count(g.students)) for g in Group if g.dept.number == 44) -SELECT "g"."number", COUNT(DISTINCT "student-1"."id") +SELECT "g"."number", COUNT(DISTINCT "student"."id") FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" WHERE "g"."dept" = 44 GROUP BY "g"."number" >>> select((s.group, count(s)) for s in Student if s.group.dept.number == 44) SELECT "s"."group", COUNT(DISTINCT "s"."id") -FROM "Student" "s", "Group" "group-1" -WHERE "group-1"."dept" = 44 - AND "s"."group" = "group-1"."number" +FROM "Student" "s", "Group" "group" +WHERE "group"."dept" = 44 + AND "s"."group" = "group"."number" GROUP BY "s"."group" >>> select((g, count(s for s in g.students if s.gpa <= 3), count(s for s in g.students if s.gpa > 3 and s.gpa <= 4), count(s for s in g.students if s.gpa > 4)) for g in Group) @@ -222,35 +238,35 @@ GROUP BY "item"."order" >>> select((order, sum(order.items.price * order.items.quantity)) for order in Order if order.id == 123) -SELECT "order"."id", coalesce(SUM(("orderitem-1"."price" * "orderitem-1"."quantity")), 0) +SELECT "order"."id", coalesce(SUM(("orderitem"."price" * "orderitem"."quantity")), 0) FROM "Order" "order" - LEFT JOIN "OrderItem" "orderitem-1" - ON "order"."id" = "orderitem-1"."order" + LEFT JOIN "OrderItem" "orderitem" + ON "order"."id" = "orderitem"."order" WHERE "order"."id" = 123 GROUP BY "order"."id" >>> select((item.order, item.order.total_price, sum(item.price * item.quantity)) for item in OrderItem if item.order.total_price < sum(item.price * item.quantity)) -SELECT "item"."order", "order-1"."total_price", coalesce(SUM(("item"."price" * "item"."quantity")), 0) -FROM "OrderItem" "item", "Order" "order-1" -WHERE "item"."order" = "order-1"."id" -GROUP BY "item"."order", "order-1"."total_price" -HAVING "order-1"."total_price" < coalesce(SUM(("item"."price" * "item"."quantity")), 0) +SELECT "item"."order", "order"."total_price", coalesce(SUM(("item"."price" * "item"."quantity")), 0) +FROM "OrderItem" "item", "Order" "order" +WHERE "item"."order" = "order"."id" +GROUP BY "item"."order", "order"."total_price" +HAVING "order"."total_price" < coalesce(SUM(("item"."price" * "item"."quantity")), 0) >>> select(c for c in Customer for p in c.orders.items.product if 'Tablets' in p.categories.name and count(p) > 1) -SELECT DISTINCT "c"."id" -FROM "Customer" "c", "Order" "order-1", "OrderItem" "orderitem-1" +SELECT "c"."id" +FROM "Customer" "c", "Order" "order", "OrderItem" "orderitem" WHERE 'Tablets' IN ( - SELECT "category-1"."name" - FROM "Category_Product" "t-1", "Category" "category-1" - WHERE "orderitem-1"."product" = "t-1"."product" - AND "t-1"."category" = "category-1"."id" + SELECT "category"."name" + FROM "Category_Product" "t-1", "Category" "category" + WHERE "orderitem"."product" = "t-1"."product" + AND "t-1"."category" = "category"."id" ) - AND "c"."id" = "order-1"."customer" - AND "order-1"."id" = "orderitem-1"."order" + AND "c"."id" = "order"."customer" + AND "order"."id" = "orderitem"."order" GROUP BY "c"."id" -HAVING COUNT(DISTINCT "orderitem-1"."product") > 1 +HAVING COUNT(DISTINCT "orderitem"."product") > 1 Schema: pony.orm.examples.university1 @@ -258,9 +274,9 @@ pony.orm.examples.university1 >>> select((s.group, count(s)) for s in Student if s.group.dept.number == 44 and avg(s.gpa) > 4) SELECT "s"."group", COUNT(DISTINCT "s"."id") -FROM "Student" "s", "Group" "group-1" -WHERE "group-1"."dept" = 44 - AND "s"."group" = "group-1"."number" +FROM "Student" "s", "Group" "group" +WHERE "group"."dept" = 44 + AND "s"."group" = "group"."number" GROUP BY "s"."group" HAVING AVG("s"."gpa") > 4 @@ -268,19 +284,19 @@ HAVING AVG("s"."gpa") > 4 SELECT "g"."number" FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" GROUP BY "g"."number" -HAVING MAX("student-1"."gpa") < 4 +HAVING MAX("student"."gpa") < 4 >>> select(g for g in Group if JOIN(max(g.students.gpa) < 4)) SELECT "g"."number" FROM "Group" "g" LEFT JOIN ( - SELECT "student-1"."group" AS "group", MAX("student-1"."gpa") AS "expr-1" - FROM "Student" "student-1" - GROUP BY "student-1"."group" + SELECT "student"."group" AS "group", MAX("student"."gpa") AS "expr-1" + FROM "Student" "student" + GROUP BY "student"."group" ) "t-1" ON "g"."number" = "t-1"."group" WHERE "t-1"."expr-1" < 4 @@ -500,7 +516,7 @@ GROUP BY "g"."number" >>> select((s, count(c)) for s in Student for c in s.courses) -SELECT DISTINCT "s"."id", COUNT(DISTINCT "c"."ROWID") +SELECT "s"."id", COUNT(DISTINCT "c"."ROWID") FROM "Student" "s", "Course_Student" "t-1", "Course" "c" WHERE "s"."id" = "t-1"."student" AND "t-1"."course_name" = "c"."name" @@ -509,7 +525,7 @@ GROUP BY "s"."id" Oracle: -SELECT DISTINCT "s"."ID", COUNT(DISTINCT "c"."ROWID") +SELECT "s"."ID", COUNT(DISTINCT "c"."ROWID") FROM "STUDENT" "s", "COURSE_STUDENT" "t-1", "COURSE" "c" WHERE "s"."ID" = "t-1"."STUDENT" AND "t-1"."COURSE_NAME" = "c"."NAME" @@ -518,7 +534,7 @@ GROUP BY "s"."ID" PostgreSQL: -SELECT DISTINCT "s"."id", COUNT(DISTINCT case when ("t-1"."course_name", "t-1"."course_semester") IS NULL then null else ("t-1"."course_name", "t-1"."course_semester") end) +SELECT "s"."id", COUNT(DISTINCT case when ("t-1"."course_name", "t-1"."course_semester") IS NULL then null else ("t-1"."course_name", "t-1"."course_semester") end) FROM "student" "s", "course_student" "t-1" WHERE "s"."id" = "t-1"."student" GROUP BY "s"."id" @@ -559,9 +575,9 @@ WHERE "s"."TEL" IS NULL SELECT DISTINCT "s"."name" FROM "Student" "s" WHERE "s"."name" IN ( - SELECT "s"."name" - FROM "Student" "s" - GROUP BY "s"."name" + SELECT "s-2"."name" + FROM "Student" "s-2" + GROUP BY "s-2"."name" HAVING COUNT(*) > 1 ) @@ -675,11 +691,11 @@ SELECT COUNT(*) FROM ( SELECT "g"."number" FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" WHERE "g"."number" > 101 GROUP BY "g"."number" - HAVING COUNT(DISTINCT "student-1"."id") > 0 + HAVING COUNT(DISTINCT "student"."id") > 0 ) "t" >>> count(g for g in Group if count(s for s in g.students) > 0 and g.number > 101) @@ -731,26 +747,26 @@ WHERE "gpa" > 3 DELETE FROM "Student" WHERE "id" IN ( SELECT "s"."id" - FROM "Student" "s", "Group" "group-1" - WHERE "group-1"."dept" = 1 - AND "s"."group" = "group-1"."number" + FROM "Student" "s", "Group" "group" + WHERE "group"."dept" = 1 + AND "s"."group" = "group"."number" ) MySQL: DELETE s FROM `student` `s` - INNER JOIN `group` `group-1` - ON `s`.`group` = `group-1`.`number` -WHERE `group-1`.`dept` = 1 + INNER JOIN `group` `group` + ON `s`.`group` = `group`.`number` +WHERE `group`.`dept` = 1 PostgreSQL: DELETE FROM "student" WHERE "id" IN ( SELECT "s"."id" - FROM "student" "s", "group" "group-1" - WHERE "group-1"."dept" = 1 - AND "s"."group" = "group-1"."number" + FROM "student" "s", "group" "group" + WHERE "group"."dept" = 1 + AND "s"."group" = "group"."number" ) Oracle: @@ -758,9 +774,9 @@ Oracle: DELETE FROM "STUDENT" WHERE "ID" IN ( SELECT "s"."ID" - FROM "STUDENT" "s", "GROUP" "group-1" - WHERE "group-1"."DEPT" = 1 - AND "s"."GROUP" = "group-1"."NUMBER" + FROM "STUDENT" "s", "GROUP" "group" + WHERE "group"."DEPT" = 1 + AND "s"."GROUP" = "group"."NUMBER" ) >>> select(c for c in Course if c.dept.name.startswith('D')).delete(bulk=True) @@ -768,26 +784,26 @@ WHERE "ID" IN ( DELETE FROM "Course" WHERE "ROWID" IN ( SELECT "c"."ROWID" - FROM "Course" "c", "Department" "department-1" - WHERE "department-1"."name" LIKE 'D%' - AND "c"."dept" = "department-1"."number" + FROM "Course" "c", "Department" "department" + WHERE "department"."name" LIKE 'D%' + AND "c"."dept" = "department"."number" ) MySQL: DELETE c FROM `course` `c` - INNER JOIN `department` `department-1` - ON `c`.`dept` = `department-1`.`number` -WHERE `department-1`.`name` LIKE 'D%%' + INNER JOIN `department` `department` + ON `c`.`dept` = `department`.`number` +WHERE `department`.`name` LIKE 'D%%' PostgreSQL: DELETE FROM "course" WHERE ("name", "semester") IN ( SELECT "c"."name", "c"."semester" - FROM "course" "c", "department" "department-1" - WHERE "department-1"."name" LIKE 'D%%' - AND "c"."dept" = "department-1"."number" + FROM "course" "c", "department" "department" + WHERE "department"."name" LIKE 'D%%' + AND "c"."dept" = "department"."number" ) Oracle: @@ -795,9 +811,9 @@ Oracle: DELETE FROM "COURSE" WHERE "ROWID" IN ( SELECT "c"."ROWID" - FROM "COURSE" "c", "DEPARTMENT" "department-1" - WHERE "department-1"."NAME" LIKE 'D%' - AND "c"."DEPT" = "department-1"."NUMBER" + FROM "COURSE" "c", "DEPARTMENT" "department" + WHERE "department"."NAME" LIKE 'D%' + AND "c"."DEPT" = "department"."NUMBER" ) >>> select(s for s in Student if s.gpa > 3 and s not in (s2 for s2 in Student if s2.group.dept.name.startswith('A'))).delete(bulk=True) @@ -806,10 +822,10 @@ DELETE FROM "Student" WHERE "gpa" > 3 AND "id" NOT IN ( SELECT "s2"."id" - FROM "Student" "s2", "Group" "group-1", "Department" "department-1" - WHERE "department-1"."name" LIKE 'A%' - AND "s2"."group" = "group-1"."number" - AND "group-1"."dept" = "department-1"."number" + FROM "Student" "s2", "Group" "group", "Department" "department" + WHERE "department"."name" LIKE 'A%' + AND "s2"."group" = "group"."number" + AND "group"."dept" = "department"."number" ) # MySQL does not support such queries @@ -820,10 +836,10 @@ DELETE FROM "student" WHERE "gpa" > 3 AND "id" NOT IN ( SELECT "s2"."id" - FROM "student" "s2", "group" "group-1", "department" "department-1" - WHERE "department-1"."name" LIKE 'A%%' - AND "s2"."group" = "group-1"."number" - AND "group-1"."dept" = "department-1"."number" + FROM "student" "s2", "group" "group", "department" "department" + WHERE "department"."name" LIKE 'A%%' + AND "s2"."group" = "group"."number" + AND "group"."dept" = "department"."number" ) Oracle: @@ -832,10 +848,10 @@ DELETE FROM "STUDENT" WHERE "GPA" > 3 AND "ID" NOT IN ( SELECT "s2"."ID" - FROM "STUDENT" "s2", "GROUP" "group-1", "DEPARTMENT" "department-1" - WHERE "department-1"."NAME" LIKE 'A%' - AND "s2"."GROUP" = "group-1"."NUMBER" - AND "group-1"."DEPT" = "department-1"."NUMBER" + FROM "STUDENT" "s2", "GROUP" "group", "DEPARTMENT" "department" + WHERE "department"."NAME" LIKE 'A%' + AND "s2"."GROUP" = "group"."NUMBER" + AND "group"."DEPT" = "department"."NUMBER" ) >>> select(s for s in Student if exists(s2 for s2 in Student if s.gpa > s2.gpa)).delete(bulk=True) @@ -877,6 +893,30 @@ WHERE "ID" IN ( ) ) +>>> select(s for s in Student if count(g for g in s.group.dept.groups) > 2).delete(bulk=True) + +DELETE FROM "Student" +WHERE "id" IN ( + SELECT "s"."id" + FROM "Student" "s" + WHERE ( + SELECT COUNT(DISTINCT "g"."number") + FROM "Group" "group", "Group" "g" + WHERE "s"."group" = "group"."number" + AND "group"."dept" = "g"."dept" + ) > 2 + ) + +>>> Student.select(lambda s: count(s.group.students) == 2).delete(bulk=True) + +DELETE FROM "Student" +WHERE "id" IN ( + SELECT "s"."id" + FROM "Student" "s" + LEFT JOIN "Student" "student" + ON "s"."group" = "student"."group" + ) + # Test UPPER/LOWER functions: >>> select(s.name.upper() for s in Student) @@ -888,3 +928,258 @@ PostgreSQL: SELECT DISTINCT upper("s"."name") FROM "student" "s" + +# Test modulo division operator + +>>> select(s for s in Student if s.id % 2 == 0) + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE ("s"."id" % 2) = 0 + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."id" %% 2) = 0 + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE (`s`.`id` %% 2) = 0 + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE MOD("s"."ID", 2) = 0 + +# Test group_concat: + +>>> select((g, group_concat(s.name, '+')) for g in Group for s in g.students) + +SELECT "g"."number", GROUP_CONCAT("s"."name", '+') +FROM "Group" "g", "Student" "s" +WHERE "g"."number" = "s"."group" +GROUP BY "g"."number" + +PostgreSQL: + +SELECT "g"."number", string_agg("s"."name"::text, '+') +FROM "group" "g", "student" "s" +WHERE "g"."number" = "s"."group" +GROUP BY "g"."number" + +MySQL: + +SELECT `g`.`number`, GROUP_CONCAT(`s`.`name` SEPARATOR '+') +FROM `group` `g`, `student` `s` +WHERE `g`.`number` = `s`.`group` +GROUP BY `g`.`number` + +Oracle: + +SELECT "g"."NUMBER", LISTAGG("s"."NAME", '+') WITHIN GROUP(ORDER BY 1) +FROM "GROUP" "g", "STUDENT" "s" +WHERE "g"."NUMBER" = "s"."GROUP" +GROUP BY "g"."NUMBER" + +# Test offset without limit + +>>> select(s for s in Student)[3:] + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +LIMIT -1 OFFSET 3 + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +LIMIT null OFFSET 3 + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +LIMIT 18446744073709551615 OFFSET 3 + +Oracle: + +SELECT t.* FROM ( + SELECT t.*, ROWNUM "row-num" FROM ( + SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" + FROM "STUDENT" "s" + ) t +) t WHERE "row-num" > 3 + +# Test row comparison: + +>>> select((s1.id, s2.id) for s1 in Student for s2 in Student if (s1.name, s1.gpa, s1.tel) < (s2.name, s2.gpa, s2.tel)) + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "Student" "s1", "Student" "s2" +WHERE ("s1"."name" < "s2"."name" OR "s1"."name" = "s2"."name" AND "s1"."gpa" < "s2"."gpa" OR "s1"."name" = "s2"."name" AND "s1"."gpa" = "s2"."gpa" AND "s1"."tel" < "s2"."tel") + +PostgreSQL: + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "student" "s1", "student" "s2" +WHERE ("s1"."name", "s1"."gpa", "s1"."tel") < ("s2"."name", "s2"."gpa", "s2"."tel") + +MySQL: + +SELECT DISTINCT `s1`.`id`, `s2`.`id` +FROM `student` `s1`, `student` `s2` +WHERE (`s1`.`name`, `s1`.`gpa`, `s1`.`tel`) < (`s2`.`name`, `s2`.`gpa`, `s2`.`tel`) + +Oracle: + +SELECT DISTINCT "s1"."ID", "s2"."ID" +FROM "STUDENT" "s1", "STUDENT" "s2" +WHERE ("s1"."NAME", "s1"."GPA", "s1"."TEL") < ("s2"."NAME", "s2"."GPA", "s2"."TEL") + +>>> select((s1.id, s2.id) for s1 in Student for s2 in Student if (s1.name, s1.gpa, s1.tel) == (s2.name, s2.gpa, s2.tel)) + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "Student" "s1", "Student" "s2" +WHERE "s1"."name" = "s2"."name" + AND "s1"."gpa" = "s2"."gpa" + AND "s1"."tel" = "s2"."tel" + +PostgreSQL: + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "student" "s1", "student" "s2" +WHERE "s1"."name" = "s2"."name" + AND "s1"."gpa" = "s2"."gpa" + AND "s1"."tel" = "s2"."tel" + +MySQL: + +SELECT DISTINCT `s1`.`id`, `s2`.`id` +FROM `student` `s1`, `student` `s2` +WHERE `s1`.`name` = `s2`.`name` + AND `s1`.`gpa` = `s2`.`gpa` + AND `s1`.`tel` = `s2`.`tel` + +Oracle: + +SELECT DISTINCT "s1"."ID", "s2"."ID" +FROM "STUDENT" "s1", "STUDENT" "s2" +WHERE "s1"."NAME" = "s2"."NAME" + AND "s1"."GPA" = "s2"."GPA" + AND "s1"."TEL" = "s2"."TEL" + + +# Test date operations: + + +>>> select(s for s in Student if s.dob + timedelta(days=100) < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE date("s"."dob", '+100 days') < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" + INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" + INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE ADDDATE(`s`.`dob`, INTERVAL '2400:0:0' HOUR_SECOND) < DATE '2010-01-01' + + +>>> td = timedelta(days=100) +>>> select(s for s in Student if s.dob + td < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE datetime(julianday("s"."dob") + ?) < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" + %(p1)s) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" + :p1) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE ADDDATE(`s`.`dob`, %s) < DATE '2010-01-01' + + +>>> select(s for s in Student if s.dob - timedelta(days=100) < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE date("s"."dob", '-100 days') < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" - INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" - INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE SUBDATE(`s`.`dob`, INTERVAL '2400:0:0' HOUR_SECOND) < DATE '2010-01-01' + +>>> td = timedelta(days=100) +>>> select(s for s in Student if s.dob - td < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE datetime(julianday("s"."dob") - ?) < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" - %(p1)s) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" - :p1) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE SUBDATE(`s`.`dob`, %s) < DATE '2010-01-01' diff --git a/pony/orm/tests/sql_tests.py b/pony/orm/tests/sql_tests.py index 668293ece..ec68a82d5 100644 --- a/pony/orm/tests/sql_tests.py +++ b/pony/orm/tests/sql_tests.py @@ -2,6 +2,7 @@ from pony.py23compat import PY2 import re, os, os.path, sys +from datetime import datetime, timedelta from pony import orm from pony.orm import core @@ -55,6 +56,7 @@ def do_test(provider_name, raw_server_version): return module = sys.modules[module_name] globals = vars(module).copy() + globals.update(datetime=datetime, timedelta=timedelta) with orm.db_session: for statement in statements[:-1]: code = compile(statement, '', 'exec') diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py new file mode 100644 index 000000000..ef647e055 --- /dev/null +++ b/pony/orm/tests/test_array.py @@ -0,0 +1,268 @@ +from pony.py23compat import PY2 + +import unittest +from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, setup_database, teardown_database + +from pony.orm import * + +db = Database() + + +class Foo(db.Entity): + id = PrimaryKey(int) + a = Required(int) + b = Required(int) + c = Required(int) + array1 = Required(IntArray, index=True) + array2 = Required(FloatArray) + array3 = Required(StrArray) + array4 = Optional(IntArray) + array5 = Optional(IntArray, nullable=True) + + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + if db_params['provider'] not in ('sqlite', 'postgres'): + raise unittest.SkipTest('Arrays are only available for SQLite and PostgreSQL') + + setup_database(db) + with db_session: + Foo(id=1, a=1, b=3, c=-2, array1=[10, 20, 30, 40, 50], array2=[1.1, 2.2, 3.3, 4.4, 5.5], + array3=['foo', 'bar']) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def test_1(self): + foo = select(f for f in Foo if 10 in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_2(self): + foo = select(f for f in Foo if [10, 20, 50] in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_2a(self): + foo = select(f for f in Foo if [] in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_3(self): + x = [10, 20, 50] + foo = select(f for f in Foo if x in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_4(self): + foo = select(f for f in Foo if 1.1 in f.array2)[:] + self.assertEqual([Foo[1]], foo) + + err_msg = "Cannot store 'int' item in array of " + ("'unicode'" if PY2 else "'str'") + + @raises_exception(TypeError, err_msg) + @db_session + def test_5(self): + foo = Foo.select().first() + foo.array3.append(123) + + @raises_exception(TypeError, err_msg) + @db_session + def test_6(self): + foo = Foo.select().first() + foo.array3[0] = 123 + + @raises_exception(TypeError, err_msg) + @db_session + def test_7(self): + foo = Foo.select().first() + foo.array3.extend(['str', 123, 'str']) + + @db_session + def test_8(self): + foo = Foo.select().first() + foo.array3.extend(['str1', 'str2']) + + @db_session + def test_9(self): + foos = select(f.array2[0] for f in Foo)[:] + self.assertEqual([1.1], foos) + + @db_session + def test_10(self): + foos = select(f.array1[1:-1] for f in Foo)[:] + self.assertEqual([20, 30, 40], foos[0]) + + @db_session + def test_11(self): + foo = Foo.select().first() + foo.array4.append(1) + self.assertEqual([1], foo.array4) + + @raises_exception(AttributeError, "'NoneType' object has no attribute 'append'") + @db_session + def test_12(self): + foo = Foo.select().first() + foo.array5.append(1) + + @db_session + def test_13(self): + x = [10, 20, 30, 40, 50] + ids = select(f.id for f in Foo if x == f.array1)[:] + self.assertEqual(ids, [1]) + + @db_session + def test_14(self): + val = select(f.array1[0] for f in Foo).first() + self.assertEqual(val, 10) + + @db_session + def test_15(self): + val = select(f.array1[2] for f in Foo).first() + self.assertEqual(val, 30) + + @db_session + def test_16(self): + val = select(f.array1[-1] for f in Foo).first() + self.assertEqual(val, 50) + + @db_session + def test_17(self): + val = select(f.array1[-2] for f in Foo).first() + self.assertEqual(val, 40) + + @db_session + def test_18(self): + x = 2 + val = select(f.array1[x] for f in Foo).first() + self.assertEqual(val, 30) + + @db_session + def test_19(self): + val = select(f.array1[f.a] for f in Foo).first() + self.assertEqual(val, 20) + + @db_session + def test_20(self): + val = select(f.array1[f.c] for f in Foo).first() + self.assertEqual(val, 40) + + @db_session + def test_21(self): + array = select(f.array1[2:4] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_22(self): + array = select(f.array1[1:-2] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_23(self): + array = select(f.array1[10:-10] for f in Foo).first() + self.assertEqual(array, []) + + @db_session + def test_24(self): + x = 2 + array = select(f.array1[x:4] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_25(self): + y = 4 + array = select(f.array1[2:y] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_26(self): + x, y = 2, 4 + array = select(f.array1[x:y] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_27(self): + x, y = 1, -2 + array = select(f.array1[x:y] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_28(self): + x = 1 + array = select(f.array1[x:f.b] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_29(self): + array = select(f.array1[f.a:f.c] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_30(self): + array = select(f.array1[:3] for f in Foo).first() + self.assertEqual(array, [10, 20, 30]) + + @db_session + def test_31(self): + array = select(f.array1[2:] for f in Foo).first() + self.assertEqual(array, [30, 40, 50]) + + @db_session + def test_32(self): + array = select(f.array1[:f.b] for f in Foo).first() + self.assertEqual(array, [10, 20, 30]) + + @db_session + def test_33(self): + array = select(f.array1[:f.c] for f in Foo).first() + self.assertEqual(array, [10, 20, 30]) + + @db_session + def test_34(self): + array = select(f.array1[f.c:] for f in Foo).first() + self.assertEqual(array, [40, 50]) + + @db_session + def test_35(self): + foo = Foo.select().first() + self.assertTrue(10 in foo.array1) + self.assertTrue(1000 not in foo.array1) + self.assertTrue([10, 20] in foo.array1) + self.assertTrue([20, 10] in foo.array1) + self.assertTrue([10, 1000] not in foo.array1) + self.assertTrue([] in foo.array1) + self.assertTrue('bar' in foo.array3) + self.assertTrue('baz' not in foo.array3) + self.assertTrue(['foo', 'bar'] in foo.array3) + self.assertTrue(['bar', 'foo'] in foo.array3) + self.assertTrue(['baz', 'bar'] not in foo.array3) + self.assertTrue([] in foo.array3) + + @db_session + def test_36(self): + items = [] + result = select(foo for foo in Foo if foo in items)[:] + self.assertEqual(result, []) + + @db_session + def test_37(self): + f1 = Foo[1] + items = [f1] + result = select(foo for foo in Foo if foo in items)[:] + self.assertEqual(result, [f1]) + + @db_session + def test_38(self): + items = [] + result = select(foo for foo in Foo if foo.id in items)[:] + self.assertEqual(result, []) + + @db_session + def test_39(self): + items = [1] + result = select(foo.id for foo in Foo if foo.id in items)[:] + self.assertEqual(result, [1]) diff --git a/pony/orm/tests/test_attribute_options.py b/pony/orm/tests/test_attribute_options.py new file mode 100644 index 000000000..6ba55ca46 --- /dev/null +++ b/pony/orm/tests/test_attribute_options.py @@ -0,0 +1,112 @@ +import unittest +from decimal import Decimal +from datetime import datetime, time +from random import randint + +from pony import orm +from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database +from pony.orm.tests.testutils import raises_exception + +db = Database() + +class Person(db.Entity): + id = PrimaryKey(int) + name = orm.Required(str, 40) + lastName = orm.Required(str, max_len=40, unique=True) + age = orm.Optional(int, max=60, min=10) + nickName = orm.Optional(str, autostrip=False) + middleName = orm.Optional(str, nullable=True) + rate = orm.Optional(Decimal, precision=11) + salaryRate = orm.Optional(Decimal, precision=13, scale=8) + timeStmp = orm.Optional(datetime, precision=6) + gpa = orm.Optional(float, py_check=lambda val: val >= 0 and val <= 5) + vehicle = orm.Optional(str, column='car') + +class TestAttributeOptions(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with orm.db_session: + p1 = Person(id=1, name='Andrew', lastName='Bodroue', age=40, rate=0.980000000001, salaryRate=0.98000001) + p2 = Person(id=2, name='Vladimir', lastName='Andrew ', nickName='vlad ') + p3 = Person(id=3, name='Nick', lastName='Craig', middleName=None, timeStmp='2010-12-10 14:12:09.019473', + vehicle='dodge') + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_optionalStringEmpty(self): + queryResult = select(p.id for p in Person if p.nickName==None).first() + self.assertIsNone(queryResult) + + def test_optionalStringNone(self): + queryResult = select(p.id for p in Person if p.middleName==None).first() + self.assertIsNotNone(queryResult) + + def test_stringAutoStrip(self): + self.assertEqual(Person[2].lastName, 'Andrew') + + def test_stringAutoStripFalse(self): + self.assertEqual(Person[2].nickName, 'vlad ') + + def test_intNone(self): + queryResult = select(p.id for p in Person if p.age==None).first() + self.assertIsNotNone(queryResult) + + def test_columnName(self): + self.assertEqual(getattr(Person.vehicle, 'column'), 'car') + + def test_decimalPrecisionTwo(self): + queryResult = select(p.rate for p in Person if p.age==40).first() + self.assertAlmostEqual(float(queryResult), 0.98, 12) + + def test_decimalPrecisionEight(self): + queryResult = select(p.salaryRate for p in Person if p.age==40).first() + self.assertAlmostEqual(float(queryResult), 0.98000001, 8) + + def test_fractionalSeconds(self): + queryResult = select(p.timeStmp for p in Person if p.name=='Nick').first() + self.assertEqual(queryResult.microsecond, 19473) + + def test_intMax(self): + p4 = Person(id=4, name='Denis', lastName='Blanc', age=60) + + def test_intMin(self): + p4 = Person(id=4, name='Denis', lastName='Blanc', age=10) + + @raises_exception(ValueError, "Value 61 of attr Person.age is greater than the maximum allowed value 60") + def test_intMaxException(self): + p4 = Person(id=4, name='Denis', lastName='Blanc', age=61) + + @raises_exception(ValueError, "Value 9 of attr Person.age is less than the minimum allowed value 10") + def test_intMinException(self): + p4 = Person(id=4, name='Denis', lastName='Blanc', age=9) + + def test_py_check(self): + p4 = Person(id=4, name='Denis', lastName='Blanc', gpa=5) + p5 = Person(id=5, name='Mario', lastName='Gon', gpa=1) + flush() + + @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: 6.0") + def test_py_checkMoreException(self): + p6 = Person(id=6, name='Daniel', lastName='Craig', gpa=6) + + @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: -1.0") + def test_py_checkLessException(self): + p6 = Person(id=6, name='Daniel', lastName='Craig', gpa=-1) + + @raises_exception(TransactionIntegrityError, + 'Object Person[...] cannot be stored in the database. IntegrityError: ...') + def test_unique(self): + p6 = Person(id=6, name='Boris', lastName='Bodroue') + flush() diff --git a/pony/orm/tests/test_autostrip.py b/pony/orm/tests/test_autostrip.py index 18865536b..b4656256d 100644 --- a/pony/orm/tests/test_autostrip.py +++ b/pony/orm/tests/test_autostrip.py @@ -2,16 +2,23 @@ from pony.orm import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Person(db.Entity): name = Required(str) tel = Optional(str) -db.generate_mapping(create_tables=True) class TestAutostrip(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) @db_session def test_1(self): diff --git a/pony/orm/tests/test_buffer.py b/pony/orm/tests/test_buffer.py new file mode 100644 index 000000000..002317f9f --- /dev/null +++ b/pony/orm/tests/test_buffer.py @@ -0,0 +1,49 @@ +import unittest + +from pony import orm +from pony.orm.tests import setup_database, teardown_database + +db = orm.Database() + + +class Foo(db.Entity): + id = orm.PrimaryKey(int) + b = orm.Optional(orm.buffer) + + +class Bar(db.Entity): + b = orm.PrimaryKey(orm.buffer) + + +class Baz(db.Entity): + id = orm.PrimaryKey(int) + b = orm.Optional(orm.buffer, unique=True) + + +buf = orm.buffer(b'123') + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with orm.db_session: + Foo(id=1, b=buf) + Bar(b=buf) + Baz(id=1, b=buf) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): # Bug #355 + with orm.db_session: + Bar[buf] + + def test_2(self): # Regression after #355 fix + with orm.db_session: + result = orm.select(bar.b for bar in Foo)[:] + self.assertEqual(result, [buf]) + + def test_3(self): # Bug #390 + with orm.db_session: + Baz.get(b=buf) diff --git a/pony/orm/tests/test_bug_170.py b/pony/orm/tests/test_bug_170.py new file mode 100644 index 000000000..8b38deab1 --- /dev/null +++ b/pony/orm/tests/test_bug_170.py @@ -0,0 +1,31 @@ +import unittest + +from pony import orm +from pony.orm.tests import setup_database, teardown_database + +db = orm.Database() + + +class Person(db.Entity): + id = orm.PrimaryKey(int, auto=True) + name = orm.Required(str) + orm.composite_key(id, name) + + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): + table = db.schema.tables[Person._table_] + pk_column = table.column_dict[Person.id.column] + self.assertTrue(pk_column.is_pk) + + with orm.db_session: + p1 = Person(name='John') + p2 = Person(name='Mike') diff --git a/pony/orm/tests/test_bug_182.py b/pony/orm/tests/test_bug_182.py index f96bdbbe4..7581f9c67 100644 --- a/pony/orm/tests/test_bug_182.py +++ b/pony/orm/tests/test_bug_182.py @@ -1,44 +1,46 @@ - import unittest from pony.orm import * from pony import orm +from pony.orm.tests import setup_database, teardown_database -import os +db = Database() -class Test(unittest.TestCase): +class User(db.Entity): + name = Required(str) + servers = Set("Server") - def setUp(self): - db = self.db = Database('sqlite', ':memory:') - class User(db.Entity): - name = Required(str) - servers = Set("Server") +class Worker(User): + pass - class Worker(db.User): - pass - class Admin(db.Worker): - pass +class Admin(Worker): + pass - # And M:1 relationship with another entity - class Server(db.Entity): - name = Required(str) - user = Optional(User) +# And M:1 relationship with another entity +class Server(db.Entity): + name = Required(str) + user = Optional(User) - db.generate_mapping(check_tables=True, create_tables=True) +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) with orm.db_session: Server(name='s1.example.com', user=User(name="Alex")) Server(name='s2.example.com', user=Worker(name="John")) Server(name='free.example.com', user=None) + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test(self): - qu = left_join( - (s.name, s.user.name) for s in self.db.Server - )[:] + qu = left_join((s.name, s.user.name) for s in db.Server)[:] for server, user in qu: if user is None: break diff --git a/pony/orm/tests/test_bug_331.py b/pony/orm/tests/test_bug_331.py new file mode 100644 index 000000000..3a6af0469 --- /dev/null +++ b/pony/orm/tests/test_bug_331.py @@ -0,0 +1,53 @@ +import unittest + +from pony.orm.tests import setup_database, teardown_database +from pony.orm import * + +db = Database() + + +class Person(db.Entity): + name = Required(str) + group = Optional(lambda: Group) + + +class Group(db.Entity): + title = PrimaryKey(str) + persons = Set(Person) + + def __len__(self): + return len(self.persons) + + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): + with db_session: + p1 = Person(name="Alex") + p2 = Person(name="Brad") + p3 = Person(name="Chad") + p4 = Person(name="Dylan") + p5 = Person(name="Ethan") + + g1 = Group(title="Foxes") + g2 = Group(title="Gorillas") + + g1.persons.add(p1) + g1.persons.add(p2) + g1.persons.add(p3) + g2.persons.add(p4) + g2.persons.add(p5) + commit() + + foxes = Group['Foxes'] + gorillas = Group['Gorillas'] + + self.assertEqual(len(foxes), 3) + self.assertEqual(len(gorillas), 2) diff --git a/pony/orm/tests/test_bug_386.py b/pony/orm/tests/test_bug_386.py new file mode 100644 index 000000000..31f62574b --- /dev/null +++ b/pony/orm/tests/test_bug_386.py @@ -0,0 +1,26 @@ +import unittest + +from pony import orm +from pony.orm.tests import setup_database, teardown_database + +db = orm.Database() + + +class Person(db.Entity): + name = orm.Required(str) + + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): + with orm.db_session: + a = Person(name='John') + a.delete() + Person.exists(name='Mike') diff --git a/pony/orm/tests/test_cascade.py b/pony/orm/tests/test_cascade.py new file mode 100644 index 000000000..d2f52c21f --- /dev/null +++ b/pony/orm/tests/test_cascade.py @@ -0,0 +1,110 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + + +class TestCascade(unittest.TestCase): + providers = ['sqlite'] # Implement for other providers + + def tearDown(self): + if self.db.schema is not None: + teardown_database(self.db) + + def assert_on_delete(self, table_name, value): + db = self.db + if not (db.provider.dialect == 'SQLite' and pony.__version__ < '0.9'): + table_name = table_name.lower() + table = db.schema.tables[table_name] + fkeys = table.foreign_keys + self.assertEqual(1, len(fkeys)) + if pony.__version__ >= '0.9': + self.assertEqual(fkeys[0].on_delete, value) + elif db.provider.dialect == 'SQLite': + self.assertIn('ON DELETE %s' % value, table.get_create_command()) + else: + self.assertIn('ON DELETE %s' % value, list(fkeys.values())[0].get_create_command()) + + + def test_1(self): + db = self.db = Database() + + class Person(self.db.Entity): + name = Required(str) + group = Required('Group') + + class Group(self.db.Entity): + persons = Set(Person) + + setup_database(db) + self.assert_on_delete('Person', 'CASCADE') + + def test_2(self): + db = self.db = Database() + + class Person(self.db.Entity): + name = Required(str) + group = Required('Group') + + class Group(self.db.Entity): + persons = Set(Person, cascade_delete=True) + + setup_database(db) + self.assert_on_delete('Person', 'CASCADE') + + def test_3(self): + db = self.db = Database() + + class Person(self.db.Entity): + name = Required(str) + group = Optional('Group') + + class Group(self.db.Entity): + persons = Set(Person, cascade_delete=True) + + setup_database(db) + self.assert_on_delete('Person', 'CASCADE') + + @raises_exception(TypeError, "'cascade_delete' option cannot be set for attribute Group.persons, because reverse attribute Person.group is collection") + def test_4(self): + db = self.db = Database() + + class Person(self.db.Entity): + name = Required(str) + group = Set('Group') + + class Group(self.db.Entity): + persons = Set(Person, cascade_delete=True) + + setup_database(db) + + @raises_exception(TypeError, "'cascade_delete' option cannot be set for both sides of relationship (Person.group and Group.persons) simultaneously") + def test_5(self): + db = self.db = Database() + + class Person(self.db.Entity): + name = Required(str) + group = Set('Group', cascade_delete=True) + + class Group(self.db.Entity): + persons = Required(Person, cascade_delete=True) + + setup_database(db) + + def test_6(self): + db = self.db = Database() + + class Person(self.db.Entity): + name = Required(str) + group = Set('Group') + + class Group(self.db.Entity): + persons = Optional(Person) + + setup_database(db) + self.assert_on_delete('Group', 'SET NULL') + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_cascade_delete.py b/pony/orm/tests/test_cascade_delete.py new file mode 100644 index 000000000..b2d38f9dc --- /dev/null +++ b/pony/orm/tests/test_cascade_delete.py @@ -0,0 +1,72 @@ +import unittest + +from pony.orm import * +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() + +class X(db.Entity): + id = PrimaryKey(int) + parent = Optional('X', reverse='children') + children = Set('X', reverse='parent', cascade_delete=True) + + +class Y(db.Entity): + parent = Optional('Y', reverse='children') + children = Set('Y', reverse='parent', cascade_delete=True, lazy=True) + + +class TestCascade(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + x1 = X(id=1) + x2 = X(id=2, parent=x1) + x3 = X(id=3, parent=x1) + x4 = X(id=4, parent=x3) + x5 = X(id=5, parent=x3) + x6 = X(id=6, parent=x5) + x7 = X(id=7, parent=x3) + x8 = X(id=8, parent=x7) + x9 = X(id=9, parent=x7) + x10 = X(id=10) + x11 = X(id=11, parent=x10) + x12 = X(id=12, parent=x10) + + y1 = Y(id=1) + y2 = Y(id=2, parent=y1) + y3 = Y(id=3, parent=y1) + y4 = Y(id=4, parent=y3) + y5 = Y(id=5, parent=y3) + y6 = Y(id=6, parent=y5) + y7 = Y(id=7, parent=y3) + y8 = Y(id=8, parent=y7) + y9 = Y(id=9, parent=y7) + y10 = Y(id=10) + y11 = Y(id=11, parent=y10) + y12 = Y(id=12, parent=y10) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_1(self): + db.merge_local_stats() + X[1].delete() + stats = db.local_stats[None] + self.assertEqual(5, stats.db_count) + + def test_2(self): + db.merge_local_stats() + Y[1].delete() + stats = db.local_stats[None] + self.assertEqual(10, stats.db_count) diff --git a/pony/orm/tests/test_collections.py b/pony/orm/tests/test_collections.py index 57673edec..a6a54e1bb 100644 --- a/pony/orm/tests/test_collections.py +++ b/pony/orm/tests/test_collections.py @@ -1,11 +1,22 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PY2 import unittest from pony.orm.tests.testutils import raises_exception from pony.orm.tests.model1 import * +from pony.orm.tests import setup_database, teardown_database + class TestCollections(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + populate_db() + + @classmethod + def tearDownClass(cls): + db.drop_all_tables(with_all_data=True) @db_session def test_setwrapper_len(self): @@ -23,6 +34,50 @@ def test_setwrapper_nonzero(self): def test_get_by_collection_error(self): Group.get(students=[]) + @db_session + def test_collection_create_one2many_1(self): + g = Group['3132'] + g.students.create(record=106, name='Mike', scholarship=200) + flush() + self.assertEqual(len(g.students), 3) + rollback() + + @raises_exception(TypeError, "When using Group.students.create(), " + "'group' attribute should not be passed explicitly") + @db_session + def test_collection_create_one2many_2(self): + g = Group['3132'] + g.students.create(record=106, name='Mike', scholarship=200, group=g) + + @raises_exception(TransactionIntegrityError, "Object Student[105] cannot be stored in the database...") + @db_session + def test_collection_create_one2many_3(self): + g = Group['3132'] + g.students.create(record=105, name='Mike', scholarship=200) + + @db_session + def test_collection_create_many2many_1(self): + g = Group['3132'] + g.subjects.create(name='Biology') + flush() + self.assertEqual(len(g.subjects), 3) + rollback() + + @raises_exception(TypeError, "When using Group.subjects.create(), " + "'groups' attribute should not be passed explicitly") + @db_session + def test_collection_create_many2many_2(self): + g = Group['3132'] + g.subjects.create(name='Biology', groups=[g]) + + @raises_exception(TransactionIntegrityError, + "Object Subject[u'Math'] cannot be stored in the database..." if PY2 else + "Object Subject['Math'] cannot be stored in the database...") + @db_session + def test_collection_create_many2many_3(self): + g = Group['3132'] + g.subjects.create(name='Math') + # replace collection items when the old ones are not fully loaded ##>>> from pony.examples.orm.students01.model import * ##>>> s1 = Student[101] diff --git a/pony/orm/tests/test_core_find_in_cache.py b/pony/orm/tests/test_core_find_in_cache.py index d0ec1b29d..d803809b8 100644 --- a/pony/orm/tests/test_core_find_in_cache.py +++ b/pony/orm/tests/test_core_find_in_cache.py @@ -3,8 +3,9 @@ import unittest from pony.orm.tests.testutils import raises_exception from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class AbstractUser(db.Entity): username = PrimaryKey(unicode) @@ -32,31 +33,40 @@ class Diagram(db.Entity): name = Required(unicode) owner = Required(User) -db.generate_mapping(create_tables=True) - -with db_session: - u1 = User(username='user1') - u2 = SubUser1(username='subuser1', attr1='some attr') - u3 = SubUser2(username='subuser2', attr2='some attr') - o1 = Organization(username='org1') - o2 = SubOrg1(username='suborg1', attr3='some attr') - o3 = SubOrg2(username='suborg2', attr4='some attr') - au = AbstractUser(username='abstractUser') - Diagram(name='diagram1', owner=u1) - Diagram(name='diagram2', owner=u2) - Diagram(name='diagram3', owner=u3) def is_seed(entity, pk): cache = entity._database_._get_cache() return pk in [ obj._pk_ for obj in cache.seeds[entity._pk_attrs_] ] + class TestFindInCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + u1 = User(username='user1') + u2 = SubUser1(username='subuser1', attr1='some attr') + u3 = SubUser2(username='subuser2', attr2='some attr') + o1 = Organization(username='org1') + o2 = SubOrg1(username='suborg1', attr3='some attr') + o3 = SubOrg2(username='suborg2', attr4='some attr') + au = AbstractUser(username='abstractUser') + Diagram(name='diagram1', owner=u1) + Diagram(name='diagram2', owner=u2) + Diagram(name='diagram3', owner=u3) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() + def tearDown(self): rollback() db_session.__exit__() + def test1(self): u = User.get(username='org1') org = Organization.get(username='org1') @@ -72,6 +82,7 @@ def test_user_1(self): u = AbstractUser['user1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, User) + def test_user_2(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql @@ -79,6 +90,7 @@ def test_user_2(self): u = User['user1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, User) + @raises_exception(ObjectNotFound) def test_user_3(self): Diagram.get(lambda d: d.name == 'diagram1') @@ -88,6 +100,7 @@ def test_user_3(self): SubUser1['user1'] finally: self.assertNotEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_user_4(self): Diagram.get(lambda d: d.name == 'diagram1') @@ -97,6 +110,7 @@ def test_user_4(self): Organization['user1'] finally: self.assertEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_user_5(self): Diagram.get(lambda d: d.name == 'diagram1') @@ -107,7 +121,6 @@ def test_user_5(self): finally: self.assertEqual(last_sql, db.last_sql) - def test_subuser_1(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql @@ -115,6 +128,7 @@ def test_subuser_1(self): u = AbstractUser['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) + def test_subuser_2(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql @@ -122,6 +136,7 @@ def test_subuser_2(self): u = User['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) + def test_subuser_3(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql @@ -129,6 +144,7 @@ def test_subuser_3(self): u = SubUser1['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) + @raises_exception(ObjectNotFound) def test_subuser_4(self): Diagram.get(lambda d: d.name == 'diagram2') @@ -138,6 +154,7 @@ def test_subuser_4(self): Organization['subuser1'] finally: self.assertEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_subuser_5(self): Diagram.get(lambda d: d.name == 'diagram2') @@ -147,6 +164,7 @@ def test_subuser_5(self): SubUser2['subuser1'] finally: self.assertNotEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_subuser_6(self): Diagram.get(lambda d: d.name == 'diagram2') @@ -163,6 +181,7 @@ def test_user_6(self): u2 = SubUser1['subuser1'] self.assertEqual(last_sql, db.last_sql) self.assertEqual(u1, u2) + def test_user_7(self): u1 = SubUser1['subuser1'] u1.delete() @@ -170,6 +189,7 @@ def test_user_7(self): u2 = SubUser1.get(username='subuser1') self.assertEqual(last_sql, db.last_sql) self.assertEqual(u2, None) + def test_user_8(self): u1 = SubUser1['subuser1'] last_sql = db.last_sql diff --git a/pony/orm/tests/test_core_multiset.py b/pony/orm/tests/test_core_multiset.py index 61b557c54..65d5153dd 100644 --- a/pony/orm/tests/test_core_multiset.py +++ b/pony/orm/tests/test_core_multiset.py @@ -4,8 +4,9 @@ from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int) @@ -27,33 +28,40 @@ class Course(db.Entity): department = Required(Department) students = Set('Student') -db.generate_mapping(create_tables=True) -with db_session: - d1 = Department(number=1) - d2 = Department(number=2) - d3 = Department(number=3) +class TestMultiset(unittest.TestCase): + + @classmethod + def setUpClass(cls): + setup_database(db) - g1 = Group(number=101, department=d1) - g2 = Group(number=102, department=d1) - g3 = Group(number=201, department=d2) + with db_session: + d1 = Department(number=1) + d2 = Department(number=2) + d3 = Department(number=3) - c1 = Course(name='C1', department=d1) - c2 = Course(name='C2', department=d1) - c3 = Course(name='C3', department=d2) - c4 = Course(name='C4', department=d2) - c5 = Course(name='C5', department=d3) + g1 = Group(number=101, department=d1) + g2 = Group(number=102, department=d1) + g3 = Group(number=201, department=d2) - s1 = Student(name='S1', group=g1, courses=[c1, c2]) - s2 = Student(name='S2', group=g1, courses=[c1, c3]) - s3 = Student(name='S3', group=g1, courses=[c2, c3]) + c1 = Course(name='C1', department=d1) + c2 = Course(name='C2', department=d1) + c3 = Course(name='C3', department=d2) + c4 = Course(name='C4', department=d2) + c5 = Course(name='C5', department=d3) - s4 = Student(name='S4', group=g2, courses=[c1, c2]) - s5 = Student(name='S5', group=g2, courses=[c1, c2]) + s1 = Student(name='S1', group=g1, courses=[c1, c2]) + s2 = Student(name='S2', group=g1, courses=[c1, c3]) + s3 = Student(name='S3', group=g1, courses=[c2, c3]) - s6 = Student(name='A', group=g3, courses=[c5]) + s4 = Student(name='S4', group=g2, courses=[c1, c2]) + s5 = Student(name='S5', group=g2, courses=[c1, c2]) -class TestMultiset(unittest.TestCase): + s6 = Student(name='A', group=g3, courses=[c5]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) @db_session def test_multiset_repr_1(self): @@ -77,7 +85,7 @@ def test_multiset_repr_4(self): with db_session: g = Group[101] multiset = g.students.courses - self.assertEqual(multiset._obj_._session_cache_.is_alive, False) + self.assertIsNone(multiset._obj_._session_cache_) self.assertEqual(repr(multiset), "") @db_session @@ -119,7 +127,7 @@ def test_multiset_ne(self): d = Department[1] multiset = d.groups.students.courses self.assertFalse(multiset != multiset) - + @db_session def test_multiset_contains(self): d = Department[1] @@ -138,5 +146,6 @@ def test_multiset_reduce(self): multiset_1 = pickle.loads(s) self.assertEqual(multiset_1, multiset_2) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_crud.py b/pony/orm/tests/test_crud.py index c84ea50db..540a92bcf 100644 --- a/pony/orm/tests/test_crud.py +++ b/pony/orm/tests/test_crud.py @@ -6,8 +6,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Group(db.Entity): id = PrimaryKey(int) @@ -16,6 +17,7 @@ class Group(db.Entity): class Student(db.Entity): name = Required(unicode) + age = Optional(int) scholarship = Required(Decimal, default=0) picture = Optional(buffer, lazy=True) email = Required(unicode, unique=True) @@ -24,25 +26,31 @@ class Student(db.Entity): group = Optional('Group') class Course(db.Entity): + id = PrimaryKey(int) name = Required(unicode) semester = Required(int) students = Set(Student) composite_key(name, semester) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(id=1, major='Math') - g2 = Group(id=2, major='Physics') - s1 = Student(id=1, name='S1', email='s1@example.com', group=g1) - s2 = Student(id=2, name='S2', email='s2@example.com', group=g1) - s3 = Student(id=3, name='S3', email='s3@example.com', group=g2) - c1 = Course(name='Math', semester=1) - c2 = Course(name='Math', semester=2) - c3 = Course(name='Physics', semester=1) - class TestCRUD(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1, major='Math') + g2 = Group(id=2, major='Physics') + s1 = Student(id=1, name='S1', age=19, email='s1@example.com', group=g1) + s2 = Student(id=2, name='S2', age=21, email='s2@example.com', group=g1) + s3 = Student(id=3, name='S3', email='s3@example.com', group=g2) + c1 = Course(id=1, name='Math', semester=1) + c2 = Course(id=2, name='Math', semester=2) + c3 = Course(id=3, name='Physics', semester=1) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -72,6 +80,20 @@ def test_exists_3(self): x = Student.exists(group=g1) self.assertEqual(x, True) + def test_numeric_nonzero(self): + result = select(s.id for s in Student if s.age)[:] + self.assertEqual(set(result), {1, 2}) + + def test_numeric_negate_1(self): + result = select(s.id for s in Student if not s.age)[:] + self.assertEqual(set(result), {3}) + self.assertTrue('is null' in db.last_sql.lower()) + + def test_numeric_negate_2(self): + result = select(c.id for c in Course if not c.semester)[:] + self.assertEqual(result, []) + self.assertTrue('is null' not in db.last_sql.lower()) + def test_set1(self): s1 = Student[1] s1.set(name='New name', scholarship=100) @@ -106,6 +128,7 @@ def test_validate_3(self): @raises_exception(ValueError, "Value type for attribute Group.id must be int. Got string 'not a number'") def test_validate_5(self): s4 = Student(id=3, name='S4', email='s4@example.com', group='not a number') + @raises_exception(TypeError, "Attribute Student.group must be of Group type. Got: datetime.date(2011, 1, 1)") def test_validate_6(self): s4 = Student(id=3, name='S4', email='s4@example.com', group=date(2011, 1, 1)) diff --git a/pony/orm/tests/test_crud_raw_sql.py b/pony/orm/tests/test_crud_raw_sql.py index 4083a15f2..03dd23348 100644 --- a/pony/orm/tests/test_crud_raw_sql.py +++ b/pony/orm/tests/test_crud_raw_sql.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database, only_for -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -25,9 +26,17 @@ class Bio(db.Entity): desc = Required(unicode) Student = Required(Student) -db.generate_mapping(create_tables=True) +@only_for('sqlite') class TestCrudRawSQL(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('delete from Student') @@ -45,20 +54,37 @@ def tearDown(self): def test1(self): students = set(Student.select_by_sql("select id, name, age, group_dept, group_grad_year from Student order by age")) - self.assertEqual(students, set([Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[3], Student[2], Student[1]}) def test2(self): students = set(Student.select_by_sql("select id, age, group_dept from Student order by age")) - self.assertEqual(students, set([Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[3], Student[2], Student[1]}) @raises_exception(NameError, "Column x does not belong to entity Student") def test3(self): students = set(Student.select_by_sql("select id, age, age*2 as x from Student order by age")) - self.assertEqual(students, set([Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[3], Student[2], Student[1]}) @raises_exception(TypeError, 'The first positional argument must be lambda function or its text source. Got: 123') def test4(self): students = Student.select(123) + def test5(self): + x = 1 + y = 30 + cursor = db.execute("select name from Student where id = $x and age = $y") + self.assertEqual(cursor.fetchone()[0], 'A') + + def test6(self): + x = 1 + y = 30 + cursor = db.execute("select name, 'abc$$def%' from Student where id = $x and age = $y") + self.assertEqual(cursor.fetchone(), ('A', 'abc$def%')) + + def test7(self): + cursor = db.execute("select name, 'abc$$def%' from Student where id = 1") + self.assertEqual(cursor.fetchone(), ('A', 'abc$def%')) + + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_datetime.py b/pony/orm/tests/test_datetime.py new file mode 100644 index 000000000..7bb60accb --- /dev/null +++ b/pony/orm/tests/test_datetime.py @@ -0,0 +1,142 @@ +from __future__ import absolute_import, print_function, division +from pony.py23compat import PY2 + +import unittest +from datetime import date, datetime, timedelta + +from pony.orm.core import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() + + +class Entity1(db.Entity): + id = PrimaryKey(int) + d = Required(date) + dt = Required(datetime) + + +class TestDate(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Entity1(id=1, d=date(2009, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) + Entity1(id=2, d=date(2010, 10, 21), dt=datetime(2010, 10, 21, 10, 21, 31)) + Entity1(id=3, d=date(2011, 11, 22), dt=datetime(2011, 11, 22, 10, 20, 32)) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_create(self): + e1 = Entity1(id=4, d=date(2011, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) + + def test_date_year(self): + result = select(e for e in Entity1 if e.d.year > 2009) + self.assertEqual(len(result), 2) + + def test_date_month(self): + result = select(e for e in Entity1 if e.d.month == 10) + self.assertEqual(len(result), 2) + + def test_date_day(self): + result = select(e for e in Entity1 if e.d.day == 22) + self.assertEqual(len(result), 1) + + def test_datetime_year(self): + result = select(e for e in Entity1 if e.dt.year > 2009) + self.assertEqual(len(result), 2) + + def test_datetime_month(self): + result = select(e for e in Entity1 if e.dt.month == 10) + self.assertEqual(len(result), 2) + + def test_datetime_day(self): + result = select(e for e in Entity1 if e.dt.day == 22) + self.assertEqual(len(result), 1) + + def test_datetime_hour(self): + result = select(e for e in Entity1 if e.dt.hour == 10) + self.assertEqual(len(result), 3) + + def test_datetime_minute(self): + result = select(e for e in Entity1 if e.dt.minute == 20) + self.assertEqual(len(result), 2) + + def test_datetime_second(self): + result = select(e for e in Entity1 if e.dt.second == 30) + self.assertEqual(len(result), 1) + + def test_date_sub_date(self): + dt = date(2012, 1, 1) + result = select(e.id for e in Entity1 if dt - e.d > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + def test_datetime_sub_datetime(self): + dt = datetime(2012, 1, 1, 10, 20, 30) + result = select(e.id for e in Entity1 if dt - e.dt > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + def test_date_sub_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.d - td < date(2009, 1, 1)) + self.assertEqual(set(result), {1}) + + def test_date_sub_const_timedelta(self): + result = select(e.id for e in Entity1 if e.d - timedelta(days=500) < date(2009, 1, 1)) + self.assertEqual(set(result), {1}) + + def test_datetime_sub_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.dt - td < datetime(2009, 1, 1, 10, 20, 30)) + self.assertEqual(set(result), {1}) + + def test_datetime_sub_const_timedelta(self): + result = select(e.id for e in Entity1 if e.dt - timedelta(days=500) < datetime(2009, 1, 1, 10, 20, 30)) + self.assertEqual(set(result), {1}) + + def test_date_add_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.d + td > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + def test_date_add_const_timedelta(self): + result = select(e.id for e in Entity1 if e.d + timedelta(days=500) > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + def test_datetime_add_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.dt + td > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + def test_datetime_add_const_timedelta(self): + result = select(e.id for e in Entity1 if e.dt + timedelta(days=500) > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + @raises_exception(TypeError, "Unsupported operand types 'date' and '%s' " + "for operation '-' in expression: e.d - s" % ('unicode' if PY2 else 'str')) + def test_date_sub_error(self): + s = 'hello' + result = select(e.id for e in Entity1 if e.d - s > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + @raises_exception(TypeError, "Unsupported operand types 'datetime' and '%s' " + "for operation '-' in expression: e.dt - s" % ('unicode' if PY2 else 'str')) + def test_datetime_sub_error(self): + s = 'hello' + result = select(e.id for e in Entity1 if e.dt - s > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 741cf5277..09cd08ed0 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -1,25 +1,31 @@ from __future__ import absolute_import, print_function, division -import unittest +import unittest, warnings from datetime import date from decimal import Decimal from itertools import count from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + class TestDBSession(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + self.db = Database() class X(self.db.Entity): - a = Required(int) + a = PrimaryKey(int) b = Optional(int) self.X = X - self.db.generate_mapping(create_tables=True) + setup_database(self.db) with db_session: x1 = X(a=1, b=1) x2 = X(a=2, b=2) + def tearDown(self): + if self.db.provider.dialect != 'SQLite': + teardown_database(self.db) + @raises_exception(TypeError, "Pass only keyword arguments to db_session or use db_session as decorator") def test_db_session_1(self): db_session(1, 2, 3) @@ -32,6 +38,7 @@ def test_db_session_3(self): self.assertTrue(db_session is db_session()) def test_db_session_4(self): + # Nested db_sessions are ignored with db_session: with db_session: self.X(a=3, b=3) @@ -39,6 +46,7 @@ def test_db_session_4(self): self.assertEqual(count(x for x in self.X), 3) def test_db_session_decorator_1(self): + # Should commit changes on exit from db_session @db_session def test(): self.X(a=3, b=3) @@ -47,6 +55,7 @@ def test(): self.assertEqual(count(x for x in self.X), 3) def test_db_session_decorator_2(self): + # Should rollback changes if an exception is occurred @db_session def test(): self.X(a=3, b=3) @@ -60,6 +69,7 @@ def test(): self.fail() def test_db_session_decorator_3(self): + # Should rollback changes if the exception is not in the list of allowed exceptions @db_session(allowed_exceptions=[TypeError]) def test(): self.X(a=3, b=3) @@ -73,6 +83,7 @@ def test(): self.fail() def test_db_session_decorator_4(self): + # Should commit changes if the exception is in the list of allowed exceptions @db_session(allowed_exceptions=[ZeroDivisionError]) def test(): self.X(a=3, b=3) @@ -85,19 +96,48 @@ def test(): else: self.fail() + def test_allowed_exceptions_1(self): + # allowed_exceptions may be callable, should commit if nonzero + @db_session(allowed_exceptions=lambda e: isinstance(e, ZeroDivisionError)) + def test(): + self.X(a=3, b=3) + 1/0 + try: + test() + except ZeroDivisionError: + with db_session: + self.assertEqual(count(x for x in self.X), 3) + else: + self.fail() + + def test_allowed_exceptions_2(self): + # allowed_exceptions may be callable, should rollback if not nonzero + @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) + def test(): + self.X(a=3, b=3) + 1/0 + try: + test() + except ZeroDivisionError: + with db_session: + self.assertEqual(count(x for x in self.X), 2) + else: + self.fail() + @raises_exception(TypeError, "'retry' parameter of db_session must be of integer type. Got: %r" % str) - def test_db_session_decorator_5(self): + def test_retry_1(self): @db_session(retry='foobar') def test(): pass @raises_exception(TypeError, "'retry' parameter of db_session must not be negative. Got: -1") - def test_db_session_decorator_6(self): + def test_retry_2(self): @db_session(retry=-1) def test(): pass - def test_db_session_decorator_7(self): + def test_retry_3(self): + # Should not to do retry until retry count is specified counter = count() @db_session(retry_exceptions=[ZeroDivisionError]) def test(): @@ -113,7 +153,8 @@ def test(): else: self.fail() - def test_db_session_decorator_8(self): + def test_retry_4(self): + # Should rollback & retry 1 time if retry=1 counter = count() @db_session(retry=1, retry_exceptions=[ZeroDivisionError]) def test(): @@ -129,7 +170,8 @@ def test(): else: self.fail() - def test_db_session_decorator_9(self): + def test_retry_5(self): + # Should rollback & retry N time if retry=N counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) def test(): @@ -145,7 +187,8 @@ def test(): else: self.fail() - def test_db_session_decorator_10(self): + def test_retry_6(self): + # Should not retry if the exception not in the list of retry_exceptions counter = count() @db_session(retry=3, retry_exceptions=[TypeError]) def test(): @@ -161,7 +204,8 @@ def test(): else: self.fail() - def test_db_session_decorator_11(self): + def test_retry_7(self): + # Should commit after successful retrying counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) def test(): @@ -179,53 +223,53 @@ def test(): @raises_exception(TypeError, "The same exception ZeroDivisionError cannot be specified " "in both allowed and retry exception lists simultaneously") - def test_db_session_decorator_12(self): + def test_retry_8(self): @db_session(retry=3, retry_exceptions=[ZeroDivisionError], allowed_exceptions=[ZeroDivisionError]) def test(): pass - def test_db_session_decorator_13(self): - @db_session(allowed_exceptions=lambda e: isinstance(e, ZeroDivisionError)) + def test_retry_9(self): + # retry_exceptions may be callable, should retry if nonzero + counter = count() + @db_session(retry=3, retry_exceptions=lambda e: isinstance(e, ZeroDivisionError)) def test(): + i = next(counter) self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: + self.assertEqual(next(counter), 4) with db_session: - self.assertEqual(count(x for x in self.X), 3) + self.assertEqual(count(x for x in self.X), 2) else: self.fail() - - def test_db_session_decorator_14(self): - @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) + + def test_retry_10(self): + # Issue 313: retry on exception raised during db_session.__exit__ + retries = count() + @db_session(retry=3) def test(): - self.X(a=3, b=3) - 1/0 + next(retries) + self.X(a=1, b=1) try: test() - except ZeroDivisionError: - with db_session: - self.assertEqual(count(x for x in self.X), 2) + except TransactionIntegrityError: + self.assertEqual(next(retries), 4) else: self.fail() - def test_db_session_decorator_15(self): - counter = count() - @db_session(retry=3, retry_exceptions=lambda e: isinstance(e, ZeroDivisionError)) + @raises_exception(PonyRuntimeWarning, '@db_session decorator with `retry=3` option is ignored for test() function ' + 'because it is called inside another db_session') + def test_retry_11(self): + @db_session(retry=3) def test(): - i = next(counter) - self.X(a=3, b=3) - 1/0 - try: - test() - except ZeroDivisionError: - self.assertEqual(next(counter), 4) + pass + with warnings.catch_warnings(): + warnings.simplefilter('error', PonyRuntimeWarning) with db_session: - self.assertEqual(count(x for x in self.X), 2) - else: - self.fail() + test() def test_db_session_manager_1(self): with db_session: @@ -240,6 +284,7 @@ def test_db_session_manager_2(self): self.X(a=3, b=3) def test_db_session_manager_3(self): + # Should rollback if the exception is not in the list of allowed_exceptions try: with db_session(allowed_exceptions=[TypeError]): self.X(a=3, b=3) @@ -251,6 +296,7 @@ def test_db_session_manager_3(self): self.fail() def test_db_session_manager_4(self): + # Should commit if the exception is in the list of allowed_exceptions try: with db_session(allowed_exceptions=[ZeroDivisionError]): self.X(a=3, b=3) @@ -261,13 +307,31 @@ def test_db_session_manager_4(self): else: self.fail() - @raises_exception(TypeError, "@db_session can accept 'ddl' parameter " - "only when used as decorator and not as context manager") + # restriction removed in 0.7.3: + # @raises_exception(TypeError, "@db_session can accept 'ddl' parameter " + # "only when used as decorator and not as context manager") def test_db_session_ddl_1(self): with db_session(ddl=True): pass - @raises_exception(TransactionError, "test() cannot be called inside of db_session") + def test_db_session_ddl_1a(self): + with db_session(ddl=True): + with db_session(ddl=True): + pass + + def test_db_session_ddl_1b(self): + with db_session(ddl=True): + with db_session: + pass + + @raises_exception(TransactionError, 'Cannot start ddl transaction inside non-ddl transaction') + def test_db_session_ddl_1c(self): + with db_session: + with db_session(ddl=True): + pass + + @raises_exception(TransactionError, "@db_session-decorated test() function with `ddl` option " + "cannot be called inside of another db_session") def test_db_session_ddl_2(self): @db_session(ddl=True) def test(): @@ -281,8 +345,44 @@ def test(): pass test() + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_1(self): + def before_insert(self): + 1/0 + self.X.before_insert = before_insert + with db_session: + self.X(a=3, b=3) + # Should raise ZeroDivisionError and not CommitException + + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_2(self): + def before_insert(self): + 1 / 0 + self.X.before_insert = before_insert + with db_session: + self.X(a=3, b=3) + commit() + # Should raise ZeroDivisionError and not CommitException + + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_3(self): + def before_insert(self): + 1 / 0 + self.X.before_insert = before_insert + with db_session: + self.X(a=3, b=3) + db.commit() + # Should raise ZeroDivisionError and not CommitException + + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_4(self): + with db_session: + connection = self.db.get_connection() + connection.close() + 1/0 -db = Database('sqlite', ':memory:') + +db = Database() class Group(db.Entity): id = PrimaryKey(int) @@ -294,83 +394,110 @@ class Student(db.Entity): picture = Optional(buffer, lazy=True) group = Required('Group') -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(id=1, major='Math') - g2 = Group(id=2, major='Physics') - s1 = Student(id=1, name='S1', group=g1) - s2 = Student(id=2, name='S2', group=g1) - s3 = Student(id=3, name='S3', group=g2) +class TestDBSessionScope(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1, major='Math') + g2 = Group(id=2, major='Physics') + s1 = Student(id=1, name='S1', group=g1) + s2 = Student(id=2, name='S2', group=g1) + s3 = Student(id=3, name='S3', group=g2) + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestDBSessionScope(unittest.TestCase): def setUp(self): rollback() + def tearDown(self): rollback() + def test1(self): with db_session: s1 = Student[1] name = s1.name + @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].picture: the database session is over') def test2(self): with db_session: s1 = Student[1] picture = s1.picture + @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over') def test3(self): with db_session: s1 = Student[1] group_id = s1.group.id major = s1.group.major - @raises_exception(DatabaseSessionIsOver, 'Cannot assign new value to attribute Student[1].name: the database session is over') + + @raises_exception(DatabaseSessionIsOver, 'Cannot assign new value to Student[1].name: the database session is over') def test4(self): with db_session: s1 = Student[1] s1.name = 'New name' + def test5(self): with db_session: g1 = Group[1] self.assertEqual(str(g1.students), 'StudentSet([...])') + @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Group[1].students: the database session is over') def test6(self): with db_session: g1 = Group[1] l = len(g1.students) - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test7(self): with db_session: s1 = Student[1] g1 = Group[1] g1.students.remove(s1) - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test8(self): with db_session: g2_students = Group[2].students g1 = Group[1] g1.students = g2_students - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test9(self): with db_session: s3 = Student[3] g1 = Group[1] g1.students.add(s3) - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test10(self): with db_session: g1 = Group[1] g1.students.clear() + @raises_exception(DatabaseSessionIsOver, 'Cannot delete object Student[1]: the database session is over') def test11(self): with db_session: s1 = Student[1] s1.delete() + @raises_exception(DatabaseSessionIsOver, 'Cannot change object Student[1]: the database session is over') def test12(self): with db_session: s1 = Student[1] s1.set(name='New name') + def test_db_session_strict_1(self): + with db_session(strict=True): + s1 = Student[1] + + @raises_exception(DatabaseSessionIsOver, 'Cannot read value of Student[1].name: the database session is over') + def test_db_session_strict_2(self): + with db_session(strict=True): + s1 = Student[1] + name = s1.name + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_declarative_attr_set_monad.py b/pony/orm/tests/test_declarative_attr_set_monad.py index 4f7f855a1..d14557342 100644 --- a/pony/orm/tests/test_declarative_attr_set_monad.py +++ b/pony/orm/tests/test_declarative_attr_set_monad.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -30,42 +31,48 @@ class Mark(db.Entity): subject = Required(Subject) PrimaryKey(student, subject) -db.generate_mapping(create_tables=True) - -with db_session: - g41 = Group(number=41, department=101) - g42 = Group(number=42, department=102) - g43 = Group(number=43, department=102) - g44 = Group(number=44, department=102) - - s1 = Student(id=1, name="Joe", scholarship=None, group=g41) - s2 = Student(id=2, name="Bob", scholarship=100, group=g41) - s3 = Student(id=3, name="Beth", scholarship=500, group=g41) - s4 = Student(id=4, name="Jon", scholarship=500, group=g42) - s5 = Student(id=5, name="Pete", scholarship=700, group=g42) - s6 = Student(id=6, name="Mary", scholarship=300, group=g44) - - Math = Subject(name="Math") - Physics = Subject(name="Physics") - History = Subject(name="History") - - g41.subjects = [ Math, Physics, History ] - g42.subjects = [ Math, Physics ] - g43.subjects = [ Physics ] - - Mark(value=5, student=s1, subject=Math) - Mark(value=4, student=s2, subject=Physics) - Mark(value=3, student=s2, subject=Math) - Mark(value=2, student=s2, subject=History) - Mark(value=1, student=s3, subject=History) - Mark(value=2, student=s3, subject=Math) - Mark(value=2, student=s4, subject=Math) - + class TestAttrSetMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g41 = Group(number=41, department=101) + g42 = Group(number=42, department=102) + g43 = Group(number=43, department=102) + g44 = Group(number=44, department=102) + + s1 = Student(id=1, name="Joe", scholarship=None, group=g41) + s2 = Student(id=2, name="Bob", scholarship=100, group=g41) + s3 = Student(id=3, name="Beth", scholarship=500, group=g41) + s4 = Student(id=4, name="Jon", scholarship=500, group=g42) + s5 = Student(id=5, name="Pete", scholarship=700, group=g42) + s6 = Student(id=6, name="Mary", scholarship=300, group=g44) + + Math = Subject(name="Math") + Physics = Subject(name="Physics") + History = Subject(name="History") + + g41.subjects = [Math, Physics, History] + g42.subjects = [Math, Physics] + g43.subjects = [Physics] + + Mark(value=5, student=s1, subject=Math) + Mark(value=4, student=s2, subject=Physics) + Mark(value=3, student=s2, subject=Math) + Mark(value=2, student=s2, subject=History) + Mark(value=1, student=s3, subject=History) + Mark(value=2, student=s3, subject=Math) + Mark(value=2, student=s4, subject=Math) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() - + def tearDown(self): rollback() db_session.__exit__() @@ -75,13 +82,13 @@ def test1(self): self.assertEqual(groups, [Group[41]]) def test2(self): groups = set(select(g for g in Group if len(g.students.name) >= 2)) - self.assertEqual(groups, set([Group[41], Group[42]])) + self.assertEqual(groups, {Group[41], Group[42]}) def test3(self): groups = select(g for g in Group if len(g.students.marks) > 2)[:] self.assertEqual(groups, [Group[41]]) def test3a(self): groups = select(g for g in Group if len(g.students.marks) < 2)[:] - self.assertEqual(groups, [Group[42], Group[43], Group[44]]) + self.assertEqual(set(groups), {Group[42], Group[43], Group[44]}) def test4(self): groups = select(g for g in Group if max(g.students.marks.value) <= 2)[:] self.assertEqual(groups, [Group[42]]) @@ -90,81 +97,112 @@ def test5(self): self.assertEqual(students, []) def test6(self): students = set(select(s for s in Student if len(s.marks.subject) >= 2)) - self.assertEqual(students, set([Student[2], Student[3]])) + self.assertEqual(students, {Student[2], Student[3]}) def test8(self): students = set(select(s for s in Student if s.group in (g for g in Group if g.department == 101))) - self.assertEqual(students, set([Student[1], Student[2], Student[3]])) + self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test9(self): students = set(select(s for s in Student if s.group not in (g for g in Group if g.department == 101))) - self.assertEqual(students, set([Student[4], Student[5], Student[6]])) + self.assertEqual(students, {Student[4], Student[5], Student[6]}) def test10(self): students = set(select(s for s in Student if s.group in (g for g in Group if g.department == 101))) - self.assertEqual(students, set([Student[1], Student[2], Student[3]])) + self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test11(self): students = set(select(g for g in Group if len(g.subjects.groups.subjects) > 1)) - self.assertEqual(students, set([Group[41], Group[42], Group[43]])) + self.assertEqual(students, {Group[41], Group[42], Group[43]}) def test12(self): groups = set(select(g for g in Group if len(g.subjects) >= 2)) - self.assertEqual(groups, set([Group[41], Group[42]])) + self.assertEqual(groups, {Group[41], Group[42]}) def test13(self): groups = set(select(g for g in Group if g.students)) - self.assertEqual(groups, set([Group[41], Group[42], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test14(self): groups = set(select(g for g in Group if not g.students)) - self.assertEqual(groups, set([Group[43]])) + self.assertEqual(groups, {Group[43]}) def test15(self): groups = set(select(g for g in Group if exists(g.students))) - self.assertEqual(groups, set([Group[41], Group[42], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test15a(self): groups = set(select(g for g in Group if not not exists(g.students))) - self.assertEqual(groups, set([Group[41], Group[42], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test16(self): groups = select(g for g in Group if not exists(g.students))[:] self.assertEqual(groups, [Group[43]]) def test17(self): groups = set(select(g for g in Group if 100 in g.students.scholarship)) - self.assertEqual(groups, set([Group[41]])) + self.assertEqual(groups, {Group[41]}) def test18(self): groups = set(select(g for g in Group if 100 not in g.students.scholarship)) - self.assertEqual(groups, set([Group[42], Group[43], Group[44]])) + self.assertEqual(groups, {Group[42], Group[43], Group[44]}) def test19(self): groups = set(select(g for g in Group if not not not 100 not in g.students.scholarship)) - self.assertEqual(groups, set([Group[41]])) + self.assertEqual(groups, {Group[41]}) def test20(self): groups = set(select(g for g in Group if exists(s for s in Student if s.group == g and s.scholarship == 500))) - self.assertEqual(groups, set([Group[41], Group[42]])) + self.assertEqual(groups, {Group[41], Group[42]}) def test21(self): groups = set(select(g for g in Group if g.department is not None)) - self.assertEqual(groups, set([Group[41], Group[42], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[43], Group[44]}) def test21a(self): groups = set(select(g for g in Group if not g.department is not None)) - self.assertEqual(groups, set([])) + self.assertEqual(groups, set()) def test21b(self): groups = set(select(g for g in Group if not not not g.department is None)) - self.assertEqual(groups, set([Group[41], Group[42], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[43], Group[44]}) def test22(self): groups = set(select(g for g in Group if 700 in (s.scholarship for s in Student if s.group == g))) - self.assertEqual(groups, set([Group[42]])) + self.assertEqual(groups, {Group[42]}) def test23a(self): groups = set(select(g for g in Group if 700 not in g.students.scholarship)) - self.assertEqual(groups, set([Group[41], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[43], Group[44]}) def test23b(self): groups = set(select(g for g in Group if 700 not in (s.scholarship for s in Student if s.group == g))) - self.assertEqual(groups, set([Group[41], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[43], Group[44]}) @raises_exception(NotImplementedError) def test24(self): groups = set(select(g for g in Group for g2 in Group if g.students == g2.students)) def test25(self): m1 = Mark[Student[1], Subject["Math"]] students = set(select(s for s in Student if m1 in s.marks)) - self.assertEqual(students, set([Student[1]])) + self.assertEqual(students, {Student[1]}) def test26(self): s1 = Student[1] groups = set(select(g for g in Group if s1 in g.students)) - self.assertEqual(groups, set([Group[41]])) + self.assertEqual(groups, {Group[41]}) @raises_exception(AttributeError, 'g.students.name.foo') def test27(self): select(g for g in Group if g.students.name.foo == 1) + def test28(self): + groups = set(select(g for g in Group if not g.students.is_empty())) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) + @raises_exception(NotImplementedError) + def test29(self): + students = select(g.students.select(lambda s: s.scholarship > 0) for g in Group if g.department == 101)[:] + def test30a(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.select(lambda s: s.scholarship > 0))[:] + self.assertEqual(set(groups), {Group[41]}) + def test30b(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.filter(lambda s: s.scholarship > 0))[:] + self.assertEqual(set(groups), {Group[41]}) + def test30c(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.select())[:] + self.assertEqual(set(groups), {Group[41]}) + def test30d(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.filter())[:] + self.assertEqual(set(groups), {Group[41]}) + def test31(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 and g.students.exists(lambda s: s.scholarship > 0))[:] + self.assertEqual(set(groups), {Group[41]}) + if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_declarative_date.py b/pony/orm/tests/test_declarative_date.py deleted file mode 100644 index b7a89541f..000000000 --- a/pony/orm/tests/test_declarative_date.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import unittest -from datetime import date, datetime - -from pony.orm.core import * -from pony.orm.tests.testutils import * - -db = Database('sqlite', ':memory:') - -class Entity1(db.Entity): - a = PrimaryKey(int) - b = Required(date) - c = Required(datetime) - -db.generate_mapping(create_tables=True) - -with db_session: - Entity1(a=1, b=date(2009, 10, 20), c=datetime(2009, 10, 20, 10, 20, 30)) - Entity1(a=2, b=date(2010, 10, 21), c=datetime(2010, 10, 21, 10, 21, 31)) - Entity1(a=3, b=date(2011, 11, 22), c=datetime(2011, 11, 22, 10, 20, 32)) - -class TestDate(unittest.TestCase): - def setUp(self): - rollback() - db_session.__enter__() - def tearDown(self): - rollback() - db_session.__exit__() - def test_create(self): - e1 = Entity1(a=4, b=date(2011, 10, 20), c=datetime(2009, 10, 20, 10, 20, 30)) - def test_date_year(self): - result = select(e for e in Entity1 if e.b.year > 2009) - self.assertEqual(len(result), 2) - def test_date_month(self): - result = select(e for e in Entity1 if e.b.month == 10) - self.assertEqual(len(result), 2) - def test_date_day(self): - result = select(e for e in Entity1 if e.b.day == 22) - self.assertEqual(len(result), 1) - def test_datetime_year(self): - result = select(e for e in Entity1 if e.c.year > 2009) - self.assertEqual(len(result), 2) - def test_datetime_month(self): - result = select(e for e in Entity1 if e.c.month == 10) - self.assertEqual(len(result), 2) - def test_datetime_day(self): - result = select(e for e in Entity1 if e.c.day == 22) - self.assertEqual(len(result), 1) - def test_datetime_hour(self): - result = select(e for e in Entity1 if e.c.hour == 10) - self.assertEqual(len(result), 3) - def test_datetime_minute(self): - result = select(e for e in Entity1 if e.c.minute == 20) - self.assertEqual(len(result), 2) - def test_datetime_second(self): - result = select(e for e in Entity1 if e.c.second == 30) - self.assertEqual(len(result), 1) - -if __name__ == '__main__': - unittest.main() diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 0d66417a0..6154086ff 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PYPY, PYPY2 import sys, unittest from datetime import date @@ -7,8 +8,9 @@ from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -33,16 +35,20 @@ class Course(db.Entity): PrimaryKey(name, semester) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - d1 = Department(number=44) - g1 = Group(number=101, dept=d1) - Student(name='S1', group=g1) - Student(name='S2', group=g1) - Student(name='S3', group=g1) class TestSQLTranslatorExceptions(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=44) + g1 = Group(number=101, dept=d1) + Student(name='S1', group=g1) + Student(name='S2', group=g1) + Student(name='S3', group=g1) + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -53,11 +59,11 @@ def tearDown(self): def test1(self): x = 10 select(s for s in Student for x in s.name) - @raises_exception(TranslationError, "Inside declarative query, iterator must be entity. Got: for i in x") + @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for i in x") def test2(self): x = [1, 2, 3] select(s for s in Student for i in x) - @raises_exception(TranslationError, "Inside declarative query, iterator must be entity. Got: for s2 in g.students") + @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for s2 in g.students") def test3(self): g = Group[101] select(s for s in Student for s2 in g.students) @@ -67,29 +73,30 @@ def test4(self): select(s for s in Student if s.name.upper(*args)) if sys.version_info[:2] < (3, 5): # TODO - @raises_exception(TypeError, "Expression `{'a':'b', 'c':'d'}` has unsupported type 'dict'") + @raises_exception(NotImplementedError) # "**{'a': 'b', 'c': 'd'} is not supported def test5(self): select(s for s in Student if s.name.upper(**{'a':'b', 'c':'d'})) - @raises_exception(ExprEvalError, "1 in 2 raises TypeError: argument of type 'int' is not iterable") + @raises_exception(ExprEvalError, "`1 in 2` raises TypeError: argument of type 'int' is not iterable" if not PYPY else + "`1 in 2` raises TypeError: 'int' object is not iterable") def test6(self): select(s for s in Student if 1 in 2) @raises_exception(NotImplementedError, 'Group[s.group.number]') def test7(self): select(s for s in Student if Group[s.group.number].dept.number == 44) - @raises_exception(ExprEvalError, "Group[123, 456].dept.number == 44 raises TypeError: Invalid count of attrs in Group primary key (2 instead of 1)") + @raises_exception(ExprEvalError, "`Group[123, 456].dept.number == 44` raises TypeError: Invalid count of attrs in Group primary key (2 instead of 1)") def test8(self): select(s for s in Student if Group[123, 456].dept.number == 44) - @raises_exception(ExprEvalError, "Course[123] raises TypeError: Invalid count of attrs in Course primary key (1 instead of 2)") + @raises_exception(ExprEvalError, "`Course[123]` raises TypeError: Invalid count of attrs in Course primary key (1 instead of 2)") def test9(self): select(s for s in Student if Course[123] in s.courses) @raises_exception(TypeError, "Incomparable types '%s' and 'float' in expression: s.name < s.gpa" % unicode.__name__) def test10(self): select(s for s in Student if s.name < s.gpa) - @raises_exception(ExprEvalError, "Group(101) raises TypeError: Group constructor accept only keyword arguments. Got: 1 positional argument") + @raises_exception(ExprEvalError, "`Group(101)` raises TypeError: Group constructor accept only keyword arguments. Got: 1 positional argument") def test11(self): select(s for s in Student if s.group == Group(101)) - @raises_exception(ExprEvalError, "Group[date(2011, 1, 2)] raises TypeError: Value type for attribute Group.number must be int. Got: %r" % date) + @raises_exception(ExprEvalError, "`Group[date(2011, 1, 2)]` raises TypeError: Value type for attribute Group.number must be int. Got: %r" % date) def test12(self): select(s for s in Student if s.group == Group[date(2011, 1, 2)]) @raises_exception(TypeError, "Unsupported operand types 'int' and '%s' for operation '+' in expression: s.group.number + s.name" % unicode.__name__) @@ -116,9 +123,6 @@ def test18(self): % unicode.__name__) def test19(self): select(s for s in Student if s.name[1:'a'] == 'A') - @raises_exception(NotImplementedError, "Negative indices are not supported in string slice s.name[-1:1]") - def test20(self): - select(s for s in Student if s.name[-1:1] == 'A') @raises_exception(TypeError, "String indices must be integers. Got '%s' in expression s.name['a']" % unicode.__name__) def test21(self): select(s.name for s in Student if s.name['a'] == 'h') @@ -134,13 +138,13 @@ def test24(self): @raises_exception(TypeError, "'chars' argument must be of '%s' type in s.name.strip(1), got: 'int'" % unicode.__name__) def test25(self): select(s.name for s in Student if s.name.strip(1)) - @raises_exception(AttributeError, "'%s' object has no attribute 'unknown'" % unicode.__name__) + @raises_exception(AttributeError, "'%s' object has no attribute 'unknown': s.name.unknown" % unicode.__name__) def test26(self): result = set(select(s for s in Student if s.name.unknown() == "joe")) @raises_exception(AttributeError, "Entity Group does not have attribute foo: s.group.foo") def test27(self): select(s.name for s in Student if s.group.foo.bar == 10) - @raises_exception(ExprEvalError, "g.dept.foo.bar raises AttributeError: 'Department' object has no attribute 'foo'") + @raises_exception(ExprEvalError, "`g.dept.foo.bar` raises AttributeError: 'Department' object has no attribute 'foo'") def test28(self): g = Group[101] select(s for s in Student if s.name == g.dept.foo.bar) @@ -151,7 +155,9 @@ def test29(self): @raises_exception(NotImplementedError, "date(s.id, 1, 1)") def test30(self): select(s for s in Student if s.dob < date(s.id, 1, 1)) - @raises_exception(ExprEvalError, "max() raises TypeError: max expected 1 arguments, got 0") + @raises_exception(ExprEvalError, "`max()` raises TypeError: max() expects at least one argument" if PYPY else + "`max()` raises TypeError: max expected 1 arguments, got 0" if sys.version_info[:2] < (3, 8) else + "`max()` raises TypeError: max expected 1 argument, got 0") def test31(self): select(s for s in Student if s.id < max()) @raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses") @@ -178,7 +184,10 @@ def test38(self): @raises_exception(TypeError, "strip() takes at most 1 argument (3 given)") def test39(self): select(s for s in Student if s.name.strip(1, 2, 3)) - @raises_exception(ExprEvalError, "len(1, 2) == 3 raises TypeError: len() takes exactly one argument (2 given)") + @raises_exception(ExprEvalError, + "`len(1, 2) == 3` raises TypeError: len() takes exactly 1 argument (2 given)" if PYPY2 else + "`len(1, 2) == 3` raises TypeError: len() takes 1 positional argument but 2 were given" if PYPY else + "`len(1, 2) == 3` raises TypeError: len() takes exactly one argument (2 given)") def test40(self): select(s for s in Student if len(1, 2) == 3) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got 'Student' in sum(s for s in Student if s.group == g)") @@ -208,12 +217,11 @@ def test48(self): @raises_exception(TypeError, "'sum' is valid for numeric attributes only") def test49(self): sum(s.name for s in Student) - - if sys.version_info[:2] < (3, 5): # TODO - @raises_exception(TypeError, "Expression `{'a':'b'}` has unsupported type 'dict'") - def test50(self): - select(s for s in Student if s.name == {'a' : 'b'}) - + @raises_exception(TypeError, "Cannot compare whole JSON value, you need to select specific sub-item: s.name == {'a':'b'}") + def test50(self): + # cannot compare JSON value to dynamic string, + # because a database does not provide json.dumps(s.name) functionality + select(s for s in Student if s.name == {'a': 'b'}) @raises_exception(IncomparableTypesError, "Incomparable types '%s' and 'int' in expression: s.name > a & 2" % unicode.__name__) def test51(self): a = 1 diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index 09cba95a2..707ed5e11 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -1,15 +1,16 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2 +from pony.py23compat import PY2, PYPY, PYPY2 -import unittest +import sys, unittest from datetime import date, datetime from decimal import Decimal from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): id = PrimaryKey(int) @@ -25,29 +26,31 @@ class Group(db.Entity): students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - g2 = Group(number=2) - - Student(id=1, name="AA", dob=date(1981, 1, 1), last_visit=datetime(2011, 1, 1, 11, 11, 11), - scholarship=Decimal("0"), phd=True, group=g1) - - Student(id=2, name="BB", dob=date(1982, 2, 2), last_visit=datetime(2011, 2, 2, 12, 12, 12), - scholarship=Decimal("202.2"), phd=True, group=g1) - - Student(id=3, name="CC", dob=date(1983, 3, 3), last_visit=datetime(2011, 3, 3, 13, 13, 13), - scholarship=Decimal("303.3"), phd=False, group=g1) - - Student(id=4, name="DD", dob=date(1984, 4, 4), last_visit=datetime(2011, 4, 4, 14, 14, 14), - scholarship=Decimal("404.4"), phd=False, group=g2) - - Student(id=5, name="EE", dob=date(1985, 5, 5), last_visit=datetime(2011, 5, 5, 15, 15, 15), - scholarship=Decimal("505.5"), phd=False, group=g2) - - class TestFuncMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + + Student(id=1, name="AA", dob=date(1981, 1, 1), last_visit=datetime(2011, 1, 1, 11, 11, 11), + scholarship=Decimal("0"), phd=True, group=g1) + + Student(id=2, name="BB", dob=date(1982, 2, 2), last_visit=datetime(2011, 2, 2, 12, 12, 12), + scholarship=Decimal("202.2"), phd=True, group=g1) + + Student(id=3, name="CC", dob=date(1983, 3, 3), last_visit=datetime(2011, 3, 3, 13, 13, 13), + scholarship=Decimal("303.3"), phd=False, group=g1) + + Student(id=4, name="DD", dob=date(1984, 4, 4), last_visit=datetime(2011, 4, 4, 14, 14, 14), + scholarship=Decimal("404.4"), phd=False, group=g2) + + Student(id=5, name="EE", dob=date(1985, 5, 5), last_visit=datetime(2011, 5, 5, 15, 15, 15), + scholarship=Decimal("505.5"), phd=False, group=g2) + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -56,16 +59,16 @@ def tearDown(self): db_session.__exit__() def test_minmax1(self): result = set(select(s for s in Student if max(s.id, 3) == 3 )) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_minmax2(self): result = set(select(s for s in Student if min(s.id, 3) == 3 )) - self.assertEqual(result, set([Student[4], Student[5], Student[3]])) + self.assertEqual(result, {Student[4], Student[5], Student[3]}) def test_minmax3(self): result = set(select(s for s in Student if max(s.name, "CC") == "CC" )) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_minmax4(self): result = set(select(s for s in Student if min(s.name, "CC") == "CC" )) - self.assertEqual(result, set([Student[4], Student[5], Student[3]])) + self.assertEqual(result, {Student[4], Student[5], Student[3]}) def test_minmax5(self): x = chr(128) try: result = set(select(s for s in Student if min(s.name, x) == "CC" )) @@ -82,7 +85,7 @@ def test_minmax7(self): result = set(select(s for s in Student if min(s.phd, 2) == 2 )) def test_date_func1(self): result = set(select(s for s in Student if s.dob >= date(1983, 3, 3))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) # @raises_exception(ExprEvalError, "date(1983, 'three', 3) raises TypeError: an integer is required") @raises_exception(TypeError, "'month' argument of date(year, month, day) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) @@ -94,13 +97,13 @@ def test_date_func2(self): # result = set(select(s for s in Student if s.dob >= date(1983, d, 3))) def test_datetime_func1(self): result = set(select(s for s in Student if s.last_visit >= date(2011, 3, 3))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_datetime_func2(self): result = set(select(s for s in Student if s.last_visit >= datetime(2011, 3, 3))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_datetime_func3(self): result = set(select(s for s in Student if s.last_visit >= datetime(2011, 3, 3, 13, 13, 13))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) # @raises_exception(ExprEvalError, "datetime(1983, 'three', 3) raises TypeError: an integer is required") @raises_exception(TypeError, "'month' argument of datetime(...) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) @@ -112,21 +115,28 @@ def test_datetime_func4(self): # result = set(select(s for s in Student if s.last_visit >= date(1983, d, 3))) def test_datetime_now1(self): result = set(select(s for s in Student if s.dob < date.today())) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) - @raises_exception(ExprEvalError, "1 < datetime.now() raises TypeError: " + - ("can't compare datetime.datetime to int" if PY2 else - "unorderable types: int() < datetime.datetime()")) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) + @raises_exception(ExprEvalError, "`1 < datetime.now()` raises TypeError: " + ( + "can't compare 'datetime' to 'int'" if PYPY2 else + "'<' not supported between instances of 'int' and 'datetime'" if PYPY and sys.version_info >= (3, 6) else + "unorderable types: int < datetime" if PYPY else + "can't compare datetime.datetime to int" if PY2 else + "unorderable types: int() < datetime.datetime()" if sys.version_info < (3, 6) else + "'<' not supported between instances of 'int' and 'datetime.datetime'")) def test_datetime_now2(self): select(s for s in Student if 1 < datetime.now()) def test_datetime_now3(self): result = set(select(s for s in Student if s.dob < datetime.today())) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test_decimal_func(self): result = set(select(s for s in Student if s.scholarship >= Decimal("303.3"))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_concat_1(self): result = set(select(concat(s.name, ':', s.dob.year, ':', s.scholarship) for s in Student)) - self.assertEqual(result, set(['AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'])) + if db.provider.dialect == 'PostgreSQL': + self.assertEqual(result, {'AA:1981:0.00', 'BB:1982:202.20', 'CC:1983:303.30', 'DD:1984:404.40', 'EE:1985:505.50'}) + else: + self.assertEqual(result, {'AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'}) @raises_exception(TranslationError, 'Invalid argument of concat() function: g.students') def test_concat_2(self): result = set(select(concat(g.number, g.students) for g in Group)) diff --git a/pony/orm/tests/test_declarative_join_optimization.py b/pony/orm/tests/test_declarative_join_optimization.py index 89c14de73..98eb8b99a 100644 --- a/pony/orm/tests/test_declarative_join_optimization.py +++ b/pony/orm/tests/test_declarative_join_optimization.py @@ -5,8 +5,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): name = Required(str) @@ -37,9 +38,13 @@ class Student(db.Entity): courses = Set(Course) -db.generate_mapping(create_tables=True) - class TestM2MOptimization(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -61,15 +66,24 @@ def test4(self): self.assertEqual(Group._table_ not in flatten(q._translator.conditions), True) def test5(self): q = select(s for s in Student if s.group.number == 1 or s.group.major == '1') - self.assertEqual(Group._table_ in flatten(q._translator.subquery.from_ast), True) + self.assertEqual(Group._table_ in flatten(q._translator.sqlquery.from_ast), True) # def test6(self): ### Broken with ExprEvalError: Group[101] raises ObjectNotFound: Group[101] # q = select(s for s in Student if s.group == Group[101]) - # self.assertEqual(Group._table_ not in flatten(q._translator.subquery.from_ast), True) + # self.assertEqual(Group._table_ not in flatten(q._translator.sqlquery.from_ast), True) def test7(self): q = select(s for s in Student if sum(c.credits for c in Course if s.group.dept == c.dept) > 10) objects = q[:] - self.assertEqual(str(q._translator.subquery.from_ast), - "['FROM', ['s', 'TABLE', 'Student'], ['group-1', 'TABLE', 'Group', ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group-1', 'number']]]]") + student_table_name = 'Student' + group_table_name = 'Group' + if not (db.provider.dialect == 'SQLite' and pony.__version__ < '0.9'): + student_table_name = student_table_name.lower() + group_table_name = group_table_name.lower() + self.assertEqual(q._translator.sqlquery.from_ast, [ + 'FROM', ['s', 'TABLE', student_table_name], + ['group', 'TABLE', group_table_name, + ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group', 'number']] + ] + ]) if __name__ == '__main__': diff --git a/pony/orm/tests/test_declarative_object_flat_monad.py b/pony/orm/tests/test_declarative_object_flat_monad.py index f374962e8..4b695b76b 100644 --- a/pony/orm/tests/test_declarative_object_flat_monad.py +++ b/pony/orm/tests/test_declarative_object_flat_monad.py @@ -2,8 +2,9 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -28,41 +29,47 @@ class Mark(db.Entity): subject = Required(Subject) PrimaryKey(student, subject) -db.generate_mapping(create_tables=True) -with db_session: - Math = Subject(name="Math") - Physics = Subject(name="Physics") - History = Subject(name="History") +class TestObjectFlatMonad(unittest.TestCase): + @classmethod + def setUpClass(self): + setup_database(db) + with db_session: + Math = Subject(name="Math") + Physics = Subject(name="Physics") + History = Subject(name="History") - g41 = Group(number=41, department=101, subjects=[ Math, Physics, History ]) - g42 = Group(number=42, department=102, subjects=[ Math, Physics ]) - g43 = Group(number=43, department=102, subjects=[ Physics ]) + g41 = Group(number=41, department=101, subjects=[Math, Physics, History]) + g42 = Group(number=42, department=102, subjects=[Math, Physics]) + g43 = Group(number=43, department=102, subjects=[Physics]) - s1 = Student(id=1, name="Joe", scholarship=None, group=g41) - s2 = Student(id=2, name="Bob", scholarship=100, group=g41) - s3 = Student(id=3, name="Beth", scholarship=500, group=g41) - s4 = Student(id=4, name="Jon", scholarship=500, group=g42) - s5 = Student(id=5, name="Pete", scholarship=700, group=g42) + s1 = Student(id=1, name="Joe", scholarship=None, group=g41) + s2 = Student(id=2, name="Bob", scholarship=100, group=g41) + s3 = Student(id=3, name="Beth", scholarship=500, group=g41) + s4 = Student(id=4, name="Jon", scholarship=500, group=g42) + s5 = Student(id=5, name="Pete", scholarship=700, group=g42) - Mark(value=5, student=s1, subject=Math) - Mark(value=4, student=s2, subject=Physics) - Mark(value=3, student=s2, subject=Math) - Mark(value=2, student=s2, subject=History) - Mark(value=1, student=s3, subject=History) - Mark(value=2, student=s3, subject=Math) - Mark(value=2, student=s4, subject=Math) + Mark(value=5, student=s1, subject=Math) + Mark(value=4, student=s2, subject=Physics) + Mark(value=3, student=s2, subject=Math) + Mark(value=2, student=s2, subject=History) + Mark(value=1, student=s3, subject=History) + Mark(value=2, student=s3, subject=Math) + Mark(value=2, student=s4, subject=Math) + + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestObjectFlatMonad(unittest.TestCase): @db_session def test1(self): result = set(select(s.groups for s in Subject if len(s.name) == 4)) - self.assertEqual(result, set([Group[41], Group[42]])) + self.assertEqual(result, {Group[41], Group[42]}) @db_session def test2(self): result = set(select(g.students for g in Group if g.department == 102)) - self.assertEqual(result, set([Student[5], Student[4]])) + self.assertEqual(result, {Student[5], Student[4]}) if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_declarative_orderby_limit.py b/pony/orm/tests/test_declarative_orderby_limit.py index a880634b4..bb906f392 100644 --- a/pony/orm/tests/test_declarative_orderby_limit.py +++ b/pony/orm/tests/test_declarative_orderby_limit.py @@ -4,24 +4,31 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) group = Required(int) -db.generate_mapping(create_tables=True) - -with db_session: - Student(id=1, name="B", scholarship=None, group=41) - Student(id=2, name="C", scholarship=700, group=41) - Student(id=3, name="A", scholarship=500, group=42) - Student(id=4, name="D", scholarship=500, group=43) - Student(id=5, name="E", scholarship=700, group=42) class TestOrderbyLimit(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Student(id=1, name="B", scholarship=None, group=41) + Student(id=2, name="C", scholarship=700, group=41) + Student(id=3, name="A", scholarship=500, group=42) + Student(id=4, name="D", scholarship=500, group=43) + Student(id=5, name="E", scholarship=700, group=42) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -32,34 +39,34 @@ def tearDown(self): def test1(self): students = set(select(s for s in Student).order_by(Student.name)) - self.assertEqual(students, set([Student[3], Student[1], Student[2], Student[4], Student[5]])) + self.assertEqual(students, {Student[3], Student[1], Student[2], Student[4], Student[5]}) def test2(self): students = set(select(s for s in Student).order_by(Student.name.asc)) - self.assertEqual(students, set([Student[3], Student[1], Student[2], Student[4], Student[5]])) + self.assertEqual(students, {Student[3], Student[1], Student[2], Student[4], Student[5]}) def test3(self): students = set(select(s for s in Student).order_by(Student.id.desc)) - self.assertEqual(students, set([Student[5], Student[4], Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[5], Student[4], Student[3], Student[2], Student[1]}) def test4(self): students = set(select(s for s in Student).order_by(Student.scholarship.asc, Student.group.desc)) - self.assertEqual(students, set([Student[1], Student[4], Student[3], Student[5], Student[2]])) + self.assertEqual(students, {Student[1], Student[4], Student[3], Student[5], Student[2]}) def test5(self): students = set(select(s for s in Student).order_by(Student.name).limit(3)) - self.assertEqual(students, set([Student[3], Student[1], Student[2]])) + self.assertEqual(students, {Student[3], Student[1], Student[2]}) def test6(self): students = set(select(s for s in Student).order_by(Student.name).limit(3, 1)) - self.assertEqual(students, set([Student[1], Student[2], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[4]}) def test7(self): q = select(s for s in Student).order_by(Student.name).limit(3, 1) students = set(q) - self.assertEqual(students, set([Student[1], Student[2], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[4]}) students = set(q) - self.assertEqual(students, set([Student[1], Student[2], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[4]}) # @raises_exception(TypeError, "query.order_by() arguments must be attributes. Got: 'name'") # now generate: ExprEvalError: name raises NameError: name 'name' is not defined @@ -68,21 +75,21 @@ def test7(self): def test9(self): students = set(select(s for s in Student).order_by(Student.id)[1:4]) - self.assertEqual(students, set([Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test10(self): students = set(select(s for s in Student).order_by(Student.id)[:4]) - self.assertEqual(students, set([Student[1], Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4]}) - @raises_exception(TypeError, "Parameter 'stop' of slice object should be specified") - def test11(self): - students = select(s for s in Student).order_by(Student.id)[4:] + # @raises_exception(TypeError, "Parameter 'stop' of slice object should be specified") + # def test11(self): + # students = select(s for s in Student).order_by(Student.id)[4:] @raises_exception(TypeError, "Parameter 'start' of slice object cannot be negative") def test12(self): students = select(s for s in Student).order_by(Student.id)[-3:2] - @raises_exception(TypeError, 'If you want apply index to query, convert it to list first') + @raises_exception(TypeError, 'If you want apply index to a query, convert it to list first') def test13(self): students = select(s for s in Student).order_by(Student.id)[3] self.assertEqual(students, Student[4]) @@ -93,19 +100,19 @@ def test13(self): def test15(self): students = set(select(s for s in Student).order_by(Student.id)[0:4][1:3]) - self.assertEqual(students, set([Student[2], Student[3]])) + self.assertEqual(students, {Student[2], Student[3]}) def test16(self): students = set(select(s for s in Student).order_by(Student.id)[0:4][1:]) - self.assertEqual(students, set([Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test17(self): students = set(select(s for s in Student).order_by(Student.id)[:4][1:]) - self.assertEqual(students, set([Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test18(self): students = set(select(s for s in Student).order_by(Student.id)[:]) - self.assertEqual(students, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) + self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test19(self): q = select(s for s in Student).order_by(Student.id) @@ -116,5 +123,36 @@ def test19(self): students = q[:] self.assertEqual(students, [Student[1], Student[2], Student[3], Student[4], Student[5]]) + def test20(self): + q = select(s for s in Student).limit(offset=2) + self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) + last_sql = db.last_sql + if db.provider.dialect == 'PostgreSQL': + self.assertTrue('LIMIT null OFFSET 2' in last_sql) + else: + self.assertTrue('LIMIT -1 OFFSET 2' in last_sql) + + def test21(self): + q = select(s for s in Student).limit(0, offset=2) + self.assertEqual(set(q), set()) + + def test22(self): + q = select(s for s in Student).order_by(Student.id).limit(offset=1) + self.assertEqual(set(q), {Student[2], Student[3], Student[4], Student[5]}) + + def test23(self): + q = select(s for s in Student)[2:2] + self.assertEqual(set(q), set()) + self.assertTrue('LIMIT 0' in db.last_sql) + + def test24(self): + q = select(s for s in Student)[2:] + self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) + + def test25(self): + q = select(s for s in Student)[:2] + self.assertEqual(set(q), {Student[2], Student[1]}) + + if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 07555ecd0..3edc8f52e 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Group(db.Entity): id = PrimaryKey(int) @@ -13,6 +14,7 @@ class Group(db.Entity): class Student(db.Entity): name = Required(unicode) + age = Required(int) group = Required('Group') scholarship = Required(int, default=0) courses = Set('Course') @@ -23,20 +25,25 @@ class Course(db.Entity): PrimaryKey(name, semester) students = Set('Student') -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(id=1) - g2 = Group(id=2) - s1 = Student(id=1, name='S1', group=g1, scholarship=0) - s2 = Student(id=2, name='S2', group=g1, scholarship=100) - s3 = Student(id=3, name='S3', group=g2, scholarship=500) - c1 = Course(name='C1', semester=1, students=[s1, s2]) - c2 = Course(name='C2', semester=1, students=[s2, s3]) - c3 = Course(name='C3', semester=2, students=[s3]) - class TestQuerySetMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1) + g2 = Group(id=2) + s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) + s2 = Student(id=2, name='S2', age=23, group=g1, scholarship=100) + s3 = Student(id=3, name='S3', age=23, group=g2, scholarship=500) + c1 = Course(name='C1', semester=1, students=[s1, s2]) + c2 = Course(name='C2', semester=1, students=[s2, s3]) + c3 = Course(name='C3', semester=2, students=[s3]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -47,52 +54,113 @@ def tearDown(self): def test_len(self): result = set(select(g for g in Group if len(g.students) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_len_2(self): result = set(select(g for g in Group if len(s for s in Student if s.group == g) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_len_3(self): result = set(select(g for g in Group if len(s.name for s in Student if s.group == g) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_count_1(self): result = set(select(g for g in Group if count(s.name for s in g.students) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_count_2(self): result = set(select(g for g in Group if select(s.name for s in g.students).count() > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_count_3(self): result = set(select(s for s in Student if count(c for c in s.courses) > 1)) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) + + def test_count_3a(self): + result = set(select(s for s in Student if select(c for c in s.courses).count() > 1)) + self.assertEqual(result, {Student[2], Student[3]}) + self.assertTrue('DISTINCT' in db.last_sql) + + def test_count_3b(self): + result = set(select(s for s in Student if select(c for c in s.courses).count(distinct=False) > 1)) + self.assertEqual(result, {Student[2], Student[3]}) + self.assertTrue('DISTINCT' not in db.last_sql) def test_count_4(self): result = set(select(c for c in Course if count(s for s in c.students) > 1)) - self.assertEqual(result, set([Course['C1', 1], Course['C2', 1]])) + self.assertEqual(result, {Course['C1', 1], Course['C2', 1]}) + + def test_count_5(self): + result = select(c.semester for c in Course).count(distinct=True) + self.assertEqual(result, 2) + + def test_count_6(self): + result = select(c for c in Course).count() + self.assertEqual(result, 3) + self.assertTrue('DISTINCT' not in db.last_sql) + + def test_count_7(self): + result = select(c for c in Course).count(distinct=True) + self.assertEqual(result, 3) + self.assertTrue('DISTINCT' in db.last_sql) + + def test_count_8(self): + select(count(c.semester, distinct=False) for c in Course)[:] + self.assertTrue('DISTINCT' not in db.last_sql) + + @raises_exception(TypeError, "`distinct` value should be True or False. Got: s.name.startswith('P')") + def test_count_9(self): + select(count(s, distinct=s.name.startswith('P')) for s in Student) + + def test_count_10(self): + select(count('*', distinct=True) for s in Student)[:] + self.assertTrue('DISTINCT' not in db.last_sql) @raises_exception(TypeError) def test_sum_1(self): result = set(select(g for g in Group if sum(s for s in Student if s.group == g) > 1)) - self.assertEqual(result, set([])) @raises_exception(TypeError) def test_sum_2(self): select(g for g in Group if sum(s.name for s in Student if s.group == g) > 1) def test_sum_3(self): - result = set(select(g for g in Group if sum(s.scholarship for s in Student if s.group == g) > 500)) - self.assertEqual(result, set([])) + result = sum(s.scholarship for s in Student) + self.assertEqual(result, 600) def test_sum_4(self): + result = sum(s.scholarship for s in Student if s.name == 'Unnamed') + self.assertEqual(result, 0) + + def test_sum_5(self): + result = select(c.semester for c in Course).sum() + self.assertEqual(result, 4) + + def test_sum_6(self): + result = select(c.semester for c in Course).sum(distinct=True) + self.assertEqual(result, 3) + + def test_sum_7(self): + result = set(select(g for g in Group if sum(s.scholarship for s in Student if s.group == g) > 500)) + self.assertEqual(result, set()) + + def test_sum_8(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum() > 200)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) + self.assertTrue('DISTINCT' not in db.last_sql) + + def test_sum_9(self): + result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum(distinct=True) > 200)) + self.assertEqual(result, {Group[2]}) + self.assertTrue('DISTINCT' in db.last_sql) + + def test_sum_10(self): + select(sum(s.scholarship, distinct=True) for s in Student)[:] + self.assertTrue('SUM(DISTINCT' in db.last_sql) def test_min_1(self): result = set(select(g for g in Group if min(s.name for s in Student if s.group == g) == 'S1')) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) @raises_exception(TypeError) def test_min_2(self): @@ -100,11 +168,15 @@ def test_min_2(self): def test_min_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).min() == 0)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) + + def test_min_4(self): + result = select(s.scholarship for s in Student).min() + self.assertEqual(0, result) def test_max_1(self): result = set(select(g for g in Group if max(s.scholarship for s in Student if s.group == g) > 100)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) @raises_exception(TypeError) def test_max_2(self): @@ -112,7 +184,11 @@ def test_max_2(self): def test_max_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).max() == 100)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) + + def test_max_4(self): + result = select(s.scholarship for s in Student).max() + self.assertEqual(result, 500) def test_avg_1(self): result = select(g for g in Group if avg(s.scholarship for s in Student if s.group == g) == 50)[:] @@ -120,40 +196,178 @@ def test_avg_1(self): def test_avg_2(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg() == 50)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) - def test_exists(self): - result = set(select(g for g in Group if exists(s for s in g.students if s.name == 'S1'))) - self.assertEqual(result, set([Group[1]])) + def test_avg_3(self): + result = select(c.semester for c in Course).avg() + self.assertAlmostEqual(1.33, result, places=2) + + def test_avg_4(self): + result = select(c.semester for c in Course).avg(distinct=True) + self.assertAlmostEqual(1.5, result) + + def test_avg_5(self): + result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg(distinct=True) == 50)) + self.assertEqual(result, {Group[1]}) + self.assertTrue('AVG(DISTINCT' in db.last_sql) + + def test_avg_6(self): + select(avg(s.scholarship, distinct=True) for s in Student)[:] + self.assertTrue('AVG(DISTINCT' in db.last_sql) + + def test_exists_1(self): + result = set(select(g for g in Group if exists(s for s in g.students if s.age < 23))) + self.assertEqual(result, {Group[1]}) + + def test_exists_2(self): + result = set(select(g for g in Group if exists(s.age < 23 for s in g.students))) + self.assertEqual(result, {Group[1]}) + + def test_exists_3(self): + result = set(select(g for g in Group if (s.age < 23 for s in g.students))) + self.assertEqual(result, {Group[1]}) def test_negate(self): result = set(select(g for g in Group if not(s.scholarship for s in Student if s.group == g))) - self.assertEqual(result, set([])) + self.assertEqual(result, set()) def test_no_conditions(self): students = set(select(s for s in Student if s.group in (g for g in Group))) - self.assertEqual(students, set([Student[1], Student[2], Student[3]])) + self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test_no_conditions_2(self): students = set(select(s for s in Student if s.scholarship == max(s.scholarship for s in Student))) - self.assertEqual(students, set([Student[3]])) + self.assertEqual(students, {Student[3]}) def test_hint_join_1(self): result = set(select(s for s in Student if JOIN(s.group in select(g for g in Group if g.id < 2)))) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_hint_join_2(self): result = set(select(s for s in Student if JOIN(s.group not in select(g for g in Group if g.id < 2)))) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) def test_hint_join_3(self): result = set(select(s for s in Student if JOIN(s.scholarship in select(s.scholarship + 100 for s in Student if s.name != 'S2')))) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_hint_join_4(self): result = set(select(g for g in Group if JOIN(g in select(s.group for s in g.students)))) - self.assertEqual(result, set([Group[1], Group[2]])) + self.assertEqual(result, {Group[1], Group[2]}) + + def test_group_concat_1(self): + result = select(s.name for s in Student).group_concat() + self.assertEqual(result, 'S1,S2,S3') + + def test_group_concat_2(self): + result = select(s.name for s in Student).group_concat('-') + self.assertEqual(result, 'S1-S2-S3') + + def test_group_concat_3(self): + result = select(s for s in Student if s.name in group_concat(s.name for s in Student))[:] + self.assertEqual(set(result), {Student[1], Student[2], Student[3]}) + + def test_group_concat_4(self): + result = Student.select().group_concat() + self.assertEqual(result, '1,2,3') + + def test_group_concat_5(self): + result = Student.select().group_concat('.') + self.assertEqual(result, '1.2.3') + + @raises_exception(TypeError, '`group_concat` cannot be used with entity with composite primary key') + def test_group_concat_6(self): + select(group_concat(s.courses, '-') for s in Student) + + def test_group_concat_7(self): + result = select(group_concat(c.semester) for c in Course)[:] + self.assertEqual(result[0], '1,1,2') + + def test_group_concat_8(self): + result = select(group_concat(c.semester, '-') for c in Course)[:] + self.assertEqual(result[0], '1-1-2') + + def test_group_concat_9(self): + result = select(group_concat(c.semester, distinct=True) for c in Course)[:] + self.assertEqual(result[0], '1,2') + + def test_group_concat_10(self): + result = group_concat((s.name for s in Student if int(s.name[1]) > 1), sep='-') + self.assertEqual(result, 'S2-S3') + + def test_group_concat_11(self): + result = group_concat((c.semester for c in Course), distinct=True) + self.assertEqual(result, '1,2') + + + @raises_exception(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') + def test_select_from_select_1(self): + query = select(s for s in Student if s.scholarship > 0)[:] + result = set(select(x for x in query)) + self.assertEqual(result, {}) + + def test_select_from_select_2(self): + p, q = 50, 400 + query = select(s for s in Student if s.scholarship > p) + result = select(x.id for x in query if x.scholarship < q)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_3(self): + p, q = 50, 400 + g = (s for s in Student if s.scholarship > p) + result = select(x.id for x in g if x.scholarship < q)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_4(self): + p, q = 50, 400 + result = select(x.id for x in (s for s in Student if s.scholarship > p) + if x.scholarship < q)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_5(self): + p, q = 50, 400 + result = select(x.id for x in select(s for s in Student if s.scholarship > 0) + if x.scholarship < 400)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_6(self): + query = select(s.name for s in Student if s.scholarship > 0) + result = select(x for x in query if not x.endswith('3')) + self.assertEqual(set(result), {'S2'}) + + @raises_exception(TranslationError, 'Too many values to unpack "for a, b in select(s for ...)" (expected 2, got 1)') + def test_select_from_select_7(self): + query = select(s for s in Student if s.scholarship > 0) + result = select(a for a, b in query) + + @raises_exception(NotImplementedError, 'Please unpack a tuple of (s.name, s.group) in for-loop ' + 'to individual variables (like: "for x, y in ...")') + def test_select_from_select_8(self): + query = select((s.name, s.group) for s in Student if s.scholarship > 0) + result = select(x for x in query) + + @raises_exception(TranslationError, 'Not enough values to unpack "for x, y in ' + 'select(s.name, s.group, s.scholarship for ...)" (expected 2, got 3)') + def test_select_from_select_9(self): + query = select((s.name, s.group, s.scholarship) for s in Student if s.scholarship > 0) + result = select(x for x, y in query) + + def test_select_from_select_10(self): + query = select((s.name, s.age) for s in Student if s.scholarship > 0) + result = select(n for n, a in query if n.endswith('2') and a > 20) + self.assertEqual(set(x for x in result), {'S2'}) + + def test_aggregations_1(self): + query = select((min(s.age), max(s.scholarship)) for s in Student) + result = query[:] + self.assertEqual(result, [(20, 500)]) + + def test_aggregations_2(self): + query = select((min(s.age), max(s.scholarship)) for s in Student for g in Group) + result = query[:] + self.assertEqual(result, [(20, 500)]) + if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 012906195..477d5c1ec 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -6,8 +6,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int) @@ -53,42 +54,48 @@ class Room(db.Entity): name = PrimaryKey(unicode) groups = Set(Group) -db.generate_mapping(create_tables=True) -with db_session: - d1 = Department(number=44) - d2 = Department(number=43) - g1 = Group(id=1, dept=d1) - g2 = Group(id=2, dept=d2) - s1 = Student(id=1, name='S1', group=g1, scholarship=0) - s2 = Student(id=2, name='S2', group=g1, scholarship=100) - s3 = Student(id=3, name='S3', group=g2, scholarship=500) - c1 = Course(name='Math', semester=1, dept=d1) - c2 = Course(name='Economics', semester=1, dept=d1, credits=3) - c3 = Course(name='Physics', semester=2, dept=d2) - t1 = Teacher(id=101, name="T1") - t2 = Teacher(id=102, name="T2") - Grade(student=s1, course=c1, value='C', teacher=t2, date=date(2011, 1, 1)) - Grade(student=s1, course=c3, value='A', teacher=t1, date=date(2011, 2, 1)) - Grade(student=s2, course=c2, value='B', teacher=t1) - r1 = Room(name='Room1') - r2 = Room(name='Room2') - r3 = Room(name='Room3') - g1.rooms = [ r1, r2 ] - g2.rooms = [ r2, r3 ] - c1.students.add(s1) - c2.students.add(s2) - -db2 = Database('sqlite', ':memory:') +db2 = Database() class Room2(db2.Entity): name = PrimaryKey(unicode) -db2.generate_mapping(create_tables=True) - name1 = 'S1' + class TestSQLTranslator(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=44) + d2 = Department(number=43) + g1 = Group(id=1, dept=d1) + g2 = Group(id=2, dept=d2) + s1 = Student(id=1, name='S1', group=g1, scholarship=0) + s2 = Student(id=2, name='S2', group=g1, scholarship=100) + s3 = Student(id=3, name='S3', group=g2, scholarship=500) + c1 = Course(name='Math', semester=1, dept=d1) + c2 = Course(name='Economics', semester=1, dept=d1, credits=3) + c3 = Course(name='Physics', semester=2, dept=d2) + t1 = Teacher(id=101, name="T1") + t2 = Teacher(id=102, name="T2") + Grade(student=s1, course=c1, value='C', teacher=t2, date=date(2011, 1, 1)) + Grade(student=s1, course=c3, value='A', teacher=t1, date=date(2011, 2, 1)) + Grade(student=s2, course=c2, value='B', teacher=t1) + r1 = Room(name='Room1') + r2 = Room(name='Room2') + r3 = Room(name='Room3') + g1.rooms = [r1, r2] + g2.rooms = [r2, r3] + c1.students.add(s1) + c1.students.add(s2) + c2.students.add(s2) + setup_database(db2) + @classmethod + def tearDownClass(cls): + teardown_database(db) + teardown_database(db2) def setUp(self): rollback() db_session.__enter__() @@ -97,14 +104,14 @@ def tearDown(self): db_session.__exit__() def test_select1(self): result = set(select(s for s in Student)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_select_param(self): result = select(s for s in Student if s.name == name1)[:] self.assertEqual(result, [Student[1]]) def test_select_object_param(self): stud1 = Student[1] result = set(select(s for s in Student if s != stud1)) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) def test_select_deref(self): x = 'S1' result = select(s for s in Student if s.name == x)[:] @@ -132,8 +139,7 @@ def test_function_min2(self): def test_min3(self): d = date(2011, 1, 1) result = set(select(g for g in Grade if min(g.date, d) == d and g.date is not None)) - self.assertEqual(result, set([Grade[Student[1], Course[u'Math', 1]], - Grade[Student[1], Course[u'Physics', 2]]])) + self.assertEqual(result, {Grade[Student[1], Course[u'Math', 1]], Grade[Student[1], Course[u'Physics', 2]]}) def test_function_len1(self): result = select(s for s in Student if len(s.grades) == 1)[:] self.assertEqual(result, [Student[2]]) @@ -168,13 +174,13 @@ def test_builtin_in_locals(self): # select(s for s in Student for g in g.subjects) def test_chain1(self): result = set(select(g for g in Group for s in g.students if s.name.endswith('3'))) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_chain2(self): result = set(select(s for g in Group if g.dept.number == 44 for s in g.students if s.name.startswith('S'))) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_chain_m2m(self): result = set(select(g for g in Group for r in g.rooms if r.name == 'Room2')) - self.assertEqual(result, set([Group[1], Group[2]])) + self.assertEqual(result, {Group[1], Group[2]}) @raises_exception(TranslationError, 'All entities in a query must belong to the same database') def test_two_diagrams(self): select(g for g in Group for r in Room2 if r.name == 'Room2') @@ -183,10 +189,10 @@ def test_add_sub_mul_etc(self): self.assertEqual(result, [Student[2]]) def test_subscript(self): result = set(select(s for s in Student if s.name[1] == '2')) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_slice(self): result = set(select(s for s in Student if s.name[:1] == 'S')) - self.assertEqual(result, set([Student[3], Student[2], Student[1]])) + self.assertEqual(result, {Student[3], Student[2], Student[1]}) def test_attr_chain(self): s1 = Student[1] result = select(s for s in Student if s == s1)[:] @@ -207,9 +213,9 @@ def test_list_monad3(self): grade1 = Grade[Student[1], Course['Physics', 2]] grade2 = Grade[Student[1], Course['Math', 1]] result = set(select(g for g in Grade if g in [grade1, grade2])) - self.assertEqual(result, set([grade1, grade2])) + self.assertEqual(result, {grade1, grade2}) result = set(select(g for g in Grade if g not in [grade1, grade2])) - self.assertEqual(result, set([Grade[Student[2], Course['Economics', 1]]])) + self.assertEqual(result, {Grade[Student[2], Course['Economics', 1]]}) def test_tuple_monad1(self): n1 = 'S1' n2 = 'S2' @@ -235,7 +241,7 @@ def test_expr1(self): result = select(a for s in Student) def test_expr2(self): result = set(select(s.group for s in Student)) - self.assertEqual(result, set([Group[1], Group[2]])) + self.assertEqual(result, {Group[1], Group[2]}) def test_numeric_binop(self): i = 100 f = 2.0 @@ -246,19 +252,19 @@ def test_string_const_monad(self): self.assertEqual(result, []) def test_numeric_to_bool1(self): result = set(select(s for s in Student if s.name != 'John' or s.scholarship)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_numeric_to_bool2(self): result = set(select(s for s in Student if not s.scholarship)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_not_monad1(self): result = set(select(s for s in Student if not (s.scholarship > 0 and s.name != 'S1'))) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_not_monad2(self): result = set(select(s for s in Student if not not (s.scholarship > 0 and s.name != 'S1'))) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) def test_subquery_with_attr(self): result = set(select(s for s in Student if max(g.value for g in s.grades) == 'C')) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_query_reuse(self): q = select(s for s in Student if s.scholarship > 0) q.count() @@ -275,47 +281,47 @@ def test_order_by(self): self.assertEqual(result, [Student[1], Student[2], Student[3]]) def test_read_inside_query(self): result = set(select(s for s in Student if Group[1].dept.number == 44)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_crud_attr_chain(self): result = set(select(s for s in Student if Group[1].dept.number == s.group.dept.number)) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key1(self): result = set(select(t for t in Teacher if Grade[Student[1], Course['Physics', 2]] in t.grades)) - self.assertEqual(result, set([Teacher.get(name='T1')])) + self.assertEqual(result, {Teacher.get(name='T1')}) def test_composite_key2(self): result = set(select(s for s in Student if Course['Math', 1] in s.courses)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key3(self): result = set(select(s for s in Student if Course['Math', 1] not in s.courses)) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[3]}) def test_composite_key4(self): result = set(select(s for s in Student if len(c for c in Course if c not in s.courses) == 2)) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1]}) def test_composite_key5(self): result = set(select(s for s in Student if not (c for c in Course if c not in s.courses))) self.assertEqual(result, set()) def test_composite_key6(self): result = set(select(c for c in Course if c not in (c2 for s in Student for c2 in s.courses))) - self.assertEqual(result, set([Course['Physics', 2]])) + self.assertEqual(result, {Course['Physics', 2]}) def test_composite_key7(self): result = set(select(c for s in Student for c in s.courses)) - self.assertEqual(result, set([Course['Math', 1], Course['Economics', 1]])) + self.assertEqual(result, {Course['Math', 1], Course['Economics', 1]}) def test_contains1(self): s1 = Student[1] result = set(select(g for g in Group if s1 in g.students)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_contains2(self): s1 = Student[1] result = set(select(g for g in Group if s1.name in g.students.name)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_contains3(self): s1 = Student[1] result = set(select(g for g in Group if s1 not in g.students)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_contains4(self): s1 = Student[1] result = set(select(g for g in Group if s1.name not in g.students.name)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_buffer_monad1(self): try: select(s for s in Student if s.picture == buffer('abc')) except TypeError as e: self.assertTrue(not PY2 and str(e) == 'string argument without an encoding') @@ -324,32 +330,129 @@ def test_buffer_monad2(self): select(s for s in Student if s.picture == buffer('abc', 'ascii')) def test_database_monad(self): result = set(select(s for s in db.Student if db.Student[1] == s)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_duplicate_name(self): result = set(select(x for x in Student if x.group in (x for x in Group))) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_hint_join1(self): result = set(select(s for s in Student if JOIN(max(s.courses.credits) == 3))) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_hint_join2(self): result = set(select(c for c in Course if JOIN(len(c.students) == 1))) - self.assertEqual(result, set([Course['Math', 1], Course['Economics', 1]])) + self.assertEqual(result, {Course['Economics', 1]}) def test_tuple_param(self): x = Student[1], Student[2] result = set(select(s for s in Student if s not in x)) - self.assertEqual(result, set([Student[3]])) - @raises_exception(TypeError, "Expression `x` should not contain None values") + self.assertEqual(result, {Student[3]}) + @raises_exception(TypeError, "Expression `x` should not contain None values") def test_tuple_param_2(self): x = Student[1], None result = set(select(s for s in Student if s not in x)) - self.assertEqual(result, set([Student[3]])) - @raises_exception(TypeError, "f(s)") - def test_unknown_func(self): - def f(x): return x - select(s for s in Student if f(s)) + self.assertEqual(result, {Student[3]}) def test_method_monad(self): result = set(select(s for s in Student if s not in Student.select(lambda s: s.scholarship > 0))) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) + def test_lambda_1(self): + q = select(s for s in Student) + q = q.filter(lambda s: s.name == 'S1') + self.assertEqual(list(q), [Student[1]]) + def test_lambda_2(self): + q = select(s for s in Student) + q = q.filter(lambda stud: stud.name == 'S1') + self.assertEqual(list(q), [Student[1]]) + def test_lambda_3(self): + q = select(s for s in Student) + q = q.filter(lambda stud: exists(x for x in Student if stud.name < x.name)) + self.assertEqual(set(q), {Student[1], Student[2]}) + def test_lambda_4(self): + q = select(s for s in Student) + q = q.filter(lambda stud: exists(s for s in Student if stud.name < s.name)) + self.assertEqual(set(q), {Student[1], Student[2]}) + def test_optimized_1(self): + q = select((g, count(g.students)) for g in Group if count(g.students) > 1) + self.assertEqual(set(q), {(Group[1], 2)}) + def test_optimized_2(self): + q = select((s, count(s.courses)) for s in Student if count(s.courses) > 1) + self.assertEqual(set(q), {(Student[2], 2)}) + def test_delete(self): + q = select(g for g in Grade if g.teacher.id == 101).delete() + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_delete_2(self): + delete(g for g in Grade if g.teacher.id == 101) + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_delete_3(self): + select(g for g in Grade if g.teacher.id == 101).delete(bulk=True) + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_delete_4(self): + select(g for g in Grade if exists(g2 for g2 in Grade if g2.value > g.value)).delete(bulk=True) + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_select_2(self): + result = select(s for s in Student)[:] + self.assertEqual(result, [Student[1], Student[2], Student[3]]) + def test_select_add(self): + result = [None] + select(s for s in Student)[:] + self.assertEqual(result, [None, Student[1], Student[2], Student[3]]) + def test_query_result_radd(self): + result = select(s for s in Student)[:] + [None] + self.assertEqual(result, [Student[1], Student[2], Student[3], None]) + def test_query_result_sort(self): + result = select(s for s in Student)[:] + result.sort() + self.assertEqual(result, [Student[1], Student[2], Student[3]]) + def test_query_result_reverse(self): + result = select(s for s in Student)[:] + items = list(result) + result.reverse() + self.assertEqual(items, list(reversed(result))) + def test_query_result_shuffle(self): + result = select(s for s in Student)[:] + items = set(result) + result.shuffle() + self.assertEqual(items, set(result)) + def test_query_result_to_list(self): + result = select(s for s in Student)[:] + items = result.to_list() + self.assertTrue(type(items) is list) + @raises_exception(TypeError, 'In order to do item assignment, cast QueryResult to list first') + def test_query_result_setitem(self): + result = select(s for s in Student)[:] + result[0] = None + @raises_exception(TypeError, 'In order to do item deletion, cast QueryResult to list first') + def test_query_result_delitem(self): + result = select(s for s in Student)[:] + del result[0] + @raises_exception(TypeError, 'In order to do +=, cast QueryResult to list first') + def test_query_result_iadd(self): + result = select(s for s in Student)[:] + result += None + @raises_exception(TypeError, 'In order to do append, cast QueryResult to list first') + def test_query_result_append(self): + result = select(s for s in Student)[:] + result.append(None) + @raises_exception(TypeError, 'In order to do clear, cast QueryResult to list first') + def test_query_result_clear(self): + result = select(s for s in Student)[:] + result.clear() + @raises_exception(TypeError, 'In order to do extend, cast QueryResult to list first') + def test_query_result_extend(self): + result = select(s for s in Student)[:] + result.extend([]) + @raises_exception(TypeError, 'In order to do insert, cast QueryResult to list first') + def test_query_result_insert(self): + result = select(s for s in Student)[:] + result.insert(0, None) + @raises_exception(TypeError, 'In order to do pop, cast QueryResult to list first') + def test_query_result_pop(self): + result = select(s for s in Student)[:] + result.pop() + @raises_exception(TypeError, 'In order to do remove, cast QueryResult to list first') + def test_query_result_remove(self): + result = select(s for s in Student)[:] + result.remove(None) if __name__ == "__main__": diff --git a/pony/orm/tests/test_declarative_sqltranslator2.py b/pony/orm/tests/test_declarative_sqltranslator2.py index 016df1fc5..ccddad377 100644 --- a/pony/orm/tests/test_declarative_sqltranslator2.py +++ b/pony/orm/tests/test_declarative_sqltranslator2.py @@ -7,8 +7,9 @@ from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int, auto=True) @@ -43,51 +44,56 @@ class Student(db.Entity): group = Required(Group) courses = Set(Course) -db.generate_mapping(create_tables=True) - -with db_session: - d1 = Department(name="Department of Computer Science") - d2 = Department(name="Department of Mathematical Sciences") - d3 = Department(name="Department of Applied Physics") - - c1 = Course(name="Web Design", semester=1, dept=d1, - lect_hours=30, lab_hours=30, credits=3) - c2 = Course(name="Data Structures and Algorithms", semester=3, dept=d1, - lect_hours=40, lab_hours=20, credits=4) - - c3 = Course(name="Linear Algebra", semester=1, dept=d2, - lect_hours=30, lab_hours=30, credits=4) - c4 = Course(name="Statistical Methods", semester=2, dept=d2, - lect_hours=50, lab_hours=25, credits=5) - - c5 = Course(name="Thermodynamics", semester=2, dept=d3, - lect_hours=25, lab_hours=40, credits=4) - c6 = Course(name="Quantum Mechanics", semester=3, dept=d3, - lect_hours=40, lab_hours=30, credits=5) - - g101 = Group(number=101, major='B.E. in Computer Engineering', dept=d1) - g102 = Group(number=102, major='B.S./M.S. in Computer Science', dept=d2) - g103 = Group(number=103, major='B.S. in Applied Mathematics and Statistics', dept=d2) - g104 = Group(number=104, major='B.S./M.S. in Pure Mathematics', dept=d2) - g105 = Group(number=105, major='B.E in Electronics', dept=d3) - g106 = Group(number=106, major='B.S./M.S. in Nuclear Engineering', dept=d3) - - Student(name='John Smith', dob=date(1991, 3, 20), tel='123-456', gpa=3, group=g101, phd=True, - courses=[c1, c2, c4, c6]) - Student(name='Matthew Reed', dob=date(1990, 11, 26), gpa=3.5, group=g101, phd=True, - courses=[c1, c3, c4, c5]) - Student(name='Chuan Qin', dob=date(1989, 2, 5), gpa=4, group=g101, - courses=[c3, c5, c6]) - Student(name='Rebecca Lawson', dob=date(1990, 4, 18), tel='234-567', gpa=3.3, group=g102, - courses=[c1, c4, c5, c6]) - Student(name='Maria Ionescu', dob=date(1991, 4, 23), gpa=3.9, group=g102, - courses=[c1, c2, c4, c6]) - Student(name='Oliver Blakey', dob=date(1990, 9, 8), gpa=3.1, group=g102, - courses=[c1, c2, c5]) - Student(name='Jing Xia', dob=date(1988, 12, 30), gpa=3.2, group=g102, - courses=[c1, c3, c5, c6]) - class TestSQLTranslator2(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=1, name="Department of Computer Science") + d2 = Department(number=2, name="Department of Mathematical Sciences") + d3 = Department(number=3, name="Department of Applied Physics") + + c1 = Course(name="Web Design", semester=1, dept=d1, + lect_hours=30, lab_hours=30, credits=3) + c2 = Course(name="Data Structures and Algorithms", semester=3, dept=d1, + lect_hours=40, lab_hours=20, credits=4) + + c3 = Course(name="Linear Algebra", semester=1, dept=d2, + lect_hours=30, lab_hours=30, credits=4) + c4 = Course(name="Statistical Methods", semester=2, dept=d2, + lect_hours=50, lab_hours=25, credits=5) + + c5 = Course(name="Thermodynamics", semester=2, dept=d3, + lect_hours=25, lab_hours=40, credits=4) + c6 = Course(name="Quantum Mechanics", semester=3, dept=d3, + lect_hours=40, lab_hours=30, credits=5) + + g101 = Group(number=101, major='B.E. in Computer Engineering', dept=d1) + g102 = Group(number=102, major='B.S./M.S. in Computer Science', dept=d2) + g103 = Group(number=103, major='B.S. in Applied Mathematics and Statistics', dept=d2) + g104 = Group(number=104, major='B.S./M.S. in Pure Mathematics', dept=d2) + g105 = Group(number=105, major='B.E in Electronics', dept=d3) + g106 = Group(number=106, major='B.S./M.S. in Nuclear Engineering', dept=d3) + + Student(id=1, name='John Smith', dob=date(1991, 3, 20), tel='123-456', gpa=3, group=g101, phd=True, + courses=[c1, c2, c4, c6]) + Student(id=2, name='Matthew Reed', dob=date(1990, 11, 26), gpa=3.5, group=g101, phd=True, + courses=[c1, c3, c4, c5]) + Student(id=3, name='Chuan Qin', dob=date(1989, 2, 5), gpa=4, group=g101, + courses=[c3, c5, c6]) + Student(id=4, name='Rebecca Lawson', dob=date(1990, 4, 18), tel='234-567', gpa=3.3, group=g102, + courses=[c1, c4, c5, c6]) + Student(id=5, name='Maria Ionescu', dob=date(1991, 4, 23), gpa=3.9, group=g102, + courses=[c1, c2, c4, c6]) + Student(id=6, name='Oliver Blakey', dob=date(1990, 9, 8), gpa=3.1, group=g102, + courses=[c1, c2, c5]) + Student(id=7, name='Jing Xia', dob=date(1988, 12, 30), gpa=3.2, group=g102, + courses=[c1, c3, c5, c6]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -100,35 +106,35 @@ def test_distinct1(self): self.assertEqual(q.count(), 7) def test_distinct3(self): q = select(d for d in Department if len(s for c in d.courses for s in c.students) > len(s for s in Student)) - self.assertEqual("DISTINCT" in flatten(q._translator.conditions), True) self.assertEqual(q[:], []) + self.assertTrue('DISTINCT' in db.last_sql) def test_distinct4(self): q = select(d for d in Department if len(d.groups.students) > 3) - self.assertEqual("DISTINCT" not in flatten(q._translator.conditions), True) self.assertEqual(q[:], [Department[2]]) + self.assertTrue("DISTINCT" not in db.last_sql) def test_distinct5(self): result = set(select(s for s in Student)) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_distinct6(self): result = set(select(s for s in Student).distinct()) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_not_null1(self): q = select(g for g in Group if '123-45-67' not in g.students.tel and g.dept == Department[1]) - not_null = "IS_NOT_NULL COLUMN student-1 tel" in (" ".join(str(i) for i in flatten(q._translator.conditions))) + not_null = "IS_NOT_NULL COLUMN student tel" in (" ".join(str(i) for i in flatten(q._translator.conditions))) self.assertEqual(not_null, True) self.assertEqual(q[:], [Group[101]]) def test_not_null2(self): q = select(g for g in Group if 'John' not in g.students.name and g.dept == Department[1]) - not_null = "IS_NOT_NULL COLUMN student-1 name" in (" ".join(str(i) for i in flatten(q._translator.conditions))) + not_null = "IS_NOT_NULL COLUMN student name" in (" ".join(str(i) for i in flatten(q._translator.conditions))) self.assertEqual(not_null, False) self.assertEqual(q[:], [Group[101]]) def test_chain_of_attrs_inside_for1(self): result = set(select(s for d in Department if d.number == 2 for s in d.groups.students)) - self.assertEqual(result, set([Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) def test_chain_of_attrs_inside_for2(self): pony.options.SIMPLE_ALIASES = False result = set(select(s for d in Department if d.number == 2 for s in d.groups.students)) - self.assertEqual(result, set([Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) pony.options.SIMPLE_ALIASES = True def test_non_entity_result1(self): result = select((s.name, s.group.number) for s in Student if s.name.startswith("J"))[:] @@ -146,7 +152,7 @@ def test_non_entity_result3a(self): self.assertEqual(sorted(result), [1988, 1989, 1990, 1991]) def test_non_entity_result4(self): result = set(select(s.name for s in Student if s.name.startswith('M'))) - self.assertEqual(result, set([u'Matthew Reed', u'Maria Ionescu'])) + self.assertEqual(result, {u'Matthew Reed', u'Maria Ionescu'}) def test_non_entity_result5(self): result = select((s.group, s.dob) for s in Student if s.group == Group[101])[:] self.assertEqual(sorted(result), [(Group[101], date(1989, 2, 5)), (Group[101], date(1990, 11, 26)), (Group[101], date(1991, 3, 20))]) @@ -156,7 +162,7 @@ def test_non_entity_result6(self): Student[2]), (Course[u'Web Design',1], Student[1]), (Course[u'Web Design',1], Student[2])])) def test_non_entity7(self): result = set(select(s for s in Student if (s.name, s.dob) not in (((s2.name, s2.dob) for s2 in Student if s.group.number == 101)))) - self.assertEqual(result, set([Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) @raises_exception(IncomparableTypesError, "Incomparable types 'int' and 'Set of Student' in expression: g.number == g.students") def test_incompartible_types(self): select(g for g in Group if g.number == g.students) @@ -167,7 +173,7 @@ def test_external_param1(self): def test_external_param2(self): x = Student[1] result = set(select(s for s in Student if s.name != x.name)) - self.assertEqual(result, set([Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) @raises_exception(TypeError, "Use select(...) function or Group.select(...) method for iteration") def test_exception1(self): for g in Group: @@ -177,18 +183,18 @@ def test_exception2(self): get(s for s in Student) def test_exists(self): result = exists(s for s in Student) - @raises_exception(ExprEvalError, "db.FooBar raises AttributeError: 'Database' object has no attribute 'FooBar'") + @raises_exception(ExprEvalError, "`db.FooBar` raises AttributeError: 'Database' object has no attribute 'FooBar'") def test_entity_not_found(self): select(s for s in db.Student for g in db.FooBar) def test_keyargs1(self): result = set(select(s for s in Student if s.dob < date(year=1990, month=10, day=20))) - self.assertEqual(result, set([Student[3], Student[4], Student[6], Student[7]])) + self.assertEqual(result, {Student[3], Student[4], Student[6], Student[7]}) def test_query_as_string1(self): result = set(select('s for s in Student if 3 <= s.gpa < 4')) - self.assertEqual(result, set([Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]}) def test_query_as_string2(self): result = set(select('s for s in db.Student if 3 <= s.gpa < 4')) - self.assertEqual(result, set([Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]}) def test_str_subclasses(self): result = select(d for d in Department for g in d.groups for c in d.courses if g.number == 106 and c.name.startswith('T'))[:] self.assertEqual(result, [Department[3]]) @@ -199,7 +205,7 @@ class Unicode2(unicode): select(s for s in Student if len(u2) == 1) def test_bool(self): result = set(select(s for s in Student if s.phd == True)) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_bool2(self): result = list(select(s for s in Student if s.phd + 1 == True)) self.assertEqual(result, []) @@ -212,7 +218,7 @@ def test_bool4(self): def test_bool5(self): x = True result = set(select(s for s in Student if s.phd == True and (False or (True and x)))) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_bool6(self): x = False result = list(select(s for s in Student if s.phd == (False or (True and x)) and s.phd is True)) diff --git a/pony/orm/tests/test_declarative_strings.py b/pony/orm/tests/test_declarative_strings.py index bb0f72795..3c6210bd1 100644 --- a/pony/orm/tests/test_declarative_strings.py +++ b/pony/orm/tests/test_declarative_strings.py @@ -4,24 +4,31 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database, only_for -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): - name = Required(unicode, autostrip=False) - foo = Optional(unicode) - bar = Optional(unicode) - -db.generate_mapping(create_tables=True) - -with db_session: - Student(id=1, name="Jon", foo='Abcdef', bar='b%d') - Student(id=2, name=" Bob ", foo='Ab%def', bar='b%d') - Student(id=3, name=" Beth ", foo='Ab_def', bar='b%d') - Student(id=4, name="Jonathan") - Student(id=5, name="Pete") + name = Required(str) + unstripped = Required(str, autostrip=False) + foo = Optional(str) + bar = Optional(str) class TestStringMethods(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Student(id=1, name="Ann", unstripped="Ann", foo='Abcdef', bar='b%d') + Student(id=2, name="Bob", unstripped=" Bob ", foo='Ab%def', bar='b%d') + Student(id=3, name="Beth", unstripped=" Beth ", foo='Ab_def', bar='b%d') + Student(id=4, name="Jonathan", unstripped="\nJonathan\n") + Student(id=5, name="Pete", unstripped="\n Pete\n ") + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -30,148 +37,267 @@ def tearDown(self): rollback() db_session.__exit__() - def test_nonzero(self): - result = set(select(s for s in Student if s.foo)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + def test_getitem_1(self): + result = set(select(s for s in Student if s.name[1] == 'o')) + self.assertEqual(result, {Student[2], Student[4]}) + x = 1 + result = set(select(s for s in Student if s.name[x] == 'o')) + self.assertEqual(result, {Student[2], Student[4]}) - def test_add(self): - name = 'Jonny' - result = set(select(s for s in Student if s.name + "ny" == name)) - self.assertEqual(result, set([Student[1]])) + def test_getitem_2(self): + result = set(select(s for s in Student if s.name[-1] == 'n')) + self.assertEqual(result, {Student[1], Student[4]}) + x = -1 + result = set(select(s for s in Student if s.name[x] == 'n')) + self.assertEqual(result, {Student[1], Student[4]}) + + def test_getitem_3(self): + result = set(select(s for s in Student if s.name[-2] == 't')) + self.assertEqual(result, {Student[3], Student[5]}) + x = -2 + result = set(select(s for s in Student if s.name[x] == 't')) + self.assertEqual(result, {Student[3], Student[5]}) + + def test_getitem_4(self): + result = set(select(s for s in Student if s.name[-s.id] == 'n')) + self.assertEqual(result, {Student[1]}) def test_slice_1(self): - result = set(select(s for s in Student if s.name[0:3] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + result = set(select(s for s in Student if s.name[:] == "Ann")) + self.assertEqual(result, {Student[1]}) + result = set(select(s for s in Student if s.name[0:] == "Ann")) + self.assertEqual(result, {Student[1]}) def test_slice_2(self): result = set(select(s for s in Student if s.name[:3] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[4]}) + result = set(select(s for s in Student if s.name[0:3] == "Jon")) + self.assertEqual(result, {Student[4]}) + x = 0 + y = 3 + result = set(select(s for s in Student if s.name[:y] == "Jon")) + self.assertEqual(result, {Student[4]}) + result = set(select(s for s in Student if s.name[x:y] == "Jon")) + self.assertEqual(result, {Student[4]}) + result = set(select(s for s in Student if s.name[x:3] == "Jon")) + self.assertEqual(result, {Student[4]}) def test_slice_3(self): - x = 3 - result = set(select(s for s in Student if s.name[:x] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) - - def test_slice_4(self): - x = 3 - result = set(select(s for s in Student if s.name[0:x] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) - - def test_slice_5(self): - result = set(select(s for s in Student if s.name[0:10] == "Jon")) - self.assertEqual(result, set([Student[1]])) - - def test_slice_6(self): - result = set(select(s for s in Student if s.name[0:] == "Jon")) - self.assertEqual(result, set([Student[1]])) - - def test_slice_7(self): - result = set(select(s for s in Student if s.name[:] == "Jon")) - self.assertEqual(result, set([Student[1]])) + result = set(select(s for s in Student if s.name[0:10] == "Ann")) + self.assertEqual(result, {Student[1]}) + x = 10 + result = set(select(s for s in Student if s.name[0:x] == "Ann")) + self.assertEqual(result, {Student[1]}) + result = set(select(s for s in Student if s.name[:x] == "Ann")) + self.assertEqual(result, {Student[1]}) def test_slice_8(self): - result = set(select(s for s in Student if s.name[1:] == "on")) - self.assertEqual(result, set([Student[1]])) - - def test_slice_9(self): + result = set(select(s for s in Student if s.name[1:] == "nn")) + self.assertEqual(result, {Student[1]}) x = 1 - result = set(select(s for s in Student if s.name[x:] == "on")) - self.assertEqual(result, set([Student[1]])) + result = set(select(s for s in Student if s.name[x:] == "nn")) + self.assertEqual(result, {Student[1]}) def test_slice_10(self): - x = 0 - result = set(select(s for s in Student if s.name[x:3] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) - - def test_slice_11(self): + result = set(select(s for s in Student if s.name[1:3] == "et")) + self.assertEqual(result, {Student[3], Student[5]}) x = 1 y = 3 - result = set(select(s for s in Student if s.name[x:y] == "on")) - self.assertEqual(result, set([Student[1], Student[4]])) + result = set(select(s for s in Student if s.name[x:y] == "et")) + self.assertEqual(result, {Student[3], Student[5]}) - def test_slice_12(self): + def test_slice_11(self): x = 10 y = 20 result = set(select(s for s in Student if s.name[x:y] == '')) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) - - def test_getitem_1(self): - result = set(select(s for s in Student if s.name[1] == 'o')) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) - def test_getitem_2(self): + def test_slice_12(self): + result = set(select(s for s in Student if s.name[-2:] == "nn")) + self.assertEqual(result, {Student[1]}) + + def test_slice_13(self): + result = set(select(s for s in Student if s.name[:-1] == "Ann")) + self.assertEqual(result, {Student[1]}) + result = set(select(s for s in Student if s.name[0:-1] == "Ann")) + self.assertEqual(result, {Student[1]}) + x = 0 + y = -1 + result = set(select(s for s in Student if s.name[x:y] == "Ann")) + self.assertEqual(result, {Student[1]}) + + def test_slice_14(self): + result = set(select(s for s in Student if s.name[-4:-2] == "th")) + self.assertEqual(result, {Student[4]}) + x = -4 + y = -2 + result = set(select(s for s in Student if s.name[x:y] == "th")) + self.assertEqual(result, {Student[4]}) + + def test_slice_15(self): + result = set(select(s for s in Student if s.name[4:-2] == "th")) + self.assertEqual(result, {Student[4]}) + x = 4 + y = -2 + result = set(select(s for s in Student if s.name[x:y] == "th")) + self.assertEqual(result, {Student[4]}) + + def test_slice_16(self): + result = list(select(s.name[-2:3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'nn', 'ob', 't', 't']) + x = -2 + y = 3 + result = list(select(s.name[x:y] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'nn', 'ob', 't', 't']) + + def test_slice_17(self): + result = list(select(s.name[s.id:5] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 't']) + x = 5 + result = list(select(s.name[s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 't']) + + def test_slice_18(self): + result = list(select(s.name[-s.id:5] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 't']) + x = 5 + result = list(select(s.name[-s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 't']) + + def test_slice_19a(self): + result = list(select(s.name[s.id:] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 'than']) + + def test_slice_19b(self): + result = list(select(s.name[s.id:-1] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', 'n', 'tha']) + x = -1 + result = list(select(s.name[s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', 'n', 'tha']) + + def test_slice_19c(self): + result = list(select(s.name[s.id:-2] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', '', 'th']) + x = -2 + result = list(select(s.name[s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', '', 'th']) + + def test_slice_20a(self): + result = list(select(s.name[-s.id:] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 'than']) + + def test_slice_20b(self): + result = list(select(s.name[-s.id:-1] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'Pet', 'et', 'o', 'tha']) + x = -1 + result = list(select(s.name[-s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'Pet', 'et', 'o', 'tha']) + + def test_slice_20c(self): + result = list(select(s.name[-s.id:-2] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', 'Pe', 'e', 'th']) + x = -2 + result = list(select(s.name[-s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', 'Pe', 'e', 'th']) + + def test_slice_21(self): + result = list(select(s.name[1:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'et', 'ete', 'o', 'ona']) x = 1 - result = set(select(s for s in Student if s.name[x] == 'o')) - self.assertEqual(result, set([Student[1], Student[4]])) + result = list(select(s.name[x:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'et', 'ete', 'o', 'ona']) - def test_getitem_3(self): - result = set(select(s for s in Student if s.name[-1] == 'n')) - self.assertEqual(result, set([Student[1], Student[4]])) + def test_slice_22(self): + result = list(select(s.name[-3:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'A', 'Bo', 'et', 'ete']) + x = -3 + result = list(select(s.name[x:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'A', 'Bo', 'et', 'ete']) - def test_getitem_4(self): - x = -1 - result = set(select(s for s in Student if s.name[x] == 'n')) - self.assertEqual(result, set([Student[1], Student[4]])) + def test_slice_23(self): + result = list(select(s.name[s.id:s.id+3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 'tha']) + + def test_slice_24(self): + result = list(select(s.name[-s.id*2:-s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'B', 'B', 'Jona', 'n']) + + def test_slice_25(self): + result = list(select(s.name[s.id:-s.id+3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', 'n', 'tha']) + + def test_slice_26(self): + result = list(select(s.name[-s.id:s.id+3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 'tha']) + + def test_nonzero(self): + result = set(select(s for s in Student if s.foo)) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) + + def test_add(self): + name = 'Bethy' + result = set(select(s for s in Student if s.name + "y" == name)) + self.assertEqual(result, {Student[3]}) def test_contains_1(self): result = set(select(s for s in Student if 'o' in s.name)) - self.assertEqual(result, set([Student[1], Student[2], Student[4]])) + self.assertEqual(result, {Student[2], Student[4]}) def test_contains_2(self): - result = set(select(s for s in Student if 'on' in s.name)) - self.assertEqual(result, set([Student[1], Student[4]])) + result = set(select(s for s in Student if 'an' in s.name)) + self.assertEqual(result, {Student[4]}) def test_contains_3(self): - x = 'on' + x = 'an' result = set(select(s for s in Student if x in s.name)) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[4]}) def test_contains_4(self): - x = 'on' + x = 'an' result = set(select(s for s in Student if x not in s.name)) - self.assertEqual(result, set([Student[2], Student[3], Student[5]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[5]}) def test_contains_5(self): result = set(select(s for s in Student if '%' in s.foo)) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_contains_6(self): x = '%' result = set(select(s for s in Student if x in s.foo)) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_contains_7(self): result = set(select(s for s in Student if '_' in s.foo)) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) def test_contains_8(self): x = '_' result = set(select(s for s in Student if x in s.foo)) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) def test_contains_9(self): result = set(select(s for s in Student if s.foo in 'Abcdef')) - self.assertEqual(result, set([Student[1], Student[4], Student[5]])) + self.assertEqual(result, {Student[1], Student[4], Student[5]}) def test_contains_10(self): result = set(select(s for s in Student if s.bar in s.foo)) - self.assertEqual(result, set([Student[2], Student[4], Student[5]])) + self.assertEqual(result, {Student[2], Student[4], Student[5]}) def test_startswith_1(self): - students = set(select(s for s in Student if s.name.startswith('J'))) - self.assertEqual(students, set([Student[1], Student[4]])) + students = set(select(s for s in Student if s.name.startswith('B'))) + self.assertEqual(students, {Student[2], Student[3]}) def test_startswith_2(self): - students = set(select(s for s in Student if not s.name.startswith('J'))) - self.assertEqual(students, set([Student[2], Student[3], Student[5]])) + students = set(select(s for s in Student if not s.name.startswith('B'))) + self.assertEqual(students, {Student[1], Student[4], Student[5]}) def test_startswith_3(self): - students = set(select(s for s in Student if not not s.name.startswith('J'))) - self.assertEqual(students, set([Student[1], Student[4]])) + students = set(select(s for s in Student if not not s.name.startswith('B'))) + self.assertEqual(students, {Student[2], Student[3]}) def test_startswith_4(self): - students = set(select(s for s in Student if not not not s.name.startswith('J'))) - self.assertEqual(students, set([Student[2], Student[3], Student[5]])) + students = set(select(s for s in Student if not not not s.name.startswith('B'))) + self.assertEqual(students, {Student[1], Student[4], Student[5]}) def test_startswith_5(self): x = "Pe" @@ -180,7 +306,7 @@ def test_startswith_5(self): def test_endswith_1(self): students = set(select(s for s in Student if s.name.endswith('n'))) - self.assertEqual(students, set([Student[1], Student[4]])) + self.assertEqual(students, {Student[1], Student[4]}) def test_endswith_2(self): x = "te" @@ -191,8 +317,13 @@ def test_strip_1(self): students = select(s for s in Student if s.name.strip() == 'Beth')[:] self.assertEqual(students, [Student[3]]) - def test_rstrip(self): - students = select(s for s in Student if s.name.rstrip('n') == 'Jo')[:] + def test_rstrip_1(self): + students = select(s for s in Student if s.name.rstrip('n') == 'A')[:] + self.assertEqual(students, [Student[1]]) + + def test_rstrip_2(self): + x = 'n' + students = select(s for s in Student if s.name.rstrip(x) == 'A')[:] self.assertEqual(students, [Student[1]]) def test_lstrip(self): @@ -200,11 +331,11 @@ def test_lstrip(self): self.assertEqual(students, [Student[5]]) def test_upper(self): - result = select(s for s in Student if s.name.upper() == "JON")[:] + result = select(s for s in Student if s.name.upper() == "ANN")[:] self.assertEqual(result, [Student[1]]) def test_lower(self): - result = select(s for s in Student if s.name.lower() == "jon")[:] + result = select(s for s in Student if s.name.lower() == "ann")[:] self.assertEqual(result, [Student[1]]) if __name__ == "__main__": diff --git a/pony/orm/tests/test_decompiler.py b/pony/orm/tests/test_decompiler.py new file mode 100644 index 000000000..e0c7e6d74 --- /dev/null +++ b/pony/orm/tests/test_decompiler.py @@ -0,0 +1,103 @@ +import unittest + +from pony.thirdparty.compiler.transformer import parse +from pony.orm.decompiling import Decompiler +from pony.orm.asttranslation import ast2src + + +def generate_gens(): + patterns = [ + '(x * y) * [z * j)', + '([x * y) * z) * j', + '(x * [y * z)) * j', + 'x * ([y * z) * j)', + 'x * (y * [z * j))' + ] + + ops = ('and', 'or') + nots = (True, False) + + result = [] + + for pat in patterns: + cur = pat + for op1 in ops: + for op2 in ops: + for op3 in ops: + res = cur.replace('*', op1, 1) + res = res.replace('*', op2, 1) + res = res.replace('*', op3, 1) + result.append(res) + + final = [] + + for res in result: + for par1 in nots: + for par2 in nots: + for a in nots: + for b in nots: + for c in nots: + for d in nots: + cur = res.replace('(', 'not(') if not par1 else res + if not par2: + cur = cur.replace('[', 'not(') + else: + cur = cur.replace('[', '(') + if not a: cur = cur.replace('x', 'not x') + if not b: cur = cur.replace('y', 'not y') + if not c: cur = cur.replace('z', 'not z') + if not d: cur = cur.replace('j', 'not j') + final.append(cur) + + return final + +def create_test(gen): + def wrapped_test(self): + def get_condition_values(cond): + result = [] + vals = (True, False) + for x in vals: + for y in vals: + for z in vals: + for j in vals: + result.append(eval(cond, {'x': x, 'y': y, 'z': z, 'j': j})) + return result + src1 = '(a for a in [] if %s)' % gen + src2 = 'lambda x, y, z, j: (%s)' % gen + src3 = '(m for m in [] if %s for n in [] if %s)' % (gen, gen) + + code1 = compile(src1, '', 'eval').co_consts[0] + ast1 = Decompiler(code1).ast + res1 = ast2src(ast1).replace('.0', '[]') + res1 = res1[res1.find('if')+2:-1] + + code2 = compile(src2, '', 'eval').co_consts[0] + ast2 = Decompiler(code2).ast + res2 = ast2src(ast2).replace('.0', '[]') + res2 = res2[res2.find(':')+1:] + + code3 = compile(src3, '', 'eval').co_consts[0] + ast3 = Decompiler(code3).ast + res3 = ast2src(ast3).replace('.0', '[]') + res3 = res3[res3.find('if')+2: res3.rfind('for')-1] + + if get_condition_values(gen) != get_condition_values(res1): + self.fail("Incorrect generator decompilation: %s -> %s" % (gen, res1)) + + if get_condition_values(gen) != get_condition_values(res2): + self.fail("Incorrect lambda decompilation: %s -> %s" % (gen, res2)) + + if get_condition_values(gen) != get_condition_values(res3): + self.fail("Incorrect multi-for generator decompilation: %s -> %s" % (gen, res3)) + + return wrapped_test + + +class TestDecompiler(unittest.TestCase): + pass + + +for i, gen in enumerate(generate_gens()): + test_method = create_test(gen) + test_method.__name__ = 'test_decompiler_%d' % i + setattr(TestDecompiler, test_method.__name__, test_method) diff --git a/pony/orm/tests/test_deduplication.py b/pony/orm/tests/test_deduplication.py new file mode 100644 index 000000000..146a33181 --- /dev/null +++ b/pony/orm/tests/test_deduplication.py @@ -0,0 +1,53 @@ +from pony.py23compat import StringIO +from pony.orm import * +from pony.orm.tests import setup_database, teardown_database + +import unittest + +db = Database() + +class A(db.Entity): + id = PrimaryKey(int) + x = Required(bool) + y = Required(float) + + +class TestDeduplication(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + a1 = A(id=1, x=False, y=3.0) + a2 = A(id=2, x=True, y=4.0) + a3 = A(id=3, x=False, y=1.0) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def test_1(self): + a2 = A.get(id=2) + a1 = A.get(id=1) + self.assertIs(a1.id, 1) + + @db_session + def test_2(self): + a3 = A.get(id=3) + a1 = A.get(id=1) + self.assertIs(a1.id, 1) + + @db_session + def test_3(self): + q = A.select().order_by(-1) + stream = StringIO() + q.show(stream=stream) + s = stream.getvalue() + self.assertEqual(s, 'id|x |y \n' + '--+-----+---\n' + '3 |False|1.0\n' + '2 |True |4.0\n' + '1 |False|3.0\n') + + + diff --git a/pony/orm/tests/test_diagram.py b/pony/orm/tests/test_diagram.py index 235de72af..a812b440f 100644 --- a/pony/orm/tests/test_diagram.py +++ b/pony/orm/tests/test_diagram.py @@ -5,12 +5,13 @@ from pony.orm.core import * from pony.orm.core import Entity from pony.orm.tests.testutils import * +from pony.orm.tests import db_params -class TestDiag(unittest.TestCase): +class TestDiag(unittest.TestCase): @raises_exception(ERDiagramError, 'Entity Entity1 already exists') def test_entity_duplicate(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) class Entity1(db.Entity): @@ -19,66 +20,75 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, 'Interrelated entities must belong to same database.' ' Entities Entity2 and Entity1 belongs to different databases') def test_diagram1(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') - db = Database('sqlite', ':memory:') + db.bind(**db_params) + db = Database() class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) + db.bind(**db_params) db.generate_mapping() @raises_exception(ERDiagramError, 'Entity definition Entity2 was not found') def test_diagram2(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') + db.bind(**db_params) db.generate_mapping() @raises_exception(TypeError, 'Entity1._table_ property must be a string. Got: 123') def test_diagram3(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): _table_ = 123 id = PrimaryKey(int) + db.bind(**db_params) db.generate_mapping() def test_diagram4(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table1') + db.bind(**db_params) db.generate_mapping(create_tables=True) + db.drop_all_tables(with_all_data=True) def test_diagram5(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) + db.bind(**db_params) db.generate_mapping(create_tables=True) + db.drop_all_tables(with_all_data=True) @raises_exception(MappingError, "Parameter 'table' for Entity1.attr1 and Entity2.attr2 do not match") def test_diagram6(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table2') + db.bind(**db_params) db.generate_mapping() - @raises_exception(MappingError, "Table name 'Table1' is already in use") + @raises_exception(MappingError, 'Table name "Table1" is already in use') def test_diagram7(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): _table_ = 'Table1' id = PrimaryKey(int) @@ -86,24 +96,35 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table1') + db.bind(**db_params) db.generate_mapping() def test_diagram8(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1) + db.bind(**db_params) db.generate_mapping(create_tables=True) - m2m_table = db.schema.tables['Entity1_Entity2'] - col_names = set([ col.name for col in m2m_table.column_list ]) - self.assertEqual(col_names, set(['entity1', 'entity2'])) - self.assertEqual(Entity1.attr1.get_m2m_columns(), ['entity1']) + if pony.__version__ >= '0.9': + m2m_table = db.schema.tables['entity1_attr1'] + col_names = set(m2m_table.columns) + self.assertEqual(col_names, {'entity1_id', 'entity2_id'}) + m2m_columns = [c.name for c in Entity1.attr1.meta.m2m_columns] + self.assertEqual(m2m_columns, ['entity1_id']) + else: + table_name = 'Entity1_Entity2' if db.provider.dialect == 'SQLite' else 'entity1_entity2' + m2m_table = db.schema.tables[table_name] + col_names = {col.name for col in m2m_table.column_list} + self.assertEqual(col_names, {'entity1', 'entity2'}) + self.assertEqual(Entity1.attr1.get_m2m_columns(), ['entity1']) + db.drop_all_tables(with_all_data=True) def test_diagram9(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -112,13 +133,21 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1) + db.bind(**db_params) db.generate_mapping(create_tables=True) - m2m_table = db.schema.tables['Entity1_Entity2'] - col_names = set([ col.name for col in m2m_table.column_list ]) - self.assertEqual(col_names, set(['entity1_a', 'entity1_b', 'entity2'])) + if pony.__version__ >= '0.9': + m2m_table = db.schema.tables['entity1_attr1'] + col_names = {col for col in m2m_table.columns} + self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2_id'}) + else: + table_name = 'Entity1_Entity2' if db.provider.dialect == 'SQLite' else 'entity1_entity2' + m2m_table = db.schema.tables[table_name] + col_names = set([col.name for col in m2m_table.column_list]) + self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2'}) + db.drop_all_tables(with_all_data=True) def test_diagram10(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -127,11 +156,13 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, columns=['x', 'y']) + db.bind(**db_params) db.generate_mapping(create_tables=True) + db.drop_all_tables(with_all_data=True) @raises_exception(MappingError, 'Invalid number of columns for Entity2.attr2') def test_diagram11(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -140,6 +171,7 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, columns=['x']) + db.bind(**db_params) db.generate_mapping() @raises_exception(ERDiagramError, 'Base Entity does not belong to any database') @@ -149,9 +181,10 @@ class Test(Entity): @raises_exception(ERDiagramError, 'Entity class name should start with a capital letter. Got: entity1') def test_diagram13(self): - db = Database('sqlite', ':memory:') + db = Database() class entity1(db.Entity): a = Required(int) + db.bind(**db_params) db.generate_mapping() if __name__ == '__main__': diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index d5dc5660b..f7c29582b 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -6,44 +6,51 @@ from pony.orm.core import * from pony.orm.core import Attribute from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, only_for, setup_database, teardown_database + class TestAttribute(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + + def tearDown(self): + teardown_database(self.db) @raises_exception(TypeError, "Attribute Entity1.id has unknown option 'another_option'") def test_attribute1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, another_option=3) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, 'Cannot link attribute Entity1.b to abstract Entity class. Use specific Entity subclass instead') def test_attribute2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) b = Required(db.Entity) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(TypeError, 'Default value for required attribute Entity1.b cannot be None') def test_attribute3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) b = Required(int, default=None) def test_attribute4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2', reverse='attr2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.attr1.reverse, Entity2.attr2) def test_attribute5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') @@ -54,35 +61,35 @@ class Entity2(db.Entity): @raises_exception(TypeError, "Value of 'reverse' option must be name of reverse attribute). Got: 123") def test_attribute6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2', reverse=123) @raises_exception(TypeError, "Reverse option cannot be set for this type: %r" % str) def test_attribute7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required(str, reverse='attr1') @raises_exception(TypeError, "'Attribute' is abstract type") def test_attribute8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Attribute(str) @raises_exception(ERDiagramError, "Attribute name cannot both start and end with underscore. Got: _attr1_") def test_attribute9(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) _attr1_ = Required(str) @raises_exception(ERDiagramError, "Duplicate use of attribute Entity1.attr1 in entity Entity2") def test_attribute10(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required(str) @@ -92,7 +99,7 @@ class Entity2(db.Entity): @raises_exception(ERDiagramError, "Invalid use of attribute Entity1.a in entity Entity2") def test_attribute11(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(str) class Entity2(db.Entity): @@ -102,33 +109,33 @@ class Entity2(db.Entity): @raises_exception(ERDiagramError, "Cannot create default primary key attribute for Entity1 because name 'id' is already in use." " Please create a PrimaryKey attribute for entity Entity1 or rename the 'id' attribute") def test_attribute12(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = Optional(str) @raises_exception(ERDiagramError, "Reverse attribute for Entity1.attr1 not found") def test_attribute13(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Reverse attribute Entity1.attr1 not found") def test_attribute14(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1, reverse='attr1') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Inconsistent reverse attributes Entity3.attr3 and Entity2.attr2") def test_attribute15(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') @@ -138,11 +145,11 @@ class Entity2(db.Entity): class Entity3(db.Entity): id = PrimaryKey(int) attr3 = Required(Entity2, reverse='attr2') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Inconsistent reverse attributes Entity3.attr3 and Entity2.attr2") def test_attribute16(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') @@ -152,21 +159,21 @@ class Entity2(db.Entity): class Entity3(db.Entity): id = PrimaryKey(int) attr3 = Required(Entity2, reverse=Entity2.attr2) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, 'Reverse attribute for Entity2.attr2 not found') def test_attribute18(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required('Entity1') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Ambiguous reverse attribute for Entity1.a. Use the 'reverse' parameter for pointing to right attribute") def test_attribute19(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2') @@ -175,11 +182,11 @@ class Entity2(db.Entity): id = PrimaryKey(int) c = Set(Entity1) d = Set(Entity1) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Ambiguous reverse attribute for Entity1.c. Use the 'reverse' parameter for pointing to right attribute") def test_attribute20(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) c = Set('Entity2') @@ -187,10 +194,10 @@ class Entity2(db.Entity): id = PrimaryKey(int) a = Required(Entity1, reverse='c') b = Optional(Entity1, reverse='c') - db.generate_mapping() + db.generate_mapping(check_tables=False) def test_attribute21(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2', reverse='c') @@ -201,7 +208,7 @@ class Entity2(db.Entity): d = Set(Entity1) def test_attribute22(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2', reverse='c') @@ -213,72 +220,81 @@ class Entity2(db.Entity): @raises_exception(ERDiagramError, 'Inconsistent reverse attributes Entity1.a and Entity2.b') def test_attribute23(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required('Entity2', reverse='b') class Entity2(db.Entity): b = Optional('Entity3') class Entity3(db.Entity): c = Required('Entity2') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, 'Inconsistent reverse attributes Entity1.a and Entity2.c') def test_attribute23(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required('Entity2', reverse='c') b = Required('Entity2', reverse='d') class Entity2(db.Entity): c = Optional('Entity1', reverse='b') d = Optional('Entity1', reverse='a') - db.generate_mapping() + db.generate_mapping(check_tables=False) + + def test_attribute24(self): + db = self.db + class Entity1(db.Entity): + a = PrimaryKey(str, auto=True) + db.generate_mapping(create_tables=True) + table_name = 'Entity1' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'entity1' + self.assertTrue('AUTOINCREMENT' not in db.schema.tables[table_name].get_create_command()) @raises_exception(TypeError, "Parameters 'column' and 'columns' cannot be specified simultaneously") def test_columns1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a', columns=['b', 'c']) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, column='a') + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.id.columns, ['a']) def test_columns3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, columns=['a']) self.assertEqual(Entity1.id.column, 'a') @raises_exception(MappingError, "Too many columns were specified for Entity1.id") def test_columns5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, columns=['a', 'b']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) - @raises_exception(TypeError, "Parameter 'columns' must be a list. Got: %r'" % set(['a'])) + @raises_exception(TypeError, "Parameter 'columns' must be a list. Got: %r'" % {'a'}) def test_columns6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): - id = PrimaryKey(int, columns=set(['a'])) - db.generate_mapping(create_tables=True) + id = PrimaryKey(int, columns={'a'}) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'column' must be a string. Got: 4") def test_columns7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, column=4) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -286,12 +302,13 @@ class Entity1(db.Entity): PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=['x', 'y']) + db.generate_mapping(check_tables=False) self.assertEqual(Entity2.attr2.column, None) self.assertEqual(Entity2.attr2.columns, ['x', 'y']) @raises_exception(MappingError, 'Invalid number of columns specified for Entity2.attr2') def test_columns9(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -299,11 +316,11 @@ class Entity1(db.Entity): PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=['x', 'y', 'z']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(MappingError, 'Invalid number of columns specified for Entity2.attr2') def test_columns10(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -311,11 +328,11 @@ class Entity1(db.Entity): PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, column='x') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Items of parameter 'columns' must be strings. Got: [1, 2]") def test_columns11(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -325,97 +342,97 @@ class Entity2(db.Entity): attr2 = Required(Entity1, columns=[1, 2]) def test_columns12(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column='column2', reverse_columns=['column2']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameters 'reverse_column' and 'reverse_columns' cannot be specified simultaneously") def test_columns13(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column='column2', reverse_columns=['column3']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'reverse_column' must be a string. Got: 5") def test_columns14(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column=5) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'reverse_columns' must be a list. Got: 'column3'") def test_columns15(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns='column3') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'reverse_columns' must be a list of strings. Got: [5]") def test_columns16(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns=[5]) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns17(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns=['column2']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns18(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table='T1') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'table' must be a string. Got: 5") def test_columns19(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table=5) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Each part of table name must be a string. Got: 1") def test_columns20(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table=[1, 'T1']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns_21(self): - db = Database('sqlite', ':memory:') + db = self.db class Stat(db.Entity): webinarshow = Optional('WebinarShow') class WebinarShow(db.Entity): stats = Required('Stat') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Stat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') - + def test_columns_22(self): - db = Database('sqlite', ':memory:') + db = self.db class ZStat(db.Entity): webinarshow = Optional('WebinarShow') class WebinarShow(db.Entity): stats = Required('ZStat') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(ZStat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') def test_nullable1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(unicode, unique=True) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.a.nullable, True) def test_nullable2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(unicode, unique=True) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: Entity1() commit() @@ -423,23 +440,23 @@ class Entity1(db.Entity): commit() def test_lambda_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(lambda: db.Entity2) class Entity2(db.Entity): b = Set(lambda: db.Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.a.py_type, Entity2) self.assertEqual(Entity2.b.py_type, Entity1) @raises_exception(TypeError, "Invalid type of attribute Entity1.a: expected entity class, got 'Entity2'") def test_lambda_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(lambda: 'Entity2') class Entity2(db.Entity): b = Set(lambda: db.Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, 'Interrelated entities must belong to same database. ' 'Entities Entity1 and Entity2 belongs to different databases') @@ -450,47 +467,47 @@ class Entity1(db1.Entity): db2 = Database('sqlite', ':memory:') class Entity2(db2.Entity): b = Set(lambda: db1.Entity1) - db1.generate_mapping(create_tables=True) + db1.generate_mapping(check_tables=False) @raises_exception(ValueError, 'Check for attribute Entity1.a failed. Value: 1') def test_py_check_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int, py_check=lambda val: val > 5 and val < 10) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=1) def test_py_check_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int, py_check=lambda val: val > 5 and val < 10) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=7) def test_py_check_3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=None) @raises_exception(ValueError, 'Check for attribute Entity1.a failed. Value: datetime.date(1999, 1, 1)') def test_py_check_4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=date(1999, 1, 1)) def test_py_check_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=date(2010, 1, 1)) @@ -498,10 +515,10 @@ class Entity1(db.Entity): def test_py_check_6(self): def positive_number(val): if val <= 0: raise ValueError('Should be positive number') - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(int, py_check=positive_number) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=-1) @@ -509,118 +526,127 @@ def test_py_check_7(self): def positive_number(val): if val <= 0: raise ValueError('Should be positive number') return True - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(int, py_check=positive_number) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=1) @raises_exception(NotImplementedError, "'py_check' parameter is not supported for collection attributes") def test_py_check_8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required('Entity2') class Entity2(db.Entity): a = Set('Entity1', py_check=lambda val: True) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) - @raises_exception(ValueError, "Check for attribute Entity1.a failed. Value: " + ( - "u'12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345..." if PY2 - else "'123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456..." - )) def test_py_check_truncate(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(str, py_check=lambda val: False) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: - obj = Entity1(a='1234567890' * 1000) + try: + obj = Entity1(a='1234567890' * 1000) + except ValueError as e: + error_message = "Check for attribute Entity1.a failed. Value: " + ( + "u'12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345..." if PY2 + else "'123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456...") + self.assertEqual(str(e), error_message) + else: + self.assert_(False) @raises_exception(ValueError, 'Value for attribute Entity1.a is too long. Max length is 10, value length is 10000') def test_str_max_len(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(str, 10) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a='1234567890' * 1000) + @only_for('sqlite') def test_foreign_key_sql_type_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SOME_TYPE') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'SOME_TYPE') + @only_for('sqlite') def test_foreign_key_sql_type_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SOME_TYPE') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo, sql_type='ANOTHER_TYPE') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'ANOTHER_TYPE') + @only_for('sqlite') def test_foreign_key_sql_type_3(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SERIAL') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo, sql_type='ANOTHER_TYPE') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'ANOTHER_TYPE') def test_foreign_key_sql_type_4(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SERIAL') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type - self.assertEqual(sql_type, 'INTEGER') + required_type = 'INT8' if db.provider_name == 'cockroach' else 'INTEGER' + self.assertEqual(required_type, sql_type) def test_foreign_key_sql_type_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='serial') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type - self.assertEqual(sql_type, 'integer') + required_type = 'int8' if db.provider_name == 'cockroach' else 'integer' + self.assertEqual(required_type, sql_type) def test_self_referenced_m2m_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Node(db.Entity): id = PrimaryKey(int) prev_nodes = Set("Node") next_nodes = Set("Node") - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_implicit_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): name = Required(str) bar = Required("Bar") @@ -628,7 +654,7 @@ class Bar(db.Entity): id = PrimaryKey(int) name = Optional(str) foos = Set("Foo") - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertTrue(Foo.id.is_implicit) self.assertFalse(Foo.name.is_implicit) @@ -639,12 +665,12 @@ class Bar(db.Entity): self.assertFalse(Bar.foos.is_implicit) def test_implicit_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): x = Required(str) class Bar(Foo): y = Required(str) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertTrue(Foo.id.is_implicit) self.assertTrue(Foo.classtype.is_implicit) @@ -657,19 +683,19 @@ class Bar(Foo): @raises_exception(TypeError, 'Attribute Foo.x has invalid type NoneType') def test_none_type(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): x = Required(type(None)) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "'sql_default' option value cannot be empty string, " "because it should be valid SQL literal or expression. " "Try to use \"''\", or just specify default='' instead.") def test_none_type(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): x = Required(str, sql_default='') - + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_diagram_keys.py b/pony/orm/tests/test_diagram_keys.py index 510a497bb..653656053 100644 --- a/pony/orm/tests/test_diagram_keys.py +++ b/pony/orm/tests/test_diagram_keys.py @@ -4,11 +4,15 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, setup_database, teardown_database + class TestKeys(unittest.TestCase): + def tearDown(self): + teardown_database(self.db) def test_keys1(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Required(str) @@ -20,7 +24,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._composite_keys_, []) def test_keys2(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -34,14 +38,14 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, 'Only one primary key can be defined in each entity class') def test_keys3(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = PrimaryKey(int) @raises_exception(ERDiagramError, 'Only one primary key can be defined in each entity class') def test_keys4(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Required(int) @@ -49,7 +53,7 @@ class Entity1(db.Entity): PrimaryKey(b, c) def test_unique1(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Required(int, unique=True) @@ -58,7 +62,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._composite_keys_, []) def test_unique2(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Optional(int, unique=True) @@ -67,7 +71,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._composite_keys_, []) def test_unique2_1(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Optional(int) @@ -79,28 +83,28 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'composite_key() must receive at least two attributes as arguments') def test_unique3(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) composite_key() @raises_exception(TypeError, 'composite_key() arguments must be attributes. Got: 123') def test_unique4(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) composite_key(123, 456) @raises_exception(TypeError, "composite_key() arguments must be attributes. Got: %r" % int) def test_unique5(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) composite_key(int, a) @raises_exception(TypeError, 'Set attribute Entity1.b cannot be part of unique index') def test_unique6(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Set('Entity2') @@ -108,14 +112,14 @@ class Entity1(db.Entity): @raises_exception(TypeError, "'unique' option cannot be set for attribute Entity1.b because it is collection") def test_unique7(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Set('Entity2', unique=True) @raises_exception(TypeError, 'Optional attribute Entity1.b cannot be part of primary key') def test_unique8(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Optional(int) @@ -123,13 +127,13 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'PrimaryKey attribute Entity1.a cannot be of type float') def test_float_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(float) @raises_exception(TypeError, 'Attribute Entity1.b of type float cannot be part of primary key') def test_float_composite_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(float) @@ -137,7 +141,7 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'Attribute Entity1.b of type float cannot be part of unique index') def test_float_composite_key(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(float) @@ -145,33 +149,33 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'Unique attribute Entity1.a cannot be of type float') def test_float_unique(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(float, unique=True) @raises_exception(TypeError, 'PrimaryKey attribute Entity1.a cannot be volatile') def test_volatile_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int, volatile=True) @raises_exception(TypeError, 'Set attribute Entity1.b cannot be volatile') def test_volatile_set(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Set('Entity2', volatile=True) @raises_exception(TypeError, 'Volatile attribute Entity1.b cannot be part of primary key') def test_volatile_composite_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(int, volatile=True) PrimaryKey(a, b) def test_composite_key_update(self): - db = Database('sqlite', ':memory:') + db = self.db = Database() class Entity1(db.Entity): s = Set('Entity3') class Entity2(db.Entity): @@ -180,7 +184,8 @@ class Entity3(db.Entity): a = Required(Entity1) b = Required(Entity2) composite_key(a, b) - db.generate_mapping(create_tables=True) + setup_database(db) + with db_session: x = Entity1(id=1) y = Entity2(id=1) diff --git a/pony/orm/tests/test_distinct.py b/pony/orm/tests/test_distinct.py new file mode 100644 index 000000000..a53f67b4c --- /dev/null +++ b/pony/orm/tests/test_distinct.py @@ -0,0 +1,92 @@ +from __future__ import absolute_import, print_function, division + +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() + +class Department(db.Entity): + number = PrimaryKey(int) + groups = Set('Group') + +class Group(db.Entity): + id = PrimaryKey(int) + dept = Required('Department') + students = Set('Student') + +class Student(db.Entity): + name = Required(unicode) + age = Required(int) + group = Required('Group') + scholarship = Required(int, default=0) + courses = Set('Course') + +class Course(db.Entity): + name = Required(unicode) + semester = Required(int) + PrimaryKey(name, semester) + students = Set('Student') + + +class TestDistinct(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=1) + d2 = Department(number=2) + g1 = Group(id=1, dept=d1) + g2 = Group(id=2, dept=d2) + s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) + s2 = Student(id=2, name='S2', age=21, group=g1, scholarship=100) + s3 = Student(id=3, name='S3', age=23, group=g1, scholarship=200) + s4 = Student(id=4, name='S4', age=21, group=g1, scholarship=100) + s5 = Student(id=5, name='S5', age=23, group=g2, scholarship=0) + s6 = Student(id=6, name='S6', age=23, group=g2, scholarship=200) + c1 = Course(name='C1', semester=1, students=[s1, s2, s3]) + c2 = Course(name='C2', semester=1, students=[s2, s3, s5, s6]) + c3 = Course(name='C3', semester=2, students=[s4, s5, s6]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + db_session.__enter__() + + def tearDown(self): + db_session.__exit__() + + def test_group_by(self): + result = set(select((s.age, sum(s.scholarship)) for s in Student if s.scholarship > 0)) + self.assertEqual(result, {(21, 200), (23, 400)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_group_by_having(self): + result = set(select((s.age, sum(s.scholarship)) for s in Student if sum(s.scholarship) < 300)) + self.assertEqual(result, {(20, 0), (21, 200)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_aggregation_no_group_by_1(self): + result = set(select(sum(s.scholarship) for s in Student if s.age < 23)) + self.assertEqual(result, {200}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_aggregation_no_group_by_2(self): + result = set(select((sum(s.scholarship), min(s.scholarship)) for s in Student if s.age < 23)) + self.assertEqual(result, {(200, 0)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_aggregation_no_group_by_3(self): + result = set(select((sum(s.scholarship), min(s.scholarship)) + for s in Student for g in Group + if s.group == g and g.dept.number == 1)) + self.assertEqual(result, {(400, 0)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + +if __name__ == "__main__": + unittest.main() diff --git a/pony/orm/tests/test_entity_init.py b/pony/orm/tests/test_entity_init.py index 7f1a2b896..37d62ec1b 100644 --- a/pony/orm/tests/test_entity_init.py +++ b/pony/orm/tests/test_entity_init.py @@ -6,23 +6,32 @@ from pony.orm.tests.testutils import raises_exception from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -class TestCustomInit(unittest.TestCase): - def test1(self): - db = Database('sqlite', ':memory:') +db = Database() - class User(db.Entity): - name = Required(str) - password = Required(str) - created_at = Required(datetime) - def __init__(self, name, password): - password = md5(password.encode('utf8')).hexdigest() - super(User, self).__init__(name=name, password=password, created_at=datetime.now()) - self.uppercase_name = name.upper() +class User(db.Entity): + name = Required(str) + password = Required(str) + created_at = Required(datetime) - db.generate_mapping(create_tables=True) + def __init__(self, name, password): + password = md5(password.encode('utf8')).hexdigest() + super(User, self).__init__(name=name, password=password, created_at=datetime.now()) + self.uppercase_name = name.upper() + +class TestCustomInit(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(self): + teardown_database(db) + + def test1(self): with db_session: u1 = User('John', '123') u2 = User('Mike', '456') diff --git a/pony/orm/tests/test_entity_proxy.py b/pony/orm/tests/test_entity_proxy.py new file mode 100644 index 000000000..4cc319c9d --- /dev/null +++ b/pony/orm/tests/test_entity_proxy.py @@ -0,0 +1,154 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() + +class Country(db.Entity): + id = PrimaryKey(int) + name = Required(str) + persons = Set("Person") + +class Person(db.Entity): + id = PrimaryKey(int) + name = Required(str) + country = Required(Country) + +class TestProxy(unittest.TestCase): + def setUp(self): + setup_database(db) + with db_session: + c1 = Country(id=1, name='Russia') + c2 = Country(id=2, name='Japan') + Person(id=1, name='Alexander Nevskiy', country=c1) + Person(id=2, name='Raikou Minamoto', country=c2) + Person(id=3, name='Ibaraki Douji', country=c2) + + def tearDown(self): + teardown_database(db) + + def test_1(self): + with db_session: + p = make_proxy(Person[2]) + + with db_session: + x1 = db.local_stats[None].db_count # number of queries + # it is possible to access p attributes in a new db_session + name = p.name + country = p.country + x2 = db.local_stats[None].db_count + + # p.name and p.country are loaded with a single query + self.assertEqual(x1, x2-1) + + def test_2(self): + with db_session: + p = make_proxy(Person[2]) + name = p.name + country = p.country + + with db_session: + x1 = db.local_stats[None].db_count + name = p.name + country = p.country + x2 = db.local_stats[None].db_count + + # attribute values from the first db_session should be ignored and loaded again + self.assertEqual(x1, x2-1) + + def test_3(self): + with db_session: + p = Person[2] + proxy = make_proxy(p) + + with db_session: + p2 = Person[2] + name1 = 'Tamamo no Mae' + # It is possible to assign new attribute values to a proxy object + p2.name = name1 + name2 = proxy.name + + self.assertEqual(name1, name2) + + + def test_4(self): + with db_session: + p = Person[2] + proxy = make_proxy(p) + + with db_session: + p2 = Person[2] + name1 = 'Tamamo no Mae' + p2.name = name1 + + with db_session: + # new attribute value was successfully stored in the database + name2 = proxy.name + + self.assertEqual(name1, name2) + + def test_5(self): + with db_session: + p = Person[2] + r = repr(p) + self.assertEqual(r, 'Person[2]') + + proxy = make_proxy(p) + r = repr(proxy) + # proxy object has specific repr + self.assertEqual(r, '') + + r = repr(proxy) + # repr of proxy object can be used outside of db_session + self.assertEqual(r, '') + + del p + r = repr(proxy) + # repr works even if the original object was deleted + self.assertEqual(r, '') + + + def test_6(self): + with db_session: + p = Person[2] + proxy = make_proxy(p) + proxy.name = 'Okita Souji' + # after assignment, the attribute value is the same for the proxy and for the original object + self.assertEqual(proxy.name, 'Okita Souji') + self.assertEqual(p.name, 'Okita Souji') + + + def test_7(self): + with db_session: + p = Person[2] + proxy = make_proxy(p) + proxy.name = 'Okita Souji' + # after assignment, the attribute value is the same for the proxy and for the original object + self.assertEqual(proxy.name, 'Okita Souji') + self.assertEqual(p.name, 'Okita Souji') + + + def test_8(self): + with db_session: + c1 = Country[1] + c1_proxy = make_proxy(c1) + p2 = Person[2] + self.assertNotEqual(p2.country, c1) + self.assertNotEqual(p2.country, c1_proxy) + # proxy can be used in attribute assignment + p2.country = c1_proxy + self.assertEqual(p2.country, c1_proxy) + self.assertIs(p2.country, c1) + + + def test_9(self): + with db_session: + c2 = Country[2] + c2_proxy = make_proxy(c2) + persons = select(p for p in Person if p.country == c2_proxy) + self.assertEqual({p.id for p in persons}, {2, 3}) + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_exists.py b/pony/orm/tests/test_exists.py new file mode 100644 index 000000000..6c0f3a17b --- /dev/null +++ b/pony/orm/tests/test_exists.py @@ -0,0 +1,79 @@ +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() + +class Group(db.Entity): + students = Set('Student') + +class Student(db.Entity): + first_name = Required(str) + last_name = Required(str) + login = Optional(str, nullable=True) + graduated = Optional(bool, default=False) + group = Required(Group) + passport = Optional('Passport', column='passport') + +class Passport(db.Entity): + student = Optional(Student) + + +class TestExists(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1) + g2 = Group(id=2) + + p = Passport(id=1) + + Student(id=1, first_name='Mashu', last_name='Kyrielight', login='Shielder', group=g1) + Student(id=2, first_name='Okita', last_name='Souji', login='Sakura', group=g1) + Student(id=3, first_name='Francis', last_name='Drake', group=g2, graduated=True) + Student(id=4, first_name='Oda', last_name='Nobunaga', group=g2, graduated=True) + Student(id=5, first_name='William', last_name='Shakespeare', group=g2, graduated=True, passport=p) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_1(self): + q = select(g for g in Group if exists(s.login for s in g.students))[:] + self.assertEqual(q[0], Group[1]) + + def test_2(self): + q = select(g for g in Group if exists(s.graduated for s in g.students))[:] + self.assertEqual(q[0], Group[2]) + + def test_3(self): + q = select(s for s in Student if + exists(len(s2.first_name) == len(s.first_name) and s != s2 for s2 in Student))[:] + self.assertEqual(set(q), {Student[1], Student[2], Student[3], Student[5]}) + + def test_4(self): + q = select(g for g in Group if not exists(not s.graduated for s in g.students))[:] + self.assertEqual(q[0], Group[2]) + + def test_5(self): + q = select(g for g in Group if exists(s for s in g.students))[:] + self.assertEqual(set(q), {Group[1], Group[2]}) + + def test_6(self): + q = select(g for g in Group if exists(s.login for s in g.students if s.first_name != 'Okita') and g.id != 10)[:] + self.assertEqual(q[0], Group[1]) + + def test_7(self): + q = select(g for g in Group if exists(s.passport for s in g.students))[:] + self.assertEqual(q[0], Group[2]) \ No newline at end of file diff --git a/pony/orm/tests/test_f_strings.py b/pony/orm/tests/test_f_strings.py new file mode 100644 index 000000000..fa6414e49 --- /dev/null +++ b/pony/orm/tests/test_f_strings.py @@ -0,0 +1,4 @@ +from sys import version_info + +if version_info[:2] >= (3, 6): + from pony.orm.tests.py36_test_f_strings import * \ No newline at end of file diff --git a/pony/orm/tests/test_filter.py b/pony/orm/tests/test_filter.py index f6b6724a5..be412b7ab 100644 --- a/pony/orm/tests/test_filter.py +++ b/pony/orm/tests/test_filter.py @@ -3,8 +3,17 @@ import unittest from pony.orm.tests.model1 import * +from pony.orm.tests import setup_database, teardown_database + class TestFilter(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + populate_db() + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -14,17 +23,17 @@ def tearDown(self): def test_filter_1(self): q = select(s for s in Student) result = set(q.filter(scholarship=0)) - self.assertEqual(result, set([Student[101], Student[103]])) + self.assertEqual(result, {Student[101], Student[103]}) def test_filter_2(self): q = select(s for s in Student) q2 = q.filter(scholarship=500) result = set(q2.filter(group=Group['3132'])) - self.assertEqual(result, set([Student[104]])) + self.assertEqual(result, {Student[104]}) def test_filter_3(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship > 500) result = set(q2.filter(lambda s: count(s.marks) > 0)) - self.assertEqual(result, set([Student[102]])) + self.assertEqual(result, {Student[102]}) def test_filter_4(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) @@ -47,11 +56,11 @@ def test_filter_7(self): q = select(s for s in Student) q2 = q.filter(scholarship=0) result = set(q2.filter(lambda s: count(s.marks) > 1)) - self.assertEqual(result, set([Student[103], Student[101]])) + self.assertEqual(result, {Student[103], Student[101]}) def test_filter_8(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) q3 = q2.order_by(lambda s: s.name) q4 = q3.order_by(None) result = set(q4.filter(lambda s: count(s.marks) > 1)) - self.assertEqual(result, set([Student[103], Student[101]])) + self.assertEqual(result, {Student[103], Student[101]}) diff --git a/pony/orm/tests/test_flush.py b/pony/orm/tests/test_flush.py index aab318169..bd5c6c1eb 100644 --- a/pony/orm/tests/test_flush.py +++ b/pony/orm/tests/test_flush.py @@ -4,21 +4,25 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -class TestFlush(unittest.TestCase): - def setUp(self): - self.db = Database('sqlite', ':memory:') - class Person(self.db.Entity): - name = Required(unicode) +class Person(db.Entity): + name = Required(unicode) - self.db.generate_mapping(create_tables=True) - def tearDown(self): - self.db = None +class TestFlush(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(self): + teardown_database(db) def test1(self): - Person = self.db.Person with db_session: a = Person(name='A') b = Person(name='B') @@ -29,10 +33,12 @@ def test1(self): b.flush() self.assertEqual(a.id, None) - self.assertEqual(b.id, 1) + self.assertIsNotNone(b.id) + b_id = b.id self.assertEqual(c.id, None) flush() - self.assertEqual(a.id, 2) - self.assertEqual(b.id, 1) - self.assertEqual(c.id, 3) + self.assertIsNotNone(a.id) + self.assertEqual(b.id, b_id) + self.assertIsNotNone(c.id) + self.assertEqual(len({a.id, b.id, c.id}), 3) diff --git a/pony/orm/tests/test_frames.py b/pony/orm/tests/test_frames.py index 67a56dbef..94a3bc1f1 100644 --- a/pony/orm/tests/test_frames.py +++ b/pony/orm/tests/test_frames.py @@ -3,45 +3,53 @@ import unittest from pony.orm.core import * +import pony.orm.decompiling +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Person(db.Entity): name = Required(unicode) age = Required(int) -db.generate_mapping(create_tables=True) - -with db_session: - p1 = Person(name='John', age=22) - p2 = Person(name='Mary', age=18) - p3 = Person(name='Mike', age=25) class TestFrames(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p1 = Person(id=1, name='John', age=22) + p2 = Person(id=2, name='Mary', age=18) + p3 = Person(id=3, name='Mike', age=25) + + @classmethod + def tearDownClass(cls): + db.drop_all_tables(with_all_data=True) @db_session def test_select(self): x = 20 result = select(p.id for p in Person if p.age > x)[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_select_str(self): x = 20 result = select('p.id for p in Person if p.age > x')[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_left_join(self): x = 20 result = left_join(p.id for p in Person if p.age > x)[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_left_join_str(self): x = 20 result = left_join('p.id for p in Person if p.age > x')[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_get(self): @@ -107,13 +115,13 @@ def test_entity_exists_str(self): def test_entity_select(self): x = 20 result = Person.select(lambda p: p.age > x)[:] - self.assertEqual(set(result), set([Person[1], Person[3]])) + self.assertEqual(set(result), {Person[1], Person[3]}) @db_session def test_entity_select_str(self): x = 20 result = Person.select('lambda p: p.age > x')[:] - self.assertEqual(set(result), set([Person[1], Person[3]])) + self.assertEqual(set(result), {Person[1], Person[3]}) @db_session def test_order_by(self): @@ -167,5 +175,19 @@ def test_db_exists(self): result = db.exists('name from Person where age = $x') self.assertEqual(result, True) + @raises_exception(pony.orm.decompiling.InvalidQuery, + 'Use generator expression (... for ... in ...) ' + 'instead of list comprehension [... for ... in ...] inside query') + @db_session + def test_inner_list_comprehension(self): + result = select(p.id for p in Person if p.age not in [ + p2.age for p2 in Person if p2.name.startswith('M')])[:] + + @db_session + def test_outer_list_comprehension(self): + names = ['John', 'Mary', 'Mike'] + persons = [ Person.select(lambda p: p.name == name).first() for name in names ] + self.assertEqual(set(p.name for p in persons), {'John', 'Mary', 'Mike'}) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_generator_db_session.py b/pony/orm/tests/test_generator_db_session.py index a76df966c..d232def19 100644 --- a/pony/orm/tests/test_generator_db_session.py +++ b/pony/orm/tests/test_generator_db_session.py @@ -5,14 +5,16 @@ from pony.orm.core import * from pony.orm.core import local from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + class TestGeneratorDbSession(unittest.TestCase): def setUp(self): - db = Database('sqlite', ':memory:') + db = Database() class Account(db.Entity): id = PrimaryKey(int) amount = Required(int) - db.generate_mapping(create_tables=True) + setup_database(db) self.db = db self.Account = Account @@ -23,6 +25,7 @@ class Account(db.Entity): a3 = Account(id=3, amount=3000) def tearDown(self): + teardown_database(self.db) assert local.db_session is None self.db = self.Account = None @@ -119,7 +122,7 @@ def f(id1): a2 = self.Account[2] self.assertEqual(a2.amount, 2100) - @raises_exception(TransactionError, 'You need to manually commit() changes before yielding from the generator') + @raises_exception(TransactionError, 'You need to manually commit() changes before suspending the generator') def test8(self): @db_session def f(id1): @@ -141,7 +144,6 @@ def f(id1): for amount in f(1): pass - @raises_exception(TransactionError, 'You need to manually commit() changes before exiting from the generator') def test10(self): @db_session def f(id1): @@ -149,19 +151,14 @@ def f(id1): yield a1.amount a1.amount += 100 + with db_session: + a = self.Account[1].amount for amount in f(1): pass + with db_session: + b = self.Account[1].amount - def test11(self): - @db_session - def f(id1): - a1 = self.Account[id1] - yield a1.amount - a1.amount += 100 - commit() - - for amount in f(1): - pass + self.assertEqual(b, a + 100) def test12(self): @db_session diff --git a/pony/orm/tests/test_get_pk.py b/pony/orm/tests/test_get_pk.py new file mode 100644 index 000000000..000091af5 --- /dev/null +++ b/pony/orm/tests/test_get_pk.py @@ -0,0 +1,67 @@ +import unittest +from datetime import date + +from pony.orm import * +from pony.orm.tests import setup_database, teardown_database + +day = date.today() + +db = Database() + +class A(db.Entity): + b = Required("B") + c = Required("C") + PrimaryKey(b, c) + +class B(db.Entity): + id = PrimaryKey(date) + a_set = Set(A) + +class C(db.Entity): + x = Required("X") + y = Required("Y") + a_set = Set(A) + PrimaryKey(x, y) + +class X(db.Entity): + id = PrimaryKey(int) + c_set = Set(C) + +class Y(db.Entity): + id = PrimaryKey(int) + c_set = Set(C) + + +class Test(unittest.TestCase): + def setUp(self): + setup_database(db) + with db_session: + x1 = X(id=123) + y1 = Y(id=456) + b1 = B(id=day) + c1 = C(x=x1, y=y1) + A(b=b1, c=c1) + + def tearDown(self): + teardown_database(db) + + @db_session + def test_1(self): + a1 = A.select().first() + a2 = A[a1.get_pk()] + self.assertEqual(a1, a2) + + @db_session + def test2(self): + a = A.select().first() + b = B.select().first() + c = C.select().first() + pk = (b.get_pk(), c._get_raw_pkval_()) + self.assertTrue(a is A[pk]) + + @db_session + def test3(self): + a = A.select().first() + c = C.select().first() + pk = (day, c.get_pk()) + self.assertTrue(a is A[pk]) diff --git a/pony/orm/tests/test_getattr.py b/pony/orm/tests/test_getattr.py new file mode 100644 index 000000000..226271b52 --- /dev/null +++ b/pony/orm/tests/test_getattr.py @@ -0,0 +1,110 @@ +from pony.py23compat import basestring + +import unittest + +from pony.orm import * +from pony import orm +from pony.utils import cached_property +from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import db_params, setup_database, teardown_database + +class Test(unittest.TestCase): + + @cached_property + def db(self): + return orm.Database() + + def setUp(self): + db = self.db + + class Genre(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') + + class Hobby(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') + + class Artist(db.Entity): + name = orm.Required(str) + age = orm.Optional(int) + hobbies = orm.Set(Hobby) + genres = orm.Set(Genre) + + setup_database(db) + + with orm.db_session: + pop = Genre(name='Pop') + Artist(name='Sia', age=40, genres=[pop]) + Hobby(name='Swimming') + + pony.options.INNER_JOIN_SYNTAX = True + + def tearDown(self): + teardown_database(self.db) + + @db_session + def test_no_caching(self): + for attr_name, attr_type in zip(['name', 'age'], [basestring, int]): + val = select(getattr(x, attr_name) for x in self.db.Artist).first() + self.assertIsInstance(val, attr_type) + + @db_session + def test_simple(self): + val = select(getattr(x, 'age') for x in self.db.Artist).first() + self.assertIsInstance(val, int) + + @db_session + def test_expr(self): + val = select(getattr(x, ''.join(['ag', 'e'])) for x in self.db.Artist).first() + self.assertIsInstance(val, int) + + @db_session + def test_external(self): + class data: + id = 1 + val = select(x.id for x in self.db.Artist if x.id >= getattr(data, 'id')).first() + self.assertIsNotNone(val) + + @db_session + def test_related(self): + val = select(getattr(x.genres, 'name') for x in self.db.Artist).first() + self.assertIsNotNone(val) + + @db_session + def test_not_instance_iter(self): + val = select(getattr(x.name, 'startswith')('S') for x in self.db.Artist).first() + self.assertTrue(val) + + @raises_exception(TranslationError, 'Expression `getattr(x, x.name)` cannot be translated into SQL ' + 'because x.name will be different for each row') + @db_session + def test_not_external(self): + select(getattr(x, x.name) for x in self.db.Artist) + + @raises_exception(TypeError, 'In `getattr(x, 1)` second argument should be a string. Got: 1') + @db_session + def test_not_string(self): + select(getattr(x, 1) for x in self.db.Artist) + + @raises_exception(TypeError, 'In `getattr(x, name)` second argument should be a string. Got: 1') + @db_session + def test_not_string(self): + name = 1 + select(getattr(x, name) for x in self.db.Artist) + + @db_session + def test_lambda_1(self): + for name, value in [('name', 'Sia'), ('age', 40), ('name', 'Sia')]: + result = self.db.Artist.select(lambda a: getattr(a, name) == value) + self.assertEqual(set(obj.name for obj in result), {'Sia'}) + + @db_session + def test_lambda_2(self): + for entity, name, value in [ + (self.db.Genre, 'name', 'Pop'), + (self.db.Artist, 'age', 40), + (self.db.Hobby, 'name', 'Swimming'), + ]: + result = entity.select(lambda a: getattr(a, name) == value) + self.assertEqual(set(result[:]), {entity.select().first()}) diff --git a/pony/orm/tests/test_hooks.py b/pony/orm/tests/test_hooks.py index 23e9047ec..af3c2002f 100644 --- a/pony/orm/tests/test_hooks.py +++ b/pony/orm/tests/test_hooks.py @@ -3,48 +3,68 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database, db_params logged_events = [] -db = Database('sqlite', ':memory:') +db = Database() + class Person(db.Entity): id = PrimaryKey(int) name = Required(unicode) age = Required(int) + def before_insert(self): logged_events.append('BI_' + self.name) do_before_insert(self) + def before_update(self): logged_events.append('BU_' + self.name) do_before_update(self) + def before_delete(self): logged_events.append('BD_' + self.name) do_before_delete(self) + def after_insert(self): logged_events.append('AI_' + self.name) do_after_insert(self) + def after_update(self): logged_events.append('AU_' + self.name) do_after_update(self) + def after_delete(self): logged_events.append('AD_' + self.name) do_after_delete(self) + def do_nothing(person): pass + def set_hooks_to_do_nothing(): global do_before_insert, do_before_update, do_before_delete global do_after_insert, do_after_update, do_after_delete do_before_insert = do_before_update = do_before_delete = do_nothing do_after_insert = do_after_update = do_after_delete = do_nothing + +db.bind(**db_params) +db.generate_mapping(check_tables=False) + set_hooks_to_do_nothing() -db.generate_mapping(create_tables=True) class TestHooks(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): set_hooks_to_do_nothing() @@ -90,7 +110,15 @@ def flush_for(*objects): for obj in objects: obj.flush() + class ObjectFlushTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): set_hooks_to_do_nothing() diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py new file mode 100644 index 000000000..e5e146285 --- /dev/null +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -0,0 +1,245 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() + +sep = ' ' + + +class Person(db.Entity): + id = PrimaryKey(int) + first_name = Required(str) + last_name = Required(str) + favorite_color = Optional(str) + cars = Set(lambda: Car) + + @property + def full_name(self): + return self.first_name + sep + self.last_name + + @property + def full_name_2(self): + return concat(self.first_name, sep, self.last_name) # tests using of function `concat` from external scope + + @property + def has_car(self): + return not self.cars.is_empty() + + def cars_by_color1(self, color): + return select(car for car in self.cars if car.color == color) + + def cars_by_color2(self, color): + return self.cars.select(lambda car: car.color == color) + + @property + def cars_price(self): + return sum(c.price for c in self.cars) + + @property + def incorrect_full_name(self): + return self.first_name + ' ' + p.last_name # p is FakePerson instance here + + @classmethod + def find_by_full_name(cls, full_name): + return cls.select(lambda p: p.full_name_2 == full_name) + + def complex_method(self): + result = '' + for i in range(10): + result += str(i) + return result + + def simple_method(self): + return self.complex_method() + + +class FakePerson(object): + pass + + +p = FakePerson() +p.last_name = '***' + + +class Car(db.Entity): + brand = Required(str) + model = Required(str) + owner = Optional(Person) + year = Required(int) + price = Required(int) + color = Required(str) + + +def simple_func(person): + return person.full_name + + +def complex_func(person): + return person.complex_method() + + + +class TestHybridsAndProperties(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p1 = Person(id=1, first_name='Alexander', last_name='Kozlovsky', favorite_color='white') + p2 = Person(id=2, first_name='Alexei', last_name='Malashkevich', favorite_color='green') + p3 = Person(id=3, first_name='Vitaliy', last_name='Abetkin') + p4 = Person(id=4, first_name='Alexander', last_name='Tischenko', favorite_color='blue') + + c1 = Car(id=1, brand='Peugeot', model='306', owner=p1, year=2006, price=14000, color='red') + c2 = Car(id=2, brand='Honda', model='Accord', owner=p1, year=2007, price=13850, color='white') + c3 = Car(id=3, brand='Nissan', model='Skyline', owner=p2, year=2008, price=29900, color='black') + c4 = Car(id=4, brand='Volkswagen', model='Passat', owner=p1, year=2012, price=9400, color='blue') + c5 = Car(id=5, brand='Koenigsegg', model='CCXR', owner=p4, year=2016, price=4850000, color='white') + c6 = Car(id=6, brand='Lada', model='Kalina', owner=p4, year=2015, price=5000, color='white') + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def test1(self): + persons = select(p.full_name for p in Person if p.has_car)[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexei Malashkevich', 'Alexander Tischenko'}) + + @db_session + def test2(self): + cars_prices = select(p.cars_price for p in Person)[:] + self.assertEqual(set(cars_prices), {0, 29900, 37250, 4855000}) + + @db_session + def test3(self): + persons = select(p.full_name for p in Person if p.cars_price > 100000)[:] + self.assertEqual(set(persons), {'Alexander Tischenko'}) + + @db_session + def test4(self): + persons = select(p.full_name for p in Person if not p.cars_price)[:] + self.assertEqual(set(persons), {'Vitaliy Abetkin'}) + + @db_session + def test5(self): + persons = select(p.full_name for p in Person if exists(c for c in p.cars_by_color2('white') if c.price > 10000))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexander Tischenko'}) + + @db_session + def test6(self): + persons = select(p.full_name for p in Person if exists(c for c in p.cars_by_color1('white') if c.price > 10000))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexander Tischenko'}) + + @db_session + def test7(self): + c1 = Car[1] + persons = select(p.full_name for p in Person if c1 in p.cars_by_color2('red'))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test8(self): + c1 = Car[1] + persons = select(p.full_name for p in Person if c1 in p.cars_by_color1('red'))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test9(self): + persons = select(p.full_name for p in Person if p.cars_by_color1(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test10(self): + persons = select(p.full_name for p in Person if not p.cars_by_color1(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Tischenko', 'Alexei Malashkevich', 'Vitaliy Abetkin'}) + + @db_session + def test11(self): + persons = select(p.full_name for p in Person if p.cars_by_color2(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test12(self): + persons = select(p.full_name for p in Person if not p.cars_by_color2(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Tischenko', 'Alexei Malashkevich', 'Vitaliy Abetkin'}) + + @db_session + def test13(self): + persons = select(p.full_name for p in Person if count(p.cars_by_color1('white')) > 1) + self.assertEqual(set(persons), {'Alexander Tischenko'}) + + @db_session + def test14(self): + # This test checks if accessing function-specific globals works correctly + persons = select(p.incorrect_full_name for p in Person if p.has_car)[:] + self.assertEqual(set(persons), {'Alexander ***', 'Alexei ***', 'Alexander ***'}) + + @db_session + def test15(self): + # Test repeated use of the same generator with hybrid method/property that uses funciton from external scope + result = Person.find_by_full_name('Alexander Kozlovsky') + self.assertEqual(set(obj.last_name for obj in result), {'Kozlovsky'}) + result = Person.find_by_full_name('Alexander Kozlovsky') + self.assertEqual(set(obj.last_name for obj in result), {'Kozlovsky'}) + result = Person.find_by_full_name('Alexander Tischenko') + self.assertEqual(set(obj.last_name for obj in result), {'Tischenko'}) + + @db_session + def test16(self): + result = Person.select(lambda p: p.full_name == 'Alexander Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + + @db_session + def test17(self): + global sep + sep = '.' + try: + result = Person.select(lambda p: p.full_name == 'Alexander.Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + finally: + sep = ' ' + + @db_session + def test18(self): + result = Person.select().filter(lambda p: p.full_name == 'Alexander Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + + @db_session + def test19(self): + global sep + sep = '.' + try: + result = Person.select().filter(lambda p: p.full_name == 'Alexander.Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + finally: + sep = ' ' + + @db_session + @raises_exception(TranslationError, 'p.complex_method(...) is too complex to decompile') + def test_20(self): + q = select(p.complex_method() for p in Person)[:] + + @db_session + @raises_exception(TranslationError, 'p.to_dict(...) is too complex to decompile') + def test_21(self): + q = select(p.to_dict() for p in Person)[:] + + @db_session + @raises_exception(TranslationError, 'self.complex_method(...) is too complex to decompile (inside Person.simple_method)') + def test_22(self): + q = select(p.simple_method() for p in Person)[:] + + @db_session + def test_23(self): + q = select(simple_func(p) for p in Person)[:] + + @db_session + @raises_exception(TranslationError, 'person.complex_method(...) is too complex to decompile (inside complex_func)') + def test_24(self): + q = select(complex_func(p) for p in Person)[:] + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index 6db92d0a7..9f8b3f453 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -4,17 +4,24 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, teardown_database class TestIndexes(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + + def tearDown(self): + teardown_database(self.db) + def test_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Person(db.Entity): name = Required(str) age = Required(int) composite_key(name, 'age') db.generate_mapping(create_tables=True) - [ i1, i2 ] = Person._indexes_ + i1, i2 = Person._indexes_ self.assertEqual(i1.attrs, (Person.id,)) self.assertEqual(i1.is_pk, True) self.assertEqual(i1.is_unique, True) @@ -22,7 +29,8 @@ class Person(db.Entity): self.assertEqual(i2.is_pk, False) self.assertEqual(i2.is_unique, True) - table = db.schema.tables['Person'] + table_name = 'Person' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'person' + table = db.schema.tables[table_name] name_column = table.column_dict['name'] age_column = table.column_dict['age'] self.assertEqual(len(table.indexes), 2) @@ -31,14 +39,14 @@ class Person(db.Entity): self.assertEqual(db_index.is_unique, True) def test_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Person(db.Entity): name = Required(str) age = Required(int) composite_index(name, 'age') db.generate_mapping(create_tables=True) - [ i1, i2 ] = Person._indexes_ + i1, i2 = Person._indexes_ self.assertEqual(i1.attrs, (Person.id,)) self.assertEqual(i1.is_pk, True) self.assertEqual(i1.is_unique, True) @@ -46,7 +54,8 @@ class Person(db.Entity): self.assertEqual(i2.is_pk, False) self.assertEqual(i2.is_unique, False) - table = db.schema.tables['Person'] + table_name = 'Person' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'person' + table = db.schema.tables[table_name] name_column = table.column_dict['name'] age_column = table.column_dict['age'] self.assertEqual(len(table.indexes), 2) @@ -55,11 +64,26 @@ class Person(db.Entity): self.assertEqual(db_index.is_unique, False) create_script = db.schema.generate_create_script() - index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' - self.assertTrue(index_sql in create_script) - - def test_2(self): - db = Database('sqlite', ':memory:') + + + dialect = self.db.provider.dialect + if pony.__version__ < '0.9': + if dialect == 'SQLite': + index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' + else: + index_sql = 'CREATE INDEX "idx_person__name_age" ON "person" ("name", "age")' + elif dialect == 'MySQL' or dialect == 'SQLite': + index_sql = 'CREATE INDEX `idx_person__name__age` ON `person` (`name`, `age`)' + elif dialect == 'PostgreSQL': + index_sql = 'CREATE INDEX "idx_person__name__age" ON "person" ("name", "age")' + elif dialect == 'Oracle': + index_sql = 'CREATE INDEX "IDX_PERSON__NAME__AGE" ON "PERSON" ("NAME", "AGE")' + else: + raise NotImplementedError + self.assertIn(index_sql, create_script) + + def test_3(self): + db = self.db class User(db.Entity): name = Required(str, unique=True) @@ -76,5 +100,40 @@ class User(db.Entity): u = User[1] self.assertEqual(u.name, 'B') + def test_4(self): # issue 321 + db = self.db + class Person(db.Entity): + name = Required(str) + age = Required(int) + composite_key(name, age) + + db.generate_mapping(create_tables=True) + with db_session: + p1 = Person(id=1, name='John', age=19) + + with db_session: + p1 = Person[1] + p1.set(name='John', age=19) + p1.delete() + + def test_5(self): + db = self.db + + class Table1(db.Entity): + name = Required(str) + table2s = Set('Table2') + + class Table2(db.Entity): + height = Required(int) + length = Required(int) + table1 = Optional('Table1') + composite_key(height, length, table1) + + db.generate_mapping(create_tables=True) + + with db_session: + Table2(height=2, length=1) + Table2.exists(height=2, length=1) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_inheritance.py b/pony/orm/tests/test_inheritance.py index 565b0750b..78badb3d8 100644 --- a/pony/orm/tests/test_inheritance.py +++ b/pony/orm/tests/test_inheritance.py @@ -4,11 +4,19 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, teardown_database + class TestInheritance(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + + def tearDown(self): + if self.db.schema: + teardown_database(self.db) def test_0(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) @@ -17,7 +25,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._discriminator_, None) def test_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): @@ -39,11 +47,11 @@ class Entity4(Entity2, Entity3): self.assertEqual(Entity2._discriminator_, 'Entity2') self.assertEqual(Entity3._discriminator_, 'Entity3') self.assertEqual(Entity4._discriminator_, 'Entity4') - + @raises_exception(ERDiagramError, "Multiple inheritance graph must be diamond-like. " "Entity Entity3 inherits from Entity1 and Entity2 entities which don't have common base class.") def test_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(int) class Entity2(db.Entity): @@ -55,7 +63,7 @@ class Entity3(Entity1, Entity2): 'because both entities inherit from Entity1. ' 'To fix this, move attribute definition to base class') def test_3a(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): @@ -64,7 +72,7 @@ class Entity3(Entity1): a = Required(int) def test3b(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): @@ -80,7 +88,7 @@ class Entity4(db.Entity): @raises_exception(ERDiagramError, "Name 'a' hides base attribute Entity1.a") def test_4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required(int) @@ -89,27 +97,27 @@ class Entity2(Entity1): @raises_exception(ERDiagramError, "Primary key cannot be redefined in derived classes") def test_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(int) class Entity2(Entity1): b = PrimaryKey(int) def test_6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Discriminator(str) b = Required(int) class Entity2(Entity1): c = Required(int) - + self.assertTrue(Entity1._discriminator_attr_ is Entity1.a) self.assertTrue(Entity2._discriminator_attr_ is Entity1.a) @raises_exception(TypeError, "Discriminator value for entity Entity1 " "with custom discriminator column 'a' of 'int' type is not set") def test_7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Discriminator(int) b = Required(int) @@ -117,7 +125,7 @@ class Entity2(Entity1): c = Required(int) def test_8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) @@ -127,8 +135,8 @@ class Entity2(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: - x = Entity1(b=10) - y = Entity2(b=10, c=20) + x = Entity1(id=1, b=10) + y = Entity2(id=2, b=10, c=20) with db_session: x = Entity1[1] y = Entity1[2] @@ -138,7 +146,7 @@ class Entity2(Entity1): self.assertEqual(y.a, 2) def test_9(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = '1' a = Discriminator(int) @@ -148,8 +156,8 @@ class Entity2(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: - x = Entity1(b=10) - y = Entity2(b=10, c=20) + x = Entity1(id=1, b=10) + y = Entity2(id=2, b=10, c=20) with db_session: x = Entity1[1] y = Entity1[2] @@ -160,7 +168,7 @@ class Entity2(Entity1): @raises_exception(TypeError, "Incorrect discriminator value is set for Entity2 attribute 'a' of 'int' type: 'zzz'") def test_10(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) @@ -170,7 +178,7 @@ class Entity2(Entity1): c = Required(int) def test_11(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) @@ -179,7 +187,7 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, 'Invalid use of attribute Entity1.a in entity Entity2') def test_12(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) class Entity2(db.Entity): @@ -187,7 +195,7 @@ class Entity2(db.Entity): composite_index(Entity1.a, b) def test_13(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) class Entity2(Entity1): @@ -197,7 +205,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.a, Entity2.b) ]) def test_14(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): d = Discriminator(str) a = Required(int) @@ -208,7 +216,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.d, Entity2.a, Entity2.b) ]) def test_15(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): d = Discriminator(str) a = Required(int) @@ -219,7 +227,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.d, Entity2.id, Entity2.a, Entity2.b) ]) def test_16(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) class Entity2(Entity1): @@ -229,7 +237,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.classtype, Entity2.id, Entity2.a, Entity2.b) ]) def test_17(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): t = Discriminator(str) a = Required(int) @@ -241,21 +249,58 @@ class Entity2(Entity1): [ (Entity1.id,), (Entity1.t, Entity1.a, Entity1.b) ]) def test_18(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) - class Entity2(db.Entity1): + class Entity2(Entity1): b = Required(int) class Entity3(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: - x = Entity1(a=10) - y = Entity2(a=20, b=30) - z = Entity3(a=40, c=50) + x = Entity1(id=1, a=10) + y = Entity2(id=2, a=20, b=30) + z = Entity3(id=3, a=40, c=50) with db_session: result = select(e for e in Entity1 if e.b == 30 or e.c == 50) self.assertEqual([ e.id for e in result ], [ 2, 3 ]) + def test_discriminator_1(self): + db = self.db + class Entity1(db.Entity): + a = Discriminator(str) + b = Required(int) + PrimaryKey(a, b) + class Entity2(db.Entity1): + c = Required(int) + db.generate_mapping(create_tables=True) + with db_session: + x = Entity1(b=10) + y = Entity2(b=20, c=30) + with db_session: + obj = Entity1.get(b=20) + self.assertEqual(obj.a, 'Entity2') + self.assertEqual(obj.b, 20) + self.assertEqual(obj._pkval_, ('Entity2', 20)) + with db_session: + obj = Entity1['Entity2', 20] + self.assertIsInstance(obj, Entity2) + self.assertEqual(obj.a, 'Entity2') + self.assertEqual(obj.b, 20) + self.assertEqual(obj._pkval_, ('Entity2', 20)) + + @raises_exception(TypeError, "Invalid discriminator attribute value for Foo. Expected: 'Foo', got: 'Baz'") + def test_discriminator_2(self): + db = self.db + class Foo(db.Entity): + id = PrimaryKey(int) + a = Discriminator(str) + b = Required(int) + class Bar(db.Entity): + c = Required(int) + db.generate_mapping(create_tables=True) + with db_session: + x = Foo(id=1, a='Baz', b=100) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_inner_join_syntax.py b/pony/orm/tests/test_inner_join_syntax.py new file mode 100644 index 000000000..cfdbf56c2 --- /dev/null +++ b/pony/orm/tests/test_inner_join_syntax.py @@ -0,0 +1,96 @@ +import unittest + +from pony.orm import * +from pony import orm +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() + + +class Genre(db.Entity): + name = orm.Optional(str) # TODO primary key + artists = orm.Set('Artist') + favorite = orm.Optional(bool) + index = orm.Optional(int) + + +class Hobby(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') + + +class Artist(db.Entity): + name = orm.Required(str) + age = orm.Optional(int) + hobbies = orm.Set(Hobby) + genres = orm.Set(Genre) + +pony.options.INNER_JOIN_SYNTAX = True + + +@only_for('sqlite') +class TestJoin(unittest.TestCase): + exclude_fixtures = {'test': ['clear_tables']} + @classmethod + def setUpClass(cls): + setup_database(db) + + with orm.db_session: + pop = Genre(name='pop') + rock = Genre(name='rock') + Artist(name='Sia', age=40, genres=[pop, rock]) + Artist(name='Lady GaGa', age=30, genres=[pop]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def test_join_1(self): + result = select(g.id for g in db.Genre for a in g.artists if a.name.startswith('S'))[:] + self.assertEqual(db.last_sql, """SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre" + INNER JOIN "Artist" "a" + ON "t-1"."artist" = "a"."id" +WHERE "a"."name" LIKE 'S%'""") + + @db_session + def test_join_2(self): + result = select(g.id for g in db.Genre for a in db.Artist + if JOIN(a in g.artists) and a.name.startswith('S'))[:] + self.assertEqual(db.last_sql, """SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre", "Artist" "a" +WHERE "t-1"."artist" = "a"."id" + AND "a"."name" LIKE 'S%'""") + + + @db_session + def test_join_3(self): + result = select(g.id for g in db.Genre for x in db.Artist for a in db.Artist + if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] + self.assertEqual(db.last_sql, '''SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre", "Artist" "x", "Artist" "a" +WHERE "t-1"."artist" = "a"."id" + AND "a"."name" LIKE 'S%' + AND "g"."id" = "x"."id"''') + + @db_session + def test_join_4(self): + result = select(g.id for g in db.Genre for a in db.Artist for x in db.Artist + if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] + self.assertEqual(db.last_sql, '''SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre", "Artist" "a", "Artist" "x" +WHERE "t-1"."artist" = "a"."id" + AND "a"."name" LIKE 'S%' + AND "g"."id" = "x"."id"''') + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_interleave.py b/pony/orm/tests/test_interleave.py new file mode 100644 index 000000000..1662b50ae --- /dev/null +++ b/pony/orm/tests/test_interleave.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import, print_function, division + +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import db_params, only_for + +@only_for(providers=['cockroach']) +class TestDiag(unittest.TestCase): + @raises_exception(TypeError, '`interleave` option cannot be specified for Set attribute Foo.x') + def test_1(self): + db = Database() + class Foo(db.Entity): + x = Set('Bar', interleave=True) + class Bar(db.Entity): + y = Required('Foo') + + @raises_exception(TypeError, "`interleave` option value should be True, False or None. Got: 'yes'") + def test_2(self): + db = Database() + class Foo(db.Entity): + x = Required('Bar', interleave='yes') + class Bar(db.Entity): + y = Set('Foo') + + @raises_exception(TypeError, 'only one attribute may be marked as interleave. Got: Foo.x, Foo.y') + def test_3(self): + db = Database() + class Foo(db.Entity): + x = Required(int, interleave=True) + y = Required(int, interleave=True) + + @raises_exception(TypeError, 'Interleave attribute should be part of relationship. Got: Foo.x') + def test_4(self): + db = Database() + class Foo(db.Entity): + x = Required(int, interleave=True) + + def test_5(self): + db = Database(**db_params) + class Bar(db.Entity): + y = Set('Foo') + + class Foo(db.Entity): + x = Required('Bar', interleave=True) + id = Required(int) + PrimaryKey(x, id) + + db.generate_mapping(create_tables=True) + s = ') INTERLEAVE IN PARENT "bar" ("x")' + self.assertIn(s, db.schema.tables['foo'].get_create_command()) + db.drop_all_tables() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/pony/orm/tests/test_isinstance.py b/pony/orm/tests/test_isinstance.py new file mode 100644 index 000000000..e7597a74c --- /dev/null +++ b/pony/orm/tests/test_isinstance.py @@ -0,0 +1,119 @@ +from datetime import date +from decimal import Decimal + +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() + + +class Person(db.Entity): + id = PrimaryKey(int, auto=True) + name = Required(str) + dob = Optional(date) + ssn = Required(str, unique=True) + + +class Student(Person): + group = Required("Group") + mentor = Optional("Teacher") + attend_courses = Set("Course") + + +class Teacher(Person): + teach_courses = Set("Course") + apprentices = Set("Student") + salary = Required(Decimal) + + +class Assistant(Student, Teacher): + pass + + +class Professor(Teacher): + position = Required(str) + + +class Group(db.Entity): + number = PrimaryKey(int) + students = Set("Student") + + +class Course(db.Entity): + name = Required(str) + semester = Required(int) + students = Set(Student) + teachers = Set(Teacher) + PrimaryKey(name, semester) + + +class TestVolatile(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p = Person(name='Person1', ssn='SSN1') + g = Group(number=123) + prof = Professor(name='Professor1', salary=1000, position='position1', ssn='SSN5') + a1 = Assistant(name='Assistant1', group=g, salary=100, ssn='SSN4', mentor=prof) + a2 = Assistant(name='Assistant2', group=g, salary=200, ssn='SSN6', mentor=prof) + s1 = Student(name='Student1', group=g, ssn='SSN2', mentor=a1) + s2 = Student(name='Student2', group=g, ssn='SSN3') + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def test_1(self): + q = select(p.name for p in Person if isinstance(p, Student)) + self.assertEqual(set(q), {'Student1', 'Student2', 'Assistant1', 'Assistant2'}) + + @db_session + def test_2(self): + q = select(p.name for p in Person if not isinstance(p, Student)) + self.assertEqual(set(q), {'Person1', 'Professor1'}) + + @db_session + def test_3(self): + q = select(p.name for p in Student if isinstance(p, Professor)) + self.assertEqual(set(q), set()) + + @db_session + def test_4(self): + q = select(p.name for p in Person if not isinstance(p, Person)) + self.assertEqual(set(q), set()) + + @db_session + def test_5(self): + q = select(p.name for p in Person if isinstance(p, (Student, Teacher))) + self.assertEqual(set(q), {'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) + + @db_session + def test_6(self): + q = select(p.name for p in Person if isinstance(p, Student) and isinstance(p, Teacher)) + self.assertEqual(set(q), {'Assistant1', 'Assistant2'}) + + @db_session + def test_7(self): + q = select(p.name for p in Person + if (isinstance(p, Student) and p.ssn == 'SSN2') + or (isinstance(p, Professor) and p.salary > 500)) + self.assertEqual(set(q), {'Student1', 'Professor1'}) + + @db_session + def test_8(self): + q = select(p.name for p in Person if isinstance(p, Person)) + self.assertEqual(set(q), {'Person1', 'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) + + @db_session + def test_9(self): + q = select(g.number for g in Group if isinstance(g, Group)) + self.assertEqual(set(q), {123}) + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py new file mode 100644 index 000000000..4c1dfa2cc --- /dev/null +++ b/pony/orm/tests/test_json.py @@ -0,0 +1,686 @@ +from pony.py23compat import basestring, pickle + +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import raises_exception, raises_if +from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict +from pony.orm.tests import setup_database, teardown_database + +db = Database() + + +class Product(db.Entity): + name = Required(str) + info = Optional(Json) + tags = Optional(Json) + + +class TestJson(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + def setUp(self): + with db_session: + Product.select().delete(bulk=True) + flush() + Product( + name='Apple iPad Air 2', + info={ + 'name': 'Apple iPad Air 2', + 'display': { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }, + 'os': { + 'type': 'iOS', + 'version': '8' + }, + 'cpu': 'Apple A8X', + 'ram': '8GB', + 'colors': ['Gold', 'Silver', 'Space Gray'], + 'models': [ + { + 'name': 'Wi-Fi', + 'capacity': ['16GB', '64GB'], + 'height': 240, + 'width': 169.5, + 'depth': 6.1, + 'weight': 437, + }, + { + 'name': 'Wi-Fi + Cellular', + 'capacity': ['16GB', '64GB'], + 'height': 240, + 'width': 169.5, + 'depth': 6.1, + 'weight': 444, + }, + ], + 'discontinued': False, + 'videoUrl': None, + 'non-ascii-attr': u'\u0442\u0435\u0441\u0442' + }, + tags=['Tablets', 'Apple', 'Retina']) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test(self): + with db_session: + result = select(p for p in Product)[:] + self.assertEqual(len(result), 1) + p = result[0] + p.info['os']['version'] = '9' + with db_session: + result = select(p for p in Product)[:] + self.assertEqual(len(result), 1) + p = result[0] + self.assertEqual(p.info['os']['version'], '9') + + @db_session + def test_query_int(self): + val = get(p.info['display']['resolution'][0] for p in Product) + self.assertEqual(val, 2048) + + @db_session + def test_query_float(self): + val = get(p.info['display']['size'] for p in Product) + self.assertAlmostEqual(val, 9.7) + + @db_session + def test_query_true(self): + val = get(p.info['display']['multi-touch'] for p in Product) + self.assertIs(val, True) + + @db_session + def test_query_false(self): + val = get(p.info['discontinued'] for p in Product) + self.assertIs(val, False) + + @db_session + def test_query_null(self): + val = get(p.info['videoUrl'] for p in Product) + self.assertIs(val, None) + + @db_session + def test_query_list(self): + val = get(p.info['colors'] for p in Product) + self.assertListEqual(val, ['Gold', 'Silver', 'Space Gray']) + self.assertNotIsInstance(val, TrackedValue) + + @db_session + def test_query_dict(self): + val = get(p.info['display'] for p in Product) + self.assertDictEqual(val, { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }) + self.assertNotIsInstance(val, TrackedValue) + + @db_session + def test_query_json_field(self): + val = get(p.info for p in Product) + self.assertDictEqual(val['display'], { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }) + self.assertNotIsInstance(val['display'], TrackedDict) + val = get(p.tags for p in Product) + self.assertListEqual(val, ['Tablets', 'Apple', 'Retina']) + self.assertNotIsInstance(val, TrackedList) + + @db_session + def test_get_object(self): + p = get(p for p in Product) + self.assertDictEqual(p.info['display'], { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }) + self.assertEqual(p.info['discontinued'], False) + self.assertEqual(p.info['videoUrl'], None) + self.assertListEqual(p.tags, ['Tablets', 'Apple', 'Retina']) + self.assertIsInstance(p.info, TrackedDict) + self.assertIsInstance(p.info['display'], TrackedDict) + self.assertIsInstance(p.info['colors'], TrackedList) + self.assertIsInstance(p.tags, TrackedList) + + def test_set_str(self): + with db_session: + p = get(p for p in Product) + p.info['os']['version'] = '9' + with db_session: + p = get(p for p in Product) + self.assertEqual(p.info['os']['version'], '9') + + def test_set_int(self): + with db_session: + p = get(p for p in Product) + p.info['display']['resolution'][0] += 1 + with db_session: + p = get(p for p in Product) + self.assertEqual(p.info['display']['resolution'][0], 2049) + + def test_set_true(self): + with db_session: + p = get(p for p in Product) + p.info['discontinued'] = True + with db_session: + p = get(p for p in Product) + self.assertIs(p.info['discontinued'], True) + + def test_set_false(self): + with db_session: + p = get(p for p in Product) + p.info['display']['multi-touch'] = False + with db_session: + p = get(p for p in Product) + self.assertIs(p.info['display']['multi-touch'], False) + + def test_set_null(self): + with db_session: + p = get(p for p in Product) + p.info['display'] = None + with db_session: + p = get(p for p in Product) + self.assertIs(p.info['display'], None) + + def test_set_list(self): + with db_session: + p = get(p for p in Product) + p.info['colors'] = ['Pink', 'Black'] + with db_session: + p = get(p for p in Product) + self.assertListEqual(p.info['colors'], ['Pink', 'Black']) + + def test_list_del(self): + with db_session: + p = get(p for p in Product) + del p.info['colors'][1] + with db_session: + p = get(p for p in Product) + self.assertListEqual(p.info['colors'], ['Gold', 'Space Gray']) + + def test_list_append(self): + with db_session: + p = get(p for p in Product) + p.info['colors'].append('White') + with db_session: + p = get(p for p in Product) + self.assertListEqual(p.info['colors'], ['Gold', 'Silver', 'Space Gray', 'White']) + + def test_list_set_slice(self): + with db_session: + p = get(p for p in Product) + p.info['colors'][1:] = ['White'] + with db_session: + p = get(p for p in Product) + self.assertListEqual(p.info['colors'], ['Gold', 'White']) + + def test_list_set_item(self): + with db_session: + p = get(p for p in Product) + p.info['colors'][1] = 'White' + with db_session: + p = get(p for p in Product) + self.assertListEqual(p.info['colors'], ['Gold', 'White', 'Space Gray']) + + def test_set_dict(self): + with db_session: + p = get(p for p in Product) + p.info['display']['resolution'] = {'width': 2048, 'height': 1536} + with db_session: + p = get(p for p in Product) + self.assertDictEqual(p.info['display']['resolution'], {'width': 2048, 'height': 1536}) + + def test_dict_del(self): + with db_session: + p = get(p for p in Product) + del p.info['os']['version'] + with db_session: + p = get(p for p in Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS'}) + + def test_dict_pop(self): + with db_session: + p = get(p for p in Product) + p.info['os'].pop('version') + with db_session: + p = get(p for p in Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS'}) + + def test_dict_update(self): + with db_session: + p = get(p for p in Product) + p.info['os'].update(version='9') + with db_session: + p = get(p for p in Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) + + def test_dict_set_item(self): + with db_session: + p = get(p for p in Product) + p.info['os']['version'] = '9' + with db_session: + p = get(p for p in Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) + + @db_session + def test_set_same_value(self): + p = get(p for p in Product) + p.info = p.info + + @db_session + def test_len(self): + with raises_if(self, db.provider.dialect == 'Oracle', + TranslationError, 'Oracle does not provide `length` function for JSON arrays'): + val = select(len(p.tags) for p in Product).first() + self.assertEqual(val, 3) + val = select(len(p.info['colors']) for p in Product).first() + self.assertEqual(val, 3) + + @db_session + def test_equal_str(self): + p = get(p for p in Product if p.info['name'] == 'Apple iPad Air 2') + self.assertTrue(p) + + @db_session + def test_unicode_key(self): + p = get(p for p in Product if p.info[u'name'] == 'Apple iPad Air 2') + self.assertTrue(p) + + @db_session + def test_equal_string_attr(self): + p = get(p for p in Product if p.info['name'] == p.name) + self.assertTrue(p) + + @db_session + def test_equal_param(self): + x = 'Apple iPad Air 2' + p = get(p for p in Product if p.name == x) + self.assertTrue(p) + + @db_session + def test_composite_param(self): + with raises_if(self, db.provider.dialect == 'Oracle', + TranslationError, "Oracle doesn't allow parameters in JSON paths"): + key = 'models' + index = 0 + val = get(p.info[key][index]['name'] for p in Product) + self.assertEqual(val, 'Wi-Fi') + + @db_session + def test_composite_param_in_condition(self): + with raises_if(self, db.provider.dialect == 'Oracle', + TranslationError, "Oracle doesn't allow parameters in JSON paths"): + key = 'models' + index = 0 + p = get(p for p in Product if p.info[key][index]['name'] == 'Wi-Fi') + self.assertIsNotNone(p) + + @db_session + def test_equal_json_1(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: " + "p.info['os'] == {'type':'iOS', 'version':'8'}"): + p = get(p for p in Product if p.info['os'] == {'type': 'iOS', 'version': '8'}) + self.assertTrue(p) + + @db_session + def test_equal_json_2(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: " + "p.info['os'] == Json({'type':'iOS', 'version':'8'})"): + p = get(p for p in Product if p.info['os'] == Json({'type': 'iOS', 'version': '8'})) + self.assertTrue(p) + + @db_session + def test_ne_json_1(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['os'] != {}"): + p = get(p for p in Product if p.info['os'] != {}) + self.assertTrue(p) + p = get(p for p in Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) + self.assertFalse(p) + + @db_session + def test_ne_json_2(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['os'] != Json({})"): + p = get(p for p in Product if p.info['os'] != Json({})) + self.assertTrue(p) + p = get(p for p in Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) + self.assertFalse(p) + + @db_session + def test_equal_list_1(self): + colors = ['Gold', 'Silver', 'Space Gray'] + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): + p = get(p for p in Product if p.info['colors'] == Json(colors)) + self.assertTrue(p) + + @db_session + @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == ['Gold']") + def test_equal_list_2(self): + p = get(p for p in Product if p.info['colors'] == ['Gold']) + + @db_session + def test_equal_list_3(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): + p = get(p for p in Product if p.info['colors'] != Json(['Gold'])) + self.assertIsNotNone(p) + + @db_session + def test_equal_list_4(self): + colors = ['Gold', 'Silver', 'Space Gray'] + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): + p = get(p for p in Product if p.info['colors'] == Json(colors)) + self.assertTrue(p) + + @db_session + @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == []") + def test_equal_empty_list_1(self): + p = get(p for p in Product if p.info['colors'] == []) + + @db_session + def test_equal_empty_list_2(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] == Json([])"): + p = get(p for p in Product if p.info['colors'] == Json([])) + self.assertIsNone(p) + + @db_session + def test_ne_list(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): + p = get(p for p in Product if p.info['colors'] != Json(['Gold'])) + self.assertTrue(p) + + @db_session + def test_ne_empty_list(self): + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] != Json([])"): + p = get(p for p in Product if p.info['colors'] != Json([])) + self.assertTrue(p) + + @db_session + def test_dbval2val(self): + p = select(p for p in Product)[:][0] + attr = Product.info + val = p._vals_[attr] + dbval = p._dbvals_[attr] + self.assertIsInstance(dbval, basestring) + self.assertIsInstance(val, TrackedValue) + p.info['os']['version'] = '9' + self.assertIs(val, p._vals_[attr]) + self.assertIs(dbval, p._dbvals_[attr]) + p.flush() + self.assertIs(val, p._vals_[attr]) + self.assertNotEqual(dbval, p._dbvals_[attr]) + + @db_session + def test_wildcard_path_1(self): + with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), + TranslationError, '...does not support wildcards in JSON path...'): + names = get(p.info['models'][:]['name'] for p in Product) + self.assertSetEqual(set(names), {'Wi-Fi', 'Wi-Fi + Cellular'}) + + @db_session + def test_wildcard_path_2(self): + with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), + TranslationError, '...does not support wildcards in JSON path...'): + values = get(p.info['os'][...] for p in Product) + self.assertSetEqual(set(values), {'iOS', '8'}) + + @db_session + def test_wildcard_path_3(self): + with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), + TranslationError, '...does not support wildcards in JSON path...'): + names = get(p.info[...][0]['name'] for p in Product) + self.assertSetEqual(set(names), {'Wi-Fi'}) + + @db_session + def test_wildcard_path_4(self): + if db.provider.dialect == 'Oracle': + raise unittest.SkipTest + with raises_if(self, db.provider.dialect != 'MySQL', + TranslationError, '...does not support wildcards in JSON path...'): + values = get(p.info[...][:][...][:] for p in Product)[:] + self.assertSetEqual(set(values), {'16GB', '64GB'}) + + @db_session + def test_wildcard_path_with_params(self): + if db.provider.dialect != 'Oracle': + exc_msg = '...does not support wildcards in JSON path...' + else: + exc_msg = "Oracle doesn't allow parameters in JSON paths" + with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, exc_msg): + key = 'models' + index = 0 + values = get(p.info[key][:]['capacity'][index] for p in Product) + self.assertListEqual(values, ['16GB', '16GB']) + + @db_session + def test_wildcard_path_with_params_as_string(self): + if db.provider.dialect != 'Oracle': + exc_msg = '...does not support wildcards in JSON path...' + else: + exc_msg = "Oracle doesn't allow parameters in JSON paths" + with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, exc_msg): + key = 'models' + index = 0 + values = get("p.info[key][:]['capacity'][index] for p in Product") + self.assertListEqual(values, ['16GB', '16GB']) + + @db_session + def test_wildcard_path_in_condition(self): + errors = { + 'MySQL': 'Wildcards are not allowed in json_contains()', + 'SQLite': '...does not support wildcards in JSON path...', + 'PostgreSQL': '...does not support wildcards in JSON path...' + } + dialect = db.provider.dialect + with raises_if(self, dialect in errors, TranslationError, errors.get(dialect)): + p = get(p for p in Product if '16GB' in p.info['models'][:]['capacity']) + self.assertTrue(p) + + ##### 'key' in json + + @db_session + def test_in_dict(self): + obj = get(p for p in Product if 'resolution' in p.info['display']) + self.assertTrue(obj) + + @db_session + def test_not_in_dict(self): + obj = get(p for p in Product if 'resolution' not in p.info['display']) + self.assertIs(obj, None) + obj = get(p for p in Product if 'xyz' not in p.info['display']) + self.assertTrue(obj) + + @db_session + def test_in_list(self): + obj = get(p for p in Product if 'Gold' in p.info['colors']) + self.assertTrue(obj) + + @db_session + def test_not_in_list(self): + obj = get(p for p in Product if 'White' not in p.info['colors']) + self.assertTrue(obj) + obj = get(p for p in Product if 'Gold' not in p.info['colors']) + self.assertIs(obj, None) + + @db_session + def test_var_in_json(self): + with raises_if(self, db.provider.dialect == 'Oracle', + TypeError, "For `key in JSON` operation Oracle supports literal key values only, " + "parameters are not allowed: key in p.info['colors']"): + key = 'Gold' + obj = get(p for p in Product if key in p.info['colors']) + self.assertTrue(obj) + + @db_session + def test_select_first(self): + # query should not contain ORDER BY + obj = select(p.info for p in Product).first() + self.assertNotIn('order by', db.last_sql.lower()) + + def test_sql_inject(self): + # test quote in json is not causing error + with db_session: + p = select(p for p in Product).first() + p.info['display']['size'] = "0' 9.7\"" + with db_session: + p = select(p for p in Product).first() + self.assertEqual(p.info['display']['size'], "0' 9.7\"") + + @db_session + def test_int_compare(self): + p = get(p for p in Product if p.info['display']['resolution'][0] == 2048) + self.assertTrue(p) + p = get(p for p in Product if p.info['display']['resolution'][0] != 2048) + self.assertIsNone(p) + p = get(p for p in Product if p.info['display']['resolution'][0] < 2048) + self.assertIs(p, None) + p = get(p for p in Product if p.info['display']['resolution'][0] <= 2048) + self.assertTrue(p) + p = get(p for p in Product if p.info['display']['resolution'][0] > 2048) + self.assertIs(p, None) + p = get(p for p in Product if p.info['display']['resolution'][0] >= 2048) + self.assertTrue(p) + + @db_session + def test_float_compare(self): + p = get(p for p in Product if p.info['display']['size'] > 9.5) + self.assertTrue(p) + p = get(p for p in Product if p.info['display']['size'] < 9.8) + self.assertTrue(p) + p = get(p for p in Product if p.info['display']['size'] < 9.5) + self.assertIsNone(p) + p = get(p for p in Product if p.info['display']['size'] > 9.8) + self.assertIsNone(p) + + @db_session + def test_str_compare(self): + p = get(p for p in Product if p.info['ram'] == '8GB') + self.assertTrue(p) + p = get(p for p in Product if p.info['ram'] != '8GB') + self.assertIsNone(p) + p = get(p for p in Product if p.info['ram'] < '9GB') + self.assertTrue(p) + p = get(p for p in Product if p.info['ram'] > '7GB') + self.assertTrue(p) + p = get(p for p in Product if p.info['ram'] > '9GB') + self.assertIsNone(p) + p = get(p for p in Product if p.info['ram'] < '7GB') + self.assertIsNone(p) + + @db_session + def test_bool_compare(self): + p = get(p for p in Product if p.info['display']['multi-touch'] == True) + self.assertTrue(p) + p = get(p for p in Product if p.info['display']['multi-touch'] is True) + self.assertTrue(p) + p = get(p for p in Product if p.info['display']['multi-touch'] == False) + self.assertIsNone(p) + p = get(p for p in Product if p.info['display']['multi-touch'] is False) + self.assertIsNone(p) + p = get(p for p in Product if p.info['discontinued'] == False) + self.assertTrue(p) + p = get(p for p in Product if p.info['discontinued'] == True) + self.assertIsNone(p) + + @db_session + def test_none_compare(self): + p = get(p for p in Product if p.info['videoUrl'] is None) + self.assertTrue(p) + p = get(p for p in Product if p.info['videoUrl'] is not None) + self.assertIsNone(p) + + @db_session + def test_none_for_nonexistent_path(self): + p = get(p for p in Product if p.info['some_attr'] is None) + self.assertTrue(p) + + @db_session + def test_str_cast(self): + p = get(coalesce(str(p.name), 'empty') for p in Product) + last_sql = db.last_sql + if db.provider.dialect == 'PostgreSQL': + self.assertTrue(')::text' in last_sql) + else: + self.assertTrue('AS text' in db.last_sql) + + @db_session + def test_int_cast(self): + p = get(coalesce(int(p.info['os']['version']), 0) for p in Product) + last_sql = db.last_sql + if db.provider.dialect == 'PostgreSQL': + self.assertTrue(')::int' in last_sql) + else: + self.assertTrue('as integer' in last_sql) + + + def test_nonzero(self): + with db_session: + delete(p for p in Product) + Product(name='P1', info=dict(id=1, val=True)) + Product(name='P2', info=dict(id=2, val=False)) + Product(name='P3', info=dict(id=3, val=0)) + Product(name='P4', info=dict(id=4, val=1)) + Product(name='P5', info=dict(id=5, val='')) + Product(name='P6', info=dict(id=6, val='x')) + Product(name='P7', info=dict(id=7, val=[])) + Product(name='P8', info=dict(id=8, val=[1, 2, 3])) + Product(name='P9', info=dict(id=9, val={})) + Product(name='P10', info=dict(id=10, val={'a': 'b'})) + Product(name='P11', info=dict(id=11)) + Product(name='P12', info=dict(id=12, val='True')) + Product(name='P13', info=dict(id=13, val='False')) + Product(name='P14', info=dict(id=14, val='0')) + Product(name='P15', info=dict(id=15, val='1')) + Product(name='P16', info=dict(id=16, val='""')) + Product(name='P17', info=dict(id=17, val='[]')) + Product(name='P18', info=dict(id=18, val='{}')) + + with db_session: + val = select(p.info['id'] for p in Product if not p.info['val']) + self.assertEqual(tuple(sorted(val)), (2, 3, 5, 7, 9, 11)) + + @db_session + def test_optimistic_check(self): + p1 = Product.select().first() + p1.info['foo'] = 'bar' + flush() + p1.name = 'name2' + flush() + p1.name = 'name3' + flush() + + @db_session + def test_avg(self): + result = select(avg(p.info['display']['size']) for p in Product).first() + self.assertAlmostEqual(result, 9.7) + + def test_pickle(self): + with db_session: + p1 = Product.select().first() + data = pickle.dumps(p1) + with db_session: + p1 = pickle.loads(data) + p1.name = 'name2' + flush() + rollback() diff --git a/pony/orm/tests/test_lazy.py b/pony/orm/tests/test_lazy.py index 4a0314a46..8144e2a8e 100644 --- a/pony/orm/tests/test_lazy.py +++ b/pony/orm/tests/test_lazy.py @@ -3,19 +3,24 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database + class TestLazy(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + db = self.db = Database() class X(self.db.Entity): a = Required(int) b = Required(unicode, lazy=True) self.X = X - self.db.generate_mapping(create_tables=True) + setup_database(db) with db_session: - x1 = X(a=1, b='first') - x2 = X(a=2, b='second') - x3 = X(a=3, b='third') + x1 = X(id=1, a=1, b='first') + x2 = X(id=2, a=2, b='second') + x3 = X(id=3, a=3, b='third') + + def tearDown(self): + teardown_database(self.db) @db_session def test_lazy_1(self): diff --git a/pony/orm/tests/test_mapping.py b/pony/orm/tests/test_mapping.py index a86bfc85d..28a5fb444 100644 --- a/pony/orm/tests/test_mapping.py +++ b/pony/orm/tests/test_mapping.py @@ -5,13 +5,19 @@ from pony.orm.core import * from pony.orm.dbschema import DBSchemaError from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, only_for + +@only_for('sqlite') class TestColumnsMapping(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + # raise exception if mapping table by default is not found @raises_exception(OperationalError, 'no such table: Student') def test_table_check1(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = PrimaryKey(str) sql = "drop table if exists Student;" @@ -21,7 +27,7 @@ class Student(db.Entity): # no exception if table was specified def test_table_check2(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = PrimaryKey(str) sql = """ @@ -38,7 +44,7 @@ class Student(db.Entity): # raise exception if specified mapping table is not found @raises_exception(OperationalError, 'no such table: Table1') def test_table_check3(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): _table_ = 'Table1' name = PrimaryKey(str) @@ -46,7 +52,7 @@ class Student(db.Entity): # no exception if table was specified def test_table_check4(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): _table_ = 'Table1' name = PrimaryKey(str) @@ -64,7 +70,7 @@ class Student(db.Entity): # 'id' field created if primary key is not defined @raises_exception(OperationalError, 'no such column: Student.id') def test_table_check5(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = Required(str) sql = """ @@ -79,7 +85,7 @@ class Student(db.Entity): # 'id' field created if primary key is not defined def test_table_check6(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = Required(str) sql = """ @@ -96,7 +102,7 @@ class Student(db.Entity): @raises_exception(DBSchemaError, "Column 'name' already exists in table 'Student'") def test_table_check7(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = Required(str, column='name') record = Required(str, column='name') @@ -113,7 +119,7 @@ class Student(db.Entity): # user can specify column name for an attribute def test_custom_column_name(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = PrimaryKey(str, column='name1') sql = """ @@ -131,7 +137,7 @@ class Student(db.Entity): @raises_exception(ERDiagramError, 'At least one attribute of one-to-one relationship Entity1.attr1 - Entity2.attr2 must be optional') def test_relations1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2") @@ -142,7 +148,7 @@ class Entity2(db.Entity): # no exception Optional-Required def test_relations2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2") @@ -153,7 +159,7 @@ class Entity2(db.Entity): # no exception Optional-Required(column) def test_relations3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2", column='a') @@ -163,7 +169,7 @@ class Entity2(db.Entity): db.generate_mapping(create_tables=True) def test_relations4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2") @@ -176,7 +182,7 @@ class Entity2(db.Entity): # no exception Optional-Optional def test_relations5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2") @@ -187,7 +193,7 @@ class Entity2(db.Entity): # no exception Optional-Optional(column) def test_relations6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a') @@ -197,7 +203,7 @@ class Entity2(db.Entity): db.generate_mapping(create_tables=True) def test_relations7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a') @@ -209,7 +215,7 @@ class Entity2(db.Entity): self.assertEqual(Entity2.attr2.columns, ['a1']) def test_columns1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(int) attr1 = Set("Entity2") @@ -223,7 +229,7 @@ class Entity2(db.Entity): self.assertEqual(column_list[1].name, 'attr2') def test_columns2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -240,7 +246,7 @@ class Entity2(db.Entity): self.assertEqual(column_list[2].name, 'attr2_b') def test_columns3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') @@ -252,7 +258,7 @@ class Entity2(db.Entity): self.assertEqual(Entity2.attr2.columns, []) def test_columns4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional('Entity1') @@ -265,14 +271,14 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, "Mapping is not generated for entity 'E1'") def test_generate_mapping1(self): - db = Database('sqlite', ':memory:') + db = self.db class E1(db.Entity): a1 = Required(int) select(e for e in E1) @raises_exception(ERDiagramError, "Mapping is not generated for entity 'E1'") def test_generate_mapping2(self): - db = Database('sqlite', ':memory:') + db = self.db class E1(db.Entity): a1 = Required(int) e = E1(a1=1) diff --git a/pony/orm/tests/test_objects_to_save_cleanup.py b/pony/orm/tests/test_objects_to_save_cleanup.py index 5c6e5d74a..e00b09b23 100644 --- a/pony/orm/tests/test_objects_to_save_cleanup.py +++ b/pony/orm/tests/test_objects_to_save_cleanup.py @@ -1,34 +1,35 @@ - - import unittest - from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -db = Database() - -class TestPost(db.Entity): - category = Optional('TestCategory') - name = Optional(str, default='Noname') +class EntityStatusTestCase(object): + @classmethod + def setUpClass(cls): + db = cls.db = Database() -class TestCategory(db.Entity): - posts = Set(TestPost) + class TestPost(db.Entity): + category = Optional('TestCategory') + name = Optional(str, default='Noname') -db.bind('sqlite', ':memory:') -db.generate_mapping(create_tables=True) + class TestCategory(db.Entity): + posts = Set(TestPost) + setup_database(db) -class EntityStatusTestCase(object): + @classmethod + def tearDownClass(cls): + teardown_database(cls.db) def make_flush(self, obj=None): raise NotImplementedError @db_session def test_delete_updated(self): - p = TestPost() + p = self.db.TestPost() self.make_flush(p) p.name = 'Pony' - self.assertEqual(p._status_, 'modified') + self.assertEqual(p._status_, 'modified') self.make_flush(p) self.assertEqual(p._status_, 'updated') p.delete() @@ -38,7 +39,7 @@ def test_delete_updated(self): @db_session def test_delete_inserted(self): - p = TestPost() + p = self.db.TestPost() self.assertEqual(p._status_, 'created') self.make_flush(p) self.assertEqual(p._status_, 'inserted') @@ -46,7 +47,7 @@ def test_delete_inserted(self): @db_session def test_cancelled(self): - p = TestPost() + p = self.db.TestPost() self.assertEqual(p._status_, 'created') p.delete() self.assertEqual(p._status_, 'cancelled') @@ -54,16 +55,15 @@ def test_cancelled(self): self.assertEqual(p._status_, 'cancelled') - class EntityStatusTestCase_ObjectFlush(EntityStatusTestCase, unittest.TestCase): def make_flush(self, obj=None): obj.flush() - + class EntityStatusTestCase_FullFlush(EntityStatusTestCase, unittest.TestCase): def make_flush(self, obj=None): - flush() # full flush \ No newline at end of file + flush() # full flush diff --git a/pony/orm/tests/test_prefetching.py b/pony/orm/tests/test_prefetching.py index 685e1435a..d77cb2870 100644 --- a/pony/orm/tests/test_prefetching.py +++ b/pony/orm/tests/test_prefetching.py @@ -4,8 +4,10 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(str) @@ -14,30 +16,48 @@ class Student(db.Entity): dob = Optional(date) group = Required('Group') courses = Set('Course') + mentor = Optional('Teacher') biography = Optional(LongStr) + class Group(db.Entity): number = PrimaryKey(int) - major = Required(str) + major = Required(str, lazy=True) students = Set(Student) + class Course(db.Entity): name = Required(str, unique=True) students = Set(Student) -db.generate_mapping(create_tables=True) -with db_session: - g1 = Group(number=1, major='Math') - g2 = Group(number=2, major='Computer Sciense') - c1 = Course(name='Math') - c2 = Course(name='Physics') - c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) +class Teacher(db.Entity): + name = Required(str) + students = Set(Student) + class TestPrefetching(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=1, major='Math') + g2 = Group(number=2, major='Computer Sciense') + c1 = Course(name='Math') + c2 = Course(name='Physics') + c3 = Course(name='Computer Science') + t1 = Teacher(name='T1') + t2 = Teacher(name='T2') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio', mentor=t1) + Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio') + Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3], mentor=t2) + Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def test_1(self): with db_session: s1 = Student.select().first() @@ -53,13 +73,13 @@ def test_2(self): def test_3(self): with db_session: - s1 = Student.select().prefetch(Group).first() + s1 = Student.select().prefetch(Group, Group.major).first() g = s1.group self.assertEqual(g.major, 'Math') def test_4(self): with db_session: - s1 = Student.select().prefetch(Student.group).first() + s1 = Student.select().prefetch(Student.group, Group.major).first() g = s1.group self.assertEqual(g.major, 'Math') @@ -76,7 +96,7 @@ def test_6(self): def test_7(self): with db_session: - name, group = select((s.name, s.group) for s in Student).prefetch(Group).first() + name, group = select((s.name, s.group) for s in Student).prefetch(Group, Group.major).first() self.assertEqual(group.major, 'Math') @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Student[1].courses: the database session is over') @@ -94,7 +114,7 @@ def test_9(self): def test_10(self): with db_session: s1 = Student.select().prefetch(Student.courses).first() - self.assertEqual(set(s1.courses.name), set(['Math', 'Physics'])) + self.assertEqual(set(s1.courses.name), {'Math', 'Physics'}) @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].biography: the database session is over') def test_11(self): @@ -105,11 +125,100 @@ def test_11(self): def test_12(self): with db_session: s1 = Student.select().prefetch(Student.biography).first() - self.assertEqual(s1.biography, 'some text') - self.assertEqual(db.last_sql, '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."biography" -FROM "Student" "s" + self.assertEqual(s1.biography, 'S1 bio') + table_name = 'Student' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'student' + expected_sql = '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."mentor", "s"."biography" +FROM "%s" "s" ORDER BY 1 -LIMIT 1''') +LIMIT 1''' % table_name + if db.provider.dialect == 'SQLite' and pony.__version__ >= '0.9': + expected_sql = expected_sql.replace('"', '`') + self.assertEqual(db.last_sql, expected_sql) + + def test_13(self): + db.merge_local_stats() + with db_session: + q = select(g for g in Group) + for g in q: # 1 query + for s in g.students: # 2 query + b = s.biography # 5 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 8) + + def test_14(self): + db.merge_local_stats() + with db_session: + q = select(g for g in Group).prefetch(Group.students) + for g in q: # 1 query + for s in g.students: # 1 query + b = s.biography # 5 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 7) + + def test_15(self): + with db_session: + q = select(g for g in Group).prefetch(Group.students) + q[:] + db.merge_local_stats() + with db_session: + q = select(g for g in Group).prefetch(Group.students, Student.biography) + for g in q: # 1 query + for s in g.students: # 1 query + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 2) + + def test_16(self): + db.merge_local_stats() + with db_session: + q = select(c for c in Course).prefetch(Course.students, Student.biography) + for c in q: # 1 query + for s in c.students: # 2 queries (as it is many-to-many relationship) + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 3) + + def test_17(self): + db.merge_local_stats() + with db_session: + q = select(c for c in Course).prefetch(Course.students, Student.biography, Group, Group.major) + for c in q: # 1 query + for s in c.students: # 2 queries (as it is many-to-many relationship) + m = s.group.major # 1 query + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 4) + + def test_18(self): + db.merge_local_stats() + with db_session: + q = Group.select().prefetch(Group.students, Student.biography) + for g in q: # 2 queries + for s in g.students: + m = s.mentor # 0 queries + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 2) + + def test_19(self): + db.merge_local_stats() + with db_session: + q = Group.select().prefetch(Group.students, Student.biography, Student.mentor) + mentors = set() + for g in q: # 3 queries + for s in g.students: + m = s.mentor # 0 queries + if m is not None: + mentors.add(m) + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 3) + + for m in mentors: + n = m.name # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 3) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 46d68f19d..6623463da 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PYPY2, pickle import unittest from datetime import date @@ -6,8 +7,10 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import teardown_database, setup_database + +db = Database() -db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) @@ -16,49 +19,57 @@ class Student(db.Entity): group = Required('Group') dob = Optional(date) + class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - Student(id=1, name='S1', group=g1, gpa=3.1) - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2)) class TestQuery(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=1) + Student(id=1, name='S1', group=g1, gpa=3.1) + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2)) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() - @raises_exception(TypeError, 'Cannot iterate over non-entity object') + @raises_exception(TypeError, "Query can only iterate over entity or another query (not a list of objects)") def test1(self): select(s for s in []) - @raises_exception(TypeError, 'Cannot iterate over non-entity object X') + @raises_exception(TypeError, "Cannot iterate over non-entity object X") def test2(self): X = [1, 2, 3] select('x for x in X') - @raises_exception(TypeError, "Cannot iterate over non-entity object") def test3(self): g = Group[1] - select(s for s in g.students) - @raises_exception(ExprEvalError, "a raises NameError: name 'a' is not defined") + students = select(s for s in g.students) + self.assertEqual(set(g.students), set(students)) + @raises_exception(ExprEvalError, "`a` raises NameError: global name 'a' is not defined" if PYPY2 else + "`a` raises NameError: name 'a' is not defined") def test4(self): select(a for s in Student) - @raises_exception(TypeError, "Incomparable types '%s' and 'list' in expression: s.name == x" % unicode.__name__) + @raises_exception(TypeError, "Incomparable types '%s' and 'StrArray' in expression: s.name == x" % unicode.__name__) def test5(self): x = ['A'] select(s for s in Student if s.name == x) - @raises_exception(TypeError, "f1(s.gpa)") def test6(self): def f1(x): - return x + 1 - select(s for s in Student if f1(s.gpa) > 3) - @raises_exception(NotImplementedError, "m1(s.gpa, 1) > 3") + return float(x) + 1 + students = select(s for s in Student if f1(s.gpa) > 4.25)[:] + self.assertEqual({s.id for s in students}, {3}) + @raises_exception(NotImplementedError, "m1") def test7(self): class C1(object): def method1(self, a, b): @@ -117,26 +128,21 @@ def test22(self): def test23(self): r = max(s.dob.year for s in Student) self.assertEqual(r, 2001) - @db_session def test_first1(self): q = select(s for s in Student).order_by(Student.gpa) self.assertEqual(q.first(), Student[1]) - @db_session def test_first2(self): q = select((s.name, s.group) for s in Student) self.assertEqual(q.first(), ('S1', Group[1])) - @db_session def test_first3(self): q = select(s for s in Student) self.assertEqual(q.first(), Student[1]) - @db_session def test_closures_1(self): def find_by_gpa(gpa): return lambda s: s.gpa > gpa fn = find_by_gpa(Decimal('3.1')) students = list(Student.select(fn)) self.assertEqual(students, [ Student[2], Student[3] ]) - @db_session def test_closures_2(self): def find_by_gpa(gpa): return lambda s: s.gpa > gpa @@ -144,6 +150,30 @@ def find_by_gpa(gpa): q = select(s for s in Student) q = q.filter(fn) self.assertEqual(list(q), [ Student[2], Student[3] ]) + @raises_exception(NameError, 'Free variable `gpa` referenced before assignment in enclosing scope') + def test_closures_3(self): + def find_by_gpa(): + if False: + gpa = Decimal('3.1') + return lambda s: s.gpa > gpa + fn = find_by_gpa() + students = list(Student.select(fn)) + def test_pickle(self): + objects = select(s for s in Student if s.scholarship > 0).order_by(desc(Student.id)) + data = pickle.dumps(objects) + rollback() + objects = pickle.loads(data) + self.assertEqual([obj.id for obj in objects], [3, 2]) + def test_bulk_delete_clear_query_cache(self): + students1 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] + self.assertEqual([s.id for s in students1], [2, 3]) + Student.select(lambda s: s.id < 3).delete(bulk=True) + students2 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] + self.assertEqual([s.id for s in students2], [3]) + rollback() + students1 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] + self.assertEqual([s.id for s in students1], [2, 3]) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_random.py b/pony/orm/tests/test_random.py index 9a47efaa1..e0bcef9c6 100644 --- a/pony/orm/tests/test_random.py +++ b/pony/orm/tests/test_random.py @@ -2,23 +2,31 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) name = Required(unicode) -db.generate_mapping(create_tables=True) - -with db_session: - Person(id=1, name='John') - Person(id=2, name='Mary') - Person(id=3, name='Bob') - Person(id=4, name='Mike') - Person(id=5, name='Ann') class TestRandom(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Person(id=1, name='John') + Person(id=2, name='Mary') + Person(id=3, name='Bob') + Person(id=4, name='Mike') + Person(id=5, name='Ann') + + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test_1(self): persons = Person.select().random(2) diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index 56dce042e..aa12eee7e 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -1,12 +1,15 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PYPY2 import unittest from datetime import date from pony.orm import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) @@ -14,32 +17,39 @@ class Person(db.Entity): age = Required(int) dob = Required(date) -db.generate_mapping(create_tables=True) - -with db_session: - Person(id=1, name='John', age=30, dob=date(1985, 1, 1)) - Person(id=2, name='Mike', age=32, dob=date(1983, 5, 20)) - Person(id=3, name='Mary', age=20, dob=date(1995, 2, 15)) +@only_for('sqlite') class TestRawSQL(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Person(id=1, name='John', age=30, dob=date(1985, 1, 1)) + Person(id=2, name='Mike', age=32, dob=date(1983, 5, 20)) + Person(id=3, name='Mary', age=20, dob=date(1995, 2, 15)) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test_1(self): # raw_sql result can be treated as a logical expression persons = select(p for p in Person if raw_sql('abs("p"."age") > 25'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_2(self): # raw_sql result can be used for comparison persons = select(p for p in Person if raw_sql('abs("p"."age")') > 25)[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_3(self): # raw_sql can accept $parameters x = 25 persons = select(p for p in Person if raw_sql('abs("p"."age") > $x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_4(self): @@ -47,7 +57,7 @@ def test_4(self): x = 1 s = 'p.id > $x' persons = select(p for p in Person if raw_sql(s))[:] - self.assertEqual(set(persons), set([Person[2], Person[3]])) + self.assertEqual(set(persons), {Person[2], Person[3]}) @db_session def test_5(self): @@ -55,14 +65,14 @@ def test_5(self): x = 1 cond = raw_sql('p.id > $x') persons = select(p for p in Person if cond)[:] - self.assertEqual(set(persons), set([Person[2], Person[3]])) + self.assertEqual(set(persons), {Person[2], Person[3]}) @db_session def test_6(self): # correct converter should be applied to raw_sql parameter type x = date(1990, 1, 1) persons = select(p for p in Person if raw_sql('p.dob < $x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_7(self): @@ -70,19 +80,19 @@ def test_7(self): x = 10 y = 15 persons = select(p for p in Person if raw_sql('p.age > $(x + y)'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_8(self): # raw_sql argument may be complex expression (2) persons = select(p for p in Person if raw_sql('p.dob < $date.today()'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2], Person[3]])) + self.assertEqual(set(persons), {Person[1], Person[2], Person[3]}) @db_session def test_9(self): # using raw_sql in the expression part of the generator names = select(raw_sql('UPPER(p.name)') for p in Person)[:] - self.assertEqual(set(names), set(['JOHN', 'MIKE', 'MARY'])) + self.assertEqual(set(names), {'JOHN', 'MIKE', 'MARY'}) @db_session def test_10(self): @@ -101,21 +111,21 @@ def test_12(self): # raw_sql can be used in lambdas x = 25 persons = Person.select(lambda p: p.age > raw_sql('$x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_13(self): # raw_sql in filter() x = 25 persons = select(p for p in Person).filter(lambda p: p.age > raw_sql('$x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_14(self): # raw_sql in filter() without using lambda x = 25 persons = Person.select().filter(raw_sql('p.age > $x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_15(self): @@ -123,7 +133,7 @@ def test_15(self): x = '123' y = 'John' persons = Person.select(lambda p: raw_sql("UPPER(p.name) || $x") == raw_sql("UPPER($y || '123')"))[:] - self.assertEqual(set(persons), set([Person[1]])) + self.assertEqual(set(persons), {Person[1]}) @db_session def test_16(self): @@ -135,7 +145,7 @@ def test_16(self): y = 'j' q = q.filter(lambda p: p.dob > x and p.name.startswith(raw_sql('UPPER($y)'))) persons = q[:] - self.assertEqual(set(persons), set([Person[1]])) + self.assertEqual(set(persons), {Person[1]}) @db_session def test_17(self): @@ -152,13 +162,16 @@ def test_18(self): self.assertEqual(persons, [Person[1], Person[3], Person[2]]) @db_session - @raises_exception(TypeError, "raw_sql(p.name)") + @raises_exception(TranslationError, "Expression `raw_sql(p.name)` cannot be translated into SQL " + "because raw SQL fragment will be different for each row") def test_19(self): # raw_sql argument cannot depend on iterator variables select(p for p in Person if raw_sql(p.name))[:] @db_session - @raises_exception(ExprEvalError, "raw_sql('p.dob < $x') raises NameError: name 'x' is not defined") + @raises_exception(ExprEvalError, + "`raw_sql('p.dob < $x')` raises NameError: global name 'x' is not defined" if PYPY2 else + "`raw_sql('p.dob < $x')` raises NameError: name 'x' is not defined") def test_20(self): # testing for situation where parameter variable is missing select(p for p in Person if raw_sql('p.dob < $x'))[:] diff --git a/pony/orm/tests/test_relations_m2m.py b/pony/orm/tests/test_relations_m2m.py index c4fd15224..dc0552653 100644 --- a/pony/orm/tests/test_relations_m2m.py +++ b/pony/orm/tests/test_relations_m2m.py @@ -2,26 +2,30 @@ import unittest from pony.orm.core import * +from pony.orm.tests import db_params, teardown_database -class TestManyToManyNonComposite(unittest.TestCase): +db = Database() - def setUp(self): - db = Database('sqlite', ':memory:') - class Group(db.Entity): - number = PrimaryKey(int) - subjects = Set("Subject") +class Group(db.Entity): + number = PrimaryKey(int) + subjects = Set("Subject") - class Subject(db.Entity): - name = PrimaryKey(str) - groups = Set(Group) - self.db = db - self.Group = Group - self.Subject = Subject - - self.db.generate_mapping(create_tables=True) +class Subject(db.Entity): + name = PrimaryKey(str) + groups = Set(Group) + +class TestManyToManyNonComposite(unittest.TestCase): + @classmethod + def setUpClass(cls): + db.bind(**db_params) + db.generate_mapping(check_tables=False) + db.drop_all_tables(with_all_data=True) + + def setUp(self): + db.create_tables() with db_session: g1 = Group(number=101) g2 = Group(number=102) @@ -31,9 +35,25 @@ class Subject(db.Entity): s4 = Subject(name='Subj4') g1.subjects = [ s1, s2 ] + def tearDown(self): + teardown_database(db) + def test_1(self): - db, Group, Subject = self.db, self.Group, self.Subject + schema = db.schema + m2m_table_name = 'Group_Subject' + if not (db.provider.dialect == 'SQLite' and pony.__version__ < '0.9'): + m2m_table_name = m2m_table_name.lower() + self.assertIn(m2m_table_name, schema.tables) + m2m_table = schema.tables[m2m_table_name] + if pony.__version__ >= '0.9': + fkeys = m2m_table.foreign_keys + else: + fkeys = set(m2m_table.foreign_keys.values()) + self.assertEqual(len(fkeys), 2) + for fk in fkeys: + self.assertEqual(fk.on_delete, 'CASCADE') + def test_2(self): with db_session: g = Group.get(number=101) s = Subject.get(name='Subj1') @@ -41,11 +61,9 @@ def test_1(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - - def test_2(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) + def test_3(self): with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') @@ -53,11 +71,9 @@ def test_2(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2', 'Subj3']) - - def test_3(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2', 'Subj3'}) + def test_4(self): with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') @@ -65,11 +81,9 @@ def test_3(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - - def test_4(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) + def test_5(self): with db_session: g = Group.get(number=101) s = Subject.get(name='Subj2') @@ -77,11 +91,9 @@ def test_4(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1']) - - def test_5(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1'}) + def test_6(self): with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -90,12 +102,10 @@ def test_5(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj3', 'Subj4']) - self.assertEqual(Group[101].subjects, set([Subject['Subj3'], Subject['Subj4']])) - - def test_6(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj3', 'Subj4'}) + self.assertEqual(Group[101].subjects, {Subject['Subj3'], Subject['Subj4']}) + def test_7(self): with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') @@ -106,11 +116,9 @@ def test_6(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - - def test_7(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) + def test_8(self): with db_session: g = Group.get(number=101) s = Subject.get(name='Subj1') @@ -121,11 +129,9 @@ def test_7(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - - def test_8(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) + def test_9(self): with db_session: g = Group.get(number=101) s1 = Subject.get(name='Subj1') @@ -137,11 +143,9 @@ def test_8(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - - def test_9(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) + def test_10(self): with db_session: g2 = Group.get(number=102) s1 = Subject.get(name='Subj1') @@ -154,9 +158,7 @@ def test_9(self): db_subjects = db.select('subject from Group_Subject where "group" = 102') self.assertEqual(db_subjects , []) - def test_10(self): - db, Group, Subject = self.db, self.Group, self.Subject - + def test_11(self): with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -164,11 +166,9 @@ def test_10(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj2', 'Subj3']) - - def test_11(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj2', 'Subj3'}) + def test_12(self): with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -179,11 +179,9 @@ def test_11(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - - def test_12(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) + def test_13(self): with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -194,12 +192,10 @@ def test_12(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) @db_session - def test_13(self): - db, Group, Subject = self.db, self.Group, self.Subject - + def test_14(self): g1 = Group[101] s1 = Subject['Subj1'] self.assertTrue(s1 in g1.subjects) @@ -208,7 +204,7 @@ def test_13(self): self.assertTrue(s1 in group_setdata) self.assertEqual(group_setdata.added, None) self.assertEqual(group_setdata.removed, None) - + subj_setdata = s1._vals_[Subject.groups] self.assertTrue(g1 in subj_setdata) self.assertEqual(subj_setdata.added, None) @@ -217,11 +213,11 @@ def test_13(self): g1.subjects.remove(s1) self.assertTrue(s1 not in group_setdata) self.assertEqual(group_setdata.added, None) - self.assertEqual(group_setdata.removed, set([ s1 ])) + self.assertEqual(group_setdata.removed, {s1}) self.assertTrue(g1 not in subj_setdata) self.assertEqual(subj_setdata.added, None) - self.assertEqual(subj_setdata.removed, set([ g1 ])) - + self.assertEqual(subj_setdata.removed, {g1}) + g1.subjects.add(s1) self.assertTrue(s1 in group_setdata) self.assertEqual(group_setdata.added, set()) @@ -231,9 +227,7 @@ def test_13(self): self.assertEqual(subj_setdata.removed, set()) @db_session - def test_14(self): - db, Group, Subject = self.db, self.Group, self.Subject - + def test_15(self): g = Group[101] e = g.subjects.is_empty() self.assertEqual(e, False) @@ -253,9 +247,7 @@ def test_14(self): self.assertEqual(db.last_sql, None) @db_session - def test_15(self): - db, Group = self.db, self.Group - + def test_16(self): g = Group[101] c = len(g.subjects) self.assertEqual(c, 2) @@ -263,7 +255,7 @@ def test_15(self): e = g.subjects.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) - + g = Group[102] c = len(g.subjects) self.assertEqual(c, 0) @@ -273,9 +265,7 @@ def test_15(self): self.assertEqual(db.last_sql, None) @db_session - def test_16(self): - db, Group, Subject = self.db, self.Group, self.Subject - + def test_17(self): g = Group[101] s1 = Subject['Subj1'] s3 = Subject['Subj3'] diff --git a/pony/orm/tests/test_relations_one2many.py b/pony/orm/tests/test_relations_one2many.py index 204ee6423..9b1efe800 100644 --- a/pony/orm/tests/test_relations_one2many.py +++ b/pony/orm/tests/test_relations_one2many.py @@ -4,11 +4,12 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -class TestOneToManyRequired(unittest.TestCase): +class TestOneToManyRequired(unittest.TestCase): def setUp(self): - db = Database('sqlite', ':memory:', create_db=True) + db = Database() class Student(db.Entity): id = PrimaryKey(int) @@ -23,7 +24,7 @@ class Group(db.Entity): self.Group = Group self.Student = Student - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: g101 = Group(number=101) @@ -39,6 +40,7 @@ class Group(db.Entity): def tearDown(self): rollback() db_session.__exit__() + teardown_database(self.db) @raises_exception(ValueError, 'Attribute Student[1].group is required') def test_1(self): @@ -65,7 +67,7 @@ def test_4(self): s1, s2, s3, s4 = Student.select().order_by(Student.id) g1, g2 = Group[101], Group[102] g1.students = g2.students - self.assertEqual(set(g1.students), set([s3, s4])) + self.assertEqual(set(g1.students), {s3, s4}) self.assertEqual(s1._status_, 'marked_to_delete') self.assertEqual(s2._status_, 'marked_to_delete') @@ -74,7 +76,7 @@ def test_5(self): Group, Student = self.Group, self.Student g = Group[101] g.students.add(None) - + @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_6(self): Group, Student = self.Group, self.Student @@ -127,7 +129,7 @@ def test_10(self): e = g.students.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) - + g = Group[102] c = g.students.count() self.assertEqual(c, 2) @@ -223,7 +225,7 @@ def tearDown(self): def test_1(self): self.Student[1].group = None - self.assertEqual(set(self.Group[101].students), set([self.Student[2]])) + self.assertEqual(set(self.Group[101].students), {self.Student[2]}) def test_2(self): Student, Group = self.Student, self.Group @@ -246,7 +248,7 @@ def test_4(self): s1, s2, s3, s4 = Student.select().order_by(Student.id) g1, g2 = Group[101], Group[102] g1.students = g2.students - self.assertEqual(set(g1.students), set([s3, s4])) + self.assertEqual(set(g1.students), {s3, s4}) self.assertEqual(s1.group, None) self.assertEqual(s2.group, None) @@ -255,7 +257,7 @@ def test_5(self): Group, Student = self.Group, self.Student g = Group[101] g.students.add(None) - + @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_6(self): Group, Student = self.Group, self.Student diff --git a/pony/orm/tests/test_relations_one2one1.py b/pony/orm/tests/test_relations_one2one1.py index 5971ac48e..046555a8b 100644 --- a/pony/orm/tests/test_relations_one2one1.py +++ b/pony/orm/tests/test_relations_one2one1.py @@ -2,20 +2,30 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Male(db.Entity): name = Required(unicode) wife = Optional('Female', column='wife') + class Female(db.Entity): name = Required(unicode) husband = Optional('Male') -db.generate_mapping(create_tables=True) class TestOneToOne(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('delete from male') @@ -115,16 +125,16 @@ def test_8(self): @db_session def test_9(self): - f4 = Female(name='F4') - m4 = Male(name='M4', wife=f4) + f4 = Female(id=4, name='F4') + m4 = Male(id=4, name='M4', wife=f4) flush() self.assertEqual(f4._status_, 'inserted') self.assertEqual(m4._status_, 'inserted') @db_session def test_10(self): - m4 = Male(name='M4') - f4 = Female(name='F4', husband=m4) + m4 = Male(id=4, name='M4') + f4 = Female(id=4, name='F4', husband=m4) flush() self.assertEqual(f4._status_, 'inserted') self.assertEqual(m4._status_, 'inserted') diff --git a/pony/orm/tests/test_relations_one2one2.py b/pony/orm/tests/test_relations_one2one2.py index 4add12bb8..7a7d9809e 100644 --- a/pony/orm/tests/test_relations_one2one2.py +++ b/pony/orm/tests/test_relations_one2one2.py @@ -4,20 +4,30 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import teardown_database, setup_database + +db = Database() -db = Database('sqlite', ':memory:') class Male(db.Entity): name = Required(unicode) wife = Optional('Female', column='wife') + class Female(db.Entity): name = Required(unicode) husband = Optional('Male', column='husband') -db.generate_mapping(create_tables=True) class TestOneToOne2(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('update female set husband=null') @@ -121,7 +131,8 @@ def test_8(self): self.assertEqual([2, None, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 1, None], husbands) - @raises_exception(UnrepeatableReadError, 'Value of Male.wife for Male[1] was updated outside of current transaction') + @raises_exception(UnrepeatableReadError, 'Multiple Male objects linked with the same Female[1] object. ' + 'Maybe Female.husband attribute should be Set instead of Optional') def test_9(self): db.execute('update female set husband = 3 where id = 1') m1 = Male[1] diff --git a/pony/orm/tests/test_relations_one2one3.py b/pony/orm/tests/test_relations_one2one3.py index 269ccd1e9..09aa1d2c8 100644 --- a/pony/orm/tests/test_relations_one2one3.py +++ b/pony/orm/tests/test_relations_one2one3.py @@ -4,10 +4,13 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database, only_for + +@only_for('sqlite') class TestOneToOne3(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + self.db = Database() class Person(self.db.Entity): name = Required(unicode) @@ -17,14 +20,14 @@ class Passport(self.db.Entity): code = Required(unicode) person = Required("Person") - self.db.generate_mapping(create_tables=True) + setup_database(self.db) with db_session: p1 = Person(name='John') Passport(code='123', person=p1) def tearDown(self): - self.db = None + teardown_database(self.db) @db_session def test_1(self): @@ -38,9 +41,9 @@ def test_2(self): sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" - LEFT JOIN "Passport" "passport-1" - ON "p"."id" = "passport-1"."person" -WHERE "passport-1"."id" IS NULL''') + LEFT JOIN "Passport" "passport" + ON "p"."id" = "passport"."person" +WHERE "passport"."id" IS NULL''') @db_session def test_3(self): @@ -48,9 +51,9 @@ def test_3(self): sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" - LEFT JOIN "Passport" "passport-1" - ON "p"."id" = "passport-1"."person" -WHERE "passport-1"."id" IS NULL''') + LEFT JOIN "Passport" "passport" + ON "p"."id" = "passport"."person" +WHERE "passport"."id" IS NULL''') @db_session def test_4(self): @@ -58,9 +61,9 @@ def test_4(self): sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" - LEFT JOIN "Passport" "passport-1" - ON "p"."id" = "passport-1"."person" -WHERE "passport-1"."id" IS NOT NULL''') + LEFT JOIN "Passport" "passport" + ON "p"."id" = "passport"."person" +WHERE "passport"."id" IS NOT NULL''') @db_session def test_5(self): @@ -70,7 +73,7 @@ def test_5(self): sql = self.db.last_sql self.assertEqual(sql, '''DELETE FROM "Person" WHERE "id" = ? - AND "name" = ?''') + AND "name" = ?''') @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session diff --git a/pony/orm/tests/test_relations_one2one4.py b/pony/orm/tests/test_relations_one2one4.py index 0bbb8306f..f9813bea3 100644 --- a/pony/orm/tests/test_relations_one2one4.py +++ b/pony/orm/tests/test_relations_one2one4.py @@ -4,35 +4,35 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -class TestOneToOne4(unittest.TestCase): - def setUp(self): - self.db = Database('sqlite', ':memory:') - - class Person(self.db.Entity): - name = Required(unicode) - passport = Optional("Passport") +db = Database() - class Passport(self.db.Entity): - code = Required(unicode) - person = Required("Person") +class Person(db.Entity): + name = Required(unicode) + passport = Optional("Passport") - self.db.generate_mapping(create_tables=True) +class Passport(db.Entity): + code = Required(unicode) + person = Required("Person") +class TestOneToOne4(unittest.TestCase): + def setUp(self): + setup_database(db) with db_session: - p1 = Person(name='John') - Passport(code='123', person=p1) + p1 = Person(id=1, name='John') + Passport(id=1, code='123', person=p1) def tearDown(self): - self.db = None + teardown_database(db) @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session def test1(self): - p2 = self.db.Person(name='Mike') - pas2 = self.db.Passport(code='456', person=p2) + p2 = Person(id=2, name='Mike') + pas2 = Passport(id=2, code='456', person=p2) commit() - p1 = self.db.Person.get(name='John') + p1 = Person.get(name='John') pas2.person = p1 if __name__ == '__main__': diff --git a/pony/orm/tests/test_relations_symmetric_m2m.py b/pony/orm/tests/test_relations_symmetric_m2m.py index c09d40742..79f49b699 100644 --- a/pony/orm/tests/test_relations_symmetric_m2m.py +++ b/pony/orm/tests/test_relations_symmetric_m2m.py @@ -2,15 +2,25 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) friends = Set('Person', reverse='friends') -db.generate_mapping(create_tables=True) + class TestSymmetricM2M(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: for p in Person.select(): p.delete() @@ -32,12 +42,12 @@ def test1a(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) - self.assertEqual(set(p4.friends), set([p1])) + self.assertEqual(set(p4.friends), {p1}) def test1b(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) - self.assertEqual(set(p1.friends), set([Person[2], Person[3], p4])) + self.assertEqual(set(p1.friends), {Person[2], Person[3], p4}) def test1c(self): p1 = Person[1] p4 = Person[4] @@ -49,12 +59,12 @@ def test2a(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) - self.assertEqual(set(p1.friends), set([Person[3]])) + self.assertEqual(set(p1.friends), {Person[3]}) def test2b(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) - self.assertEqual(set(Person[3].friends), set([p1])) + self.assertEqual(set(Person[3].friends), {p1}) def test2c(self): p1 = Person[1] p2 = Person[2] @@ -84,7 +94,7 @@ def test3b(self): p1 = Person[1] p2 = Person[2] p1_friends = set(p1.friends) - self.assertEqual(p1_friends, set([p2])) + self.assertEqual(p1_friends, {p2}) try: p2_friends = set(p2.friends) except UnrepeatableReadError as e: self.assertEqual(e.args[0], "Phantom object Person[1] disappeared from collection Person[2].friends") diff --git a/pony/orm/tests/test_relations_symmetric_one2one.py b/pony/orm/tests/test_relations_symmetric_one2one.py index 232b3190a..9dc72f20f 100644 --- a/pony/orm/tests/test_relations_symmetric_one2one.py +++ b/pony/orm/tests/test_relations_symmetric_one2one.py @@ -4,16 +4,25 @@ from pony.orm.core import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) spouse = Optional('Person', reverse='spouse') -db.generate_mapping(create_tables=True) class TestSymmetricOne2One(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('update person set spouse=null') @@ -63,8 +72,9 @@ def test3(self): self.assertEqual([3, None, 1, None, None], data) def test4(self): persons = set(select(p for p in Person if p.spouse.name in ('B', 'D'))) - self.assertEqual(persons, set([Person[1], Person[3]])) - @raises_exception(UnrepeatableReadError, 'Value of Person.spouse for Person[1] was updated outside of current transaction') + self.assertEqual(persons, {Person[1], Person[3]}) + @raises_exception(UnrepeatableReadError, 'Multiple Person objects linked with the same Person[2] object. ' + 'Maybe Person.spouse attribute should be Set instead of Optional') def test5(self): db.execute('update person set spouse = 3 where id = 2') p1 = Person[1] diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py new file mode 100644 index 000000000..5a2bee399 --- /dev/null +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -0,0 +1,404 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * +from pony.py23compat import PYPY2 +from pony.orm.tests import setup_database, teardown_database + +db = Database() + + +class Group(db.Entity): + number = PrimaryKey(int) + major = Required(str) + students = Set('Student') + + +class Student(db.Entity): + first_name = Required(unicode) + last_name = Required(unicode) + age = Required(int) + group = Required('Group') + scholarship = Required(int, default=0) + courses = Set('Course') + + @property + def full_name(self): + return self.first_name + ' ' + self.last_name + + +class Course(db.Entity): + name = Required(unicode) + semester = Required(int) + credits = Required(int) + PrimaryKey(name, semester) + students = Set('Student') + + +class TestSelectFromSelect(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=123, major='Computer Science') + g2 = Group(number=456, major='Graphic Design') + s1 = Student(id=1, first_name='John', last_name='Smith', age=20, group=g1, scholarship=0) + s2 = Student(id=2, first_name='Alex', last_name='Green', age=24, group=g1, scholarship=100) + s3 = Student(id=3, first_name='Mary', last_name='White', age=23, group=g1, scholarship=500) + s4 = Student(id=4, first_name='John', last_name='Brown', age=20, group=g2, scholarship=400) + s5 = Student(id=5, first_name='Bruce', last_name='Lee', age=22, group=g2, scholarship=300) + c1 = Course(name='Math', semester=1, credits=10, students=[s1, s2, s4]) + c2 = Course(name='Computer Science', semester=1, credits=20, students=[s2, s3]) + c3 = Course(name='3D Modeling', semester=2, credits=15, students=[s3, s5]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def test_1(self): # basic select from another query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(s for s in q if s.scholarship < 500) + self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) # single SELECT...FROM expression + + @db_session + def test_2(self): # different variable name in the second query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x for x in q if x.scholarship < 500) + self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_3(self): # selecting single column instead of entity in the second query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x.first_name for x in q if x.scholarship < 500) + self.assertEqual(set(q2), {'Alex', 'Bruce', 'John'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_4(self): # selecting single column instead of entity in the first query + q = select(s.first_name for s in Student if s.scholarship > 0) + q2 = select(name for name in q if 'r' in name) + self.assertEqual(set(q2), {'Bruce', 'Mary'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_5(self): # selecting hybrid property in the second query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x.full_name for x in q if x.scholarship < 500) + self.assertEqual(set(q2), {'Alex Green', 'Bruce Lee', 'John Brown'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_6(self): # selecting hybrid property in the first query + q = select(s.full_name for s in Student if s.scholarship < 500) + q2 = select(x for x in q if x.startswith('J')) + self.assertEqual(set(q2), {'John Smith', 'John Brown'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + @raises_exception(ExprEvalError, "`s.scholarship > 0` raises NameError: name 's' is not defined" if not PYPY2 + else "`s.scholarship > 0` raises NameError: global name 's' is not defined") + def test_7(self): # test access to original query var name from the new query + q = select(s.first_name for s in Student if s.scholarship < 500) + q2 = select(x for x in q if s.scholarship > 0) + + @db_session + def test_8(self): # test using external name which is equal to original query var name + class Dummy(object): + scholarship = 1 + s = Dummy() + q = select(s.first_name for s in Student if s.scholarship < 500) + q2 = select(x for x in q if s.scholarship > 0) + self.assertEqual(set(q2), {'John', 'Alex', 'Bruce'}) + + @db_session + def test_9(self): # test reusing variable name from the original query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x for x in q for s in Student if x.scholarship < s.scholarship) + self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_10(self): # test .filter() + q = select(s for s in Student if s.scholarship > 0) + q2 = q.filter(lambda a: a.scholarship < 500) + q3 = select(x for x in q2 if x.age > 20) + q4 = q3.filter(lambda b: b.age < 24) + self.assertEqual(set(s.first_name for s in q4), {'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_11(self): # test .where() + q = select(s for s in Student if s.scholarship > 0) + q2 = q.where(lambda s: s.scholarship < 500) + q3 = select(x for x in q2 if x.age > 20) + q4 = q3.where(lambda x: x.age < 24) # the name should be accessible in previous generator + self.assertEqual(set(s.first_name for s in q4), {'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + @raises_exception(TypeError, 'Lambda argument `s` does not correspond to any variable in original query') + def test_12(self): # test .where() + q = select(s for s in Student if s.scholarship > 0) + q2 = q.where(lambda s: s.scholarship < 500) + q3 = select(x for x in q2 if x.age > 20) + q4 = q3.where(lambda s: s.age < 24) + + @db_session + def test_13(self): # select several expressions from the first query + q = select((s.full_name, s.age) for s in Student if s.scholarship > 0) + q2 = select(name for name, age in q if age < 24 and 'e' in name) + self.assertEqual(set(q2), {'Mary White', 'Bruce Lee'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_14(self): # select from entity with composite key + q = select(c for c in Course if c.semester == 1) + q2 = select(x.name for x in q if x.name.startswith('M')) + self.assertEqual(set(q2), {'Math'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_15(self): # SELECT ... FROM (SELECT alias.* FROM ... + q = left_join(s for g in Group for s in g.students if g.number == 123 and s.scholarship > 0) + q2 = select(x.full_name for x in q if x.scholarship > 100) + self.assertEqual(set(q2), {'Mary White'}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertTrue('*' in db.last_sql) + + @db_session + def test_16(self): # SELECT ... FROM (grouped-query) + q = select(g for g in Group if count(g.students) > 2) + q2 = select(x.number for x in q) + + self.assertEqual(set(q2), {123}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertEqual(db.last_sql.count('GROUP BY'), 1) + self.assertEqual(db.last_sql.count('HAVING'), 1) + self.assertTrue('WHERE' not in db.last_sql) + + @db_session + def test_17(self): # SELECT ... FROM (grouped-query), t1 WHERE ... + q = select(g for g in Group if count(g.students) > 2) + q2 = select(x.major for x in q) + + self.assertEqual(set(q2), {'Computer Science'}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertEqual(db.last_sql.count('GROUP BY'), 1) + self.assertEqual(db.last_sql.count('HAVING'), 1) + + @db_session + def test_18(self): # SELECT ... FROM (grouped-query returns composite keys), t1 WHERE ... + q = select((c, count(c.students)) for c in Course if c.semester == 1 and count(c.students) > 1) + q2 = select((x.name, x.credits, y) for x, y in q if x.credits > 10 and y < 3) + + self.assertEqual(set(q2), {('Computer Science', 20, 2)}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertEqual(db.last_sql.count('GROUP BY'), 1) + self.assertEqual(db.last_sql.count('HAVING'), 1) + self.assertEqual(db.last_sql.count('WHERE'), 2) + + @db_session + def test_19(self): # multiple for loops in the inner query + q = select((g, s.first_name.lower()) for g in Group for s in g.students) + q2 = select((g.major, n) for g, n in q if g.number == 123 and n[0] == 'j') + self.assertEqual(set(q2), {('Computer Science', 'john')}) + + @db_session + def test_20(self): # additional for loop with inlined subquery + q = select((g, x.first_name.upper()) + for g in Group + for x in select(s for s in Student if s.age < 22) + if x.group == g and g.number == 123 and x.first_name[0] == 'J') + q2 = select(name for g, name in q if g.number == 123) + self.assertEqual(set(q2), {'JOHN'}) + + @db_session + def test_21(self): + objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result + q = select(s.first_name for s in Student if s not in objects) + self.assertEqual(set(q), {'John', 'Alex'}) + + @db_session + @raises_exception(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') + def test_22(self): + objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result + q = select(s.first_name for s in objects) + + @db_session + def test_23(self): + q = select(s for s in Student) + q2 = q.filter(lambda x: x.scholarship > 450) + q3 = q2.where(lambda s: s.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_24(self): + q = select(s for s in Student) + q2 = q.where(lambda s: s.scholarship > 450) + q3 = q2.filter(lambda x: x.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_25(self): + q = Student.select().filter(lambda x: x.scholarship > 450) + q2 = select(s for s in q) + q3 = q2.where(lambda s: s.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_26(self): + q = Student.select().filter(lambda x: x.scholarship > 450) + q2 = q.where(lambda s: s.scholarship < 520) + q3 = select(s for s in q2) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_27(self): + q = Student.select().where(lambda s: s.scholarship > 450) + q2 = select(s for s in q) + q3 = q2.filter(lambda x: x.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_28(self): + q = Student.select().where(lambda s: s.scholarship > 450) + q2 = q.filter(lambda x: x.scholarship < 520) + q3 = select(s for s in q2) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_29(self): + q = select(s for s in Student) + q2 = q.where(lambda s: s.scholarship > 450) + q3 = q2.where(lambda s: s.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_30(self): + q = select(s for s in Student) + q2 = q.filter(lambda x: x.scholarship > 450) + q3 = q2.filter(lambda z: z.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_31(self): + q = select(s for s in Student).order_by(lambda s: s.scholarship) + q2 = q.where(lambda s: s.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_32(self): + q = select(s for s in Student).order_by(lambda s: s.scholarship) + q2 = q.filter(lambda z: z.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_33(self): + q = select(s for s in Student).sort_by(lambda x: x.scholarship) + q2 = q.where(lambda s: s.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_34(self): + q = select(s for s in Student).sort_by(lambda x: x.scholarship) + q2 = q.filter(lambda s: s.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_35(self): + q = select(s for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s not in q) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + + @db_session + def test_36(self): + q = select(s for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s not in q[:]) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_37(self): + q = select(s.last_name for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s.last_name not in q) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + + @db_session + def test_38(self): + q = select(s.last_name for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s.last_name not in q[:]) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_39(self): + q = select((s.first_name, s.last_name) for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if (s.first_name, s.last_name) not in q) + self.assertEqual(set(q2), {1}) + self.assertTrue(db.last_sql.count('SELECT') > 1) + + # @db_session + # def test_40(self): # TODO + # q = select((s.first_name, s.last_name) for s in Student if s.scholarship > 0) + # q2 = select(s.id for s in Student if (s.first_name, s.last_name) not in q[:]) + # self.assertEqual(set(q2), {1}) + # self.assertTrue(db.last_sql.count('SELECT'), 1) + + @db_session + def test_41(self): + def f1(): + x = 21 + return select(s for s in Student if s.age > x) + + def f2(q): + x = 23 + return select(s.last_name for s in Student if s.age < x and s in q) + + q = f1() + q2 = f2(q) + self.assertEqual(set(q2), {'Lee'}) + + @db_session + def test_42(self): + q = select(s for s in Student if s.scholarship > 0) + q2 = select(g for g in Group if g.major == 'Computer Science')[:] + q3 = select(s.first_name for s in q if s.group in q2) + self.assertEqual(set(q3), {'Alex', 'Mary'}) + + @db_session + def test_43(self): + q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) + q2 = select(s.first_name for s in Student if s in q) + self.assertEqual(set(q2), {'John', 'Bruce'}) + + @db_session + def test_44(self): + q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) + q2 = select(s.first_name for s in q) + self.assertEqual(list(q2), ['Bruce', 'John', 'John']) + + @db_session + def test_45(self): + q = select(s for s in Student).order_by(Student.first_name, Student.id).limit(3, offset=1) + q2 = select(s for s in q if s.age > 18).limit(2, offset=1) + q3 = select(s.last_name for s in q2).limit(2, offset=1) + self.assertEqual(set(q3), {'Brown'}) + + @db_session + def test_46(self): + q = select((c, count(c.students)) for c in Course).order_by(-2, 1).limit(2) + q2 = select((c.name, c.credits, m) for c, m in q).limit(1, offset=1) + self.assertEqual(set(q2), {('3D Modeling', 15, 2)}) + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_show.py b/pony/orm/tests/test_show.py index dc12832ac..e5846b800 100644 --- a/pony/orm/tests/test_show.py +++ b/pony/orm/tests/test_show.py @@ -6,8 +6,10 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) @@ -26,21 +28,29 @@ class Course(db.Entity): name = Required(unicode, unique=True) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - g2 = Group(number=2) - c1 = Course(name='Math') - c2 = Course(name='Physics') - c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) normal_stdout = sys.stdout + class TestShow(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + c1 = Course(name='Math') + c2 = Course(name='Physics') + c3 = Course(name='Computer Science') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -70,5 +80,6 @@ def test_2(self): 2~~~~~ ''') + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_sqlbuilding_formatstyles.py b/pony/orm/tests/test_sqlbuilding_formatstyles.py index 2a975959f..7f2a482dd 100644 --- a/pony/orm/tests/test_sqlbuilding_formatstyles.py +++ b/pony/orm/tests/test_sqlbuilding_formatstyles.py @@ -6,10 +6,11 @@ from pony.orm.dbapiprovider import DBAPIProvider from pony.orm.tests.testutils import TestPool + class TestFormatStyles(unittest.TestCase): def setUp(self): - self.key1 = object() - self.key2 = object() + self.key1 = 'KEY1' + self.key2 = 'KEY2' self.provider = DBAPIProvider(pony_pool_mockup=TestPool(None)) self.ast = [ SELECT, [ ALL, [COLUMN, None, 'A']], [ FROM, [None, TABLE, 'T1']], [ WHERE, [ EQ, [COLUMN, None, 'B'], [ PARAM, self.key1 ] ], @@ -24,35 +25,35 @@ def test_qmark(self): self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = ?\n AND "C" = ?\n AND "D" = ?\n AND "E" = ?') - self.assertEqual(b.layout, (self.key1, self.key2, self.key2, self.key1)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_numeric(self): self.provider.paramstyle = 'numeric' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = :1\n AND "C" = :2\n AND "D" = :2\n AND "E" = :1') - self.assertEqual(b.layout, (self.key1, self.key2)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_named(self): self.provider.paramstyle = 'named' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = :p1\n AND "C" = :p2\n AND "D" = :p2\n AND "E" = :p1') - self.assertEqual(b.layout, (self.key1, self.key2)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_format(self): self.provider.paramstyle = 'format' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = %s\n AND "C" = %s\n AND "D" = %s\n AND "E" = %s') - self.assertEqual(b.layout, (self.key1, self.key2, self.key2, self.key1)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_pyformat(self): self.provider.paramstyle = 'pyformat' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = %(p1)s\n AND "C" = %(p2)s\n AND "D" = %(p2)s\n AND "E" = %(p1)s') - self.assertEqual(b.layout, (self.key1, self.key2)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) if __name__ == "__main__": diff --git a/pony/orm/tests/test_sqlbuilding_sqlast.py b/pony/orm/tests/test_sqlbuilding_sqlast.py index 3e4cae9c9..9d4493863 100644 --- a/pony/orm/tests/test_sqlbuilding_sqlast.py +++ b/pony/orm/tests/test_sqlbuilding_sqlast.py @@ -3,10 +3,14 @@ import unittest from pony.orm.core import Database, db_session from pony.orm.sqlsymbols import * +from pony.orm.tests import setup_database, only_for + +@only_for('sqlite') class TestSQLAST(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + self.db = Database() + setup_database(self.db) with db_session: conn = self.db.get_connection() conn.executescript(""" @@ -16,6 +20,13 @@ def setUp(self): ); insert or ignore into T1 values(1, 'abc'); """) + + def tearDown(self): + with db_session: + conn = self.db.get_connection() + conn.executescript("""drop table T1 + """) + @db_session def test_alias(self): sql_ast = [SELECT, [ALL, [COLUMN, "Group", "a"]], @@ -29,5 +40,6 @@ def test_alias2(self): sql, adapter = self.db._ast2sql(sql_ast) cursor = self.db._exec_sql(sql) + if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_sqlite_str_functions.py b/pony/orm/tests/test_sqlite_str_functions.py index 2592d57d3..656004472 100644 --- a/pony/orm/tests/test_sqlite_str_functions.py +++ b/pony/orm/tests/test_sqlite_str_functions.py @@ -7,14 +7,17 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import only_for db = Database('sqlite', ':memory:') + class Person(db.Entity): name = Required(unicode) age = Optional(int) image = Optional(buffer) + db.generate_mapping(create_tables=True) with db_session: @@ -22,6 +25,7 @@ class Person(db.Entity): p2 = Person(name=u'Иван') # u'\u0418\u0432\u0430\u043d' +@only_for('sqlite') class TestUnicode(unittest.TestCase): @db_session def test1(self): @@ -58,5 +62,6 @@ def test7(self): ages = db.select('select py_lower(image) from person') self.assertEqual(ages, [u'abcdef', None]) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_time_parsing.py b/pony/orm/tests/test_time_parsing.py index a9307f828..064b43191 100644 --- a/pony/orm/tests/test_time_parsing.py +++ b/pony/orm/tests/test_time_parsing.py @@ -6,6 +6,7 @@ from pony.orm.tests.testutils import raises_exception from pony.converting import str2time + class TestTimeParsing(unittest.TestCase): def test_time_1(self): self.assertEqual(str2time('1:2'), time(1, 2)) diff --git a/pony/orm/tests/test_to_dict.py b/pony/orm/tests/test_to_dict.py index 8e2da08fb..29a11b225 100644 --- a/pony/orm/tests/test_to_dict.py +++ b/pony/orm/tests/test_to_dict.py @@ -5,40 +5,48 @@ from pony.orm import * from pony.orm.serialization import to_dict from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') - -class Student(db.Entity): - name = Required(unicode) - scholarship = Optional(int) - gpa = Optional(Decimal, 3, 1) - dob = Optional(date) - group = Optional('Group') - courses = Set('Course') - biography = Optional(LongUnicode) - -class Group(db.Entity): - number = PrimaryKey(int) - students = Set(Student) - -class Course(db.Entity): - name = Required(unicode, unique=True) - students = Set(Student) - -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - g2 = Group(number=2) - c1 = Course(name='Math') - c2 = Course(name='Physics') - c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) - Student(id=4, name='S4') class TestObjectToDict(unittest.TestCase): + @classmethod + def setUpClass(cls): + db = cls.db = Database() + + class Student(db.Entity): + name = Required(unicode) + scholarship = Optional(int) + gpa = Optional(Decimal, 3, 1) + dob = Optional(date) + group = Optional('Group') + courses = Set('Course') + biography = Optional(LongUnicode) + + class Group(db.Entity): + number = PrimaryKey(int) + students = Set(Student) + + class Course(db.Entity): + name = Required(unicode, unique=True) + students = Set(Student) + + setup_database(db) + + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + c1 = Course(id=1, name='Math') + c2 = Course(id=2, name='Physics') + c3 = Course(id=3, name='Computer Science') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4') + + @classmethod + def tearDownClass(cls): + teardown_database(cls.db) + def setUp(self): rollback() db_session.__enter__() @@ -48,133 +56,172 @@ def tearDown(self): db_session.__exit__() def test1(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict() self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1)) def test2(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(related_objects=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, - group=Group[1])) + group=self.db.Group[1])) def test3(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(with_collections=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1, courses=[1, 2])) def test4(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(with_collections=True, related_objects=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, - group=Group[1], courses=[Course[1], Course[2]])) + group=self.db.Group[1], courses=[self.db.Course[1], self.db.Course[2]])) def test5(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1, biography='some text')) def test6(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only=['id', 'name', 'group']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test7(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(['id', 'name', 'group']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test8(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id, name, group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test9(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id name group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test10(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict('id name group') self.assertEqual(d, dict(id=1, name='S1', group=1)) @raises_exception(AttributeError, 'Entity Student does not have attriute x') def test11(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict('id name x group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test12(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict('id name group', related_objects=True) - self.assertEqual(d, dict(id=1, name='S1', group=Group[1])) + self.assertEqual(d, dict(id=1, name='S1', group=self.db.Group[1])) def test13(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude=['dob', 'gpa', 'scholarship']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test14(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob, gpa, scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test15(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) @raises_exception(AttributeError, 'Entity Student does not have attriute x') def test16(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa x scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test17(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship', related_objects=True) - self.assertEqual(d, dict(id=1, name='S1', group=Group[1])) + self.assertEqual(d, dict(id=1, name='S1', group=self.db.Group[1])) def test18(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship', with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', group=1, biography='some text')) def test19(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship biography', with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test20(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship', with_collections=True) self.assertEqual(d, dict(id=1, name='S1', group=1, courses=[1, 2])) def test21(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship courses', with_collections=True) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test22(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id name group', exclude='dob group') self.assertEqual(d, dict(id=1, name='S1')) def test23(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id name group', exclude='dob group', with_collections=True, with_lazy=True) self.assertEqual(d, dict(id=1, name='S1')) def test24(self): - c = Course(name='New Course') + c = self.db.Course(id=4, name='New Course') d = c.to_dict() # should do flush and get c.id from the database self.assertEqual(d, dict(id=4, name='New Course')) + class TestSerializationToDict(unittest.TestCase): + @classmethod + def setUpClass(cls): + db = cls.db = Database() + + class Student(db.Entity): + name = Required(unicode) + scholarship = Optional(int) + gpa = Optional(Decimal, 3, 1) + dob = Optional(date) + group = Optional('Group') + courses = Set('Course') + biography = Optional(LongUnicode) + + class Group(db.Entity): + number = PrimaryKey(int) + students = Set(Student) + + class Course(db.Entity): + name = Required(unicode, unique=True) + students = Set(Student) + + setup_database(db) + + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + c1 = Course(name='Math') + c2 = Course(name='Physics') + c3 = Course(name='Computer Science') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4') + + @classmethod + def tearDownClass(cls): + teardown_database(cls.db) + def setUp(self): rollback() db_session.__enter__() @@ -184,7 +231,7 @@ def tearDown(self): db_session.__exit__() def test1(self): - s4 = Student[4] + s4 = self.db.Student[4] self.assertEqual(s4.group, None) d = to_dict(s4) self.assertEqual(d, dict(Student={ diff --git a/pony/orm/tests/test_tracked_value.py b/pony/orm/tests/test_tracked_value.py new file mode 100644 index 000000000..f24ce5a7c --- /dev/null +++ b/pony/orm/tests/test_tracked_value.py @@ -0,0 +1,56 @@ +import unittest + +from pony.orm.ormtypes import TrackedList, TrackedDict, TrackedValue + +class Object(object): + def __init__(self): + self.on_attr_changed = None + def _attr_changed_(self, attr): + if self.on_attr_changed is not None: + self.on_attr_changed(attr) + + +class Attr(object): + pass + + +class TestTrackedValue(unittest.TestCase): + + def test_make(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + self.assertEqual(type(tracked_value), TrackedDict) + self.assertEqual(type(tracked_value['items']), TrackedList) + + def test_dict_setitem(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + log = [] + obj.on_attr_changed = lambda x: log.append(x) + tracked_value['items'] = [1, 2, 3] + self.assertEqual(log, [attr]) + + def test_list_append(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + log = [] + obj.on_attr_changed = lambda x: log.append(x) + tracked_value['items'].append('four') + self.assertEqual(log, [attr]) + + def test_list_setslice(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + log = [] + obj.on_attr_changed = lambda x: log.append(x) + tracked_value['items'][1:2] = ['a', 'b', 'c'] + self.assertEqual(log, [attr]) + self.assertEqual(tracked_value['items'], ['one', 'a', 'b', 'c', 'three']) diff --git a/pony/orm/tests/test_transaction_lock.py b/pony/orm/tests/test_transaction_lock.py index 002557664..07a5ca440 100644 --- a/pony/orm/tests/test_transaction_lock.py +++ b/pony/orm/tests/test_transaction_lock.py @@ -1,27 +1,31 @@ - - import unittest from pony.orm import * +from pony.orm.tests import setup_database, teardown_database db = Database() + class TestPost(db.Entity): category = Optional('TestCategory') name = Optional(str, default='Noname') + class TestCategory(db.Entity): posts = Set(TestPost) -db.bind('sqlite', ':memory:') -db.generate_mapping(create_tables=True) - -with db_session: - post = TestPost() - class TransactionLockTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + cls.post = TestPost(id=1) + + @classmethod + def tearDownClass(cls): + teardown_database(db) __call__ = db_session(unittest.TestCase.__call__) @@ -29,14 +33,14 @@ def tearDown(self): rollback() def test_create(self): - p = TestPost() + p = TestPost(id=2) p.flush() cache = db._get_cache() self.assertEqual(cache.immediate, True) self.assertEqual(cache.in_transaction, True) def test_update(self): - p = TestPost[post.id] + p = TestPost[self.post.id] p.name = 'Trash' p.flush() cache = db._get_cache() @@ -44,7 +48,7 @@ def test_update(self): self.assertEqual(cache.in_transaction, True) def test_delete(self): - p = TestPost[post.id] + p = TestPost[self.post.id] p.delete() flush() cache = db._get_cache() diff --git a/pony/orm/tests/test_validate.py b/pony/orm/tests/test_validate.py new file mode 100644 index 000000000..813f88a31 --- /dev/null +++ b/pony/orm/tests/test_validate.py @@ -0,0 +1,83 @@ +import unittest, warnings + +from pony.orm import * +from pony.orm import core +from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import db_params, teardown_database + +db = Database() + +class Person(db.Entity): + id = PrimaryKey(int) + name = Required(str) + tel = Optional(str) + + +table_name = 'person' + +class TestValidate(unittest.TestCase): + @classmethod + def setUpClass(cls): + db.bind(**db_params) + db.generate_mapping(check_tables=False) + db.drop_all_tables(with_all_data=True) + with db_session(ddl=True): + db.execute(""" + create table "%s"( + id int primary key, + name text, + tel text + ) + """ % table_name) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session + def setUp(self): + db.execute('delete from "%s"' % table_name) + registry = getattr(core, '__warningregistry__', {}) + for key in list(registry): + if type(key) is not tuple: continue + text, category, lineno = key + if category is DatabaseContainsIncorrectEmptyValue: + del registry[key] + + @db_session + def test_1a(self): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) + db.insert(table_name, id=1, name='', tel='111') + p = Person.get(id=1) + self.assertEqual(p.name, '') + + @raises_exception(DatabaseContainsIncorrectEmptyValue, + 'Database contains empty string for required attribute Person.name') + @db_session + def test_1b(self): + with warnings.catch_warnings(): + warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) + db.insert(table_name, id=1, name='', tel='111') + p = Person.get(id=1) + + @db_session + def test_2a(self): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) + db.insert(table_name, id=1, name=None, tel='111') + p = Person.get(id=1) + self.assertEqual(p.name, None) + + @raises_exception(DatabaseContainsIncorrectEmptyValue, + 'Database contains NULL for required attribute Person.name') + @db_session + def test_2b(self): + with warnings.catch_warnings(): + warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) + db.insert(table_name, id=1, name=None, tel='111') + p = Person.get(id=1) + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_virtuals.py b/pony/orm/tests/test_virtuals.py new file mode 100644 index 000000000..e69de29bb diff --git a/pony/orm/tests/test_volatile.py b/pony/orm/tests/test_volatile.py new file mode 100644 index 000000000..e534ee01f --- /dev/null +++ b/pony/orm/tests/test_volatile.py @@ -0,0 +1,52 @@ +import sys, unittest + +from pony.orm import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + + +class TestVolatile(unittest.TestCase): + def setUp(self): + db = self.db = Database() + + class Item(self.db.Entity): + name = Required(str) + index = Required(int, volatile=True) + + setup_database(db) + + with db_session: + Item(id=1, name='A', index=1) + Item(id=2, name='B', index=2) + Item(id=3, name='C', index=3) + + def tearDown(self): + teardown_database(self.db) + + @db_session + def test_1(self): + db = self.db + Item = db.Item + db.execute('update "item" set "index" = "index" + 1') + items = Item.select(lambda item: item.index > 0).order_by(Item.id)[:] + a, b, c = items + self.assertEqual(a.index, 2) + self.assertEqual(b.index, 3) + self.assertEqual(c.index, 4) + c.index = 1 + items = Item.select()[:] # force re-read from the database + self.assertEqual(c.index, 1) + self.assertEqual(a.index, 2) + self.assertEqual(b.index, 3) + + + @db_session + def test_2(self): + Item = self.db.Item + item = Item[1] + item.name = 'X' + item.flush() + self.assertEqual(item.index, 1) + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/testutils.py b/pony/orm/tests/testutils.py index 87f9a91de..5b28c644d 100644 --- a/pony/orm/tests/testutils.py +++ b/pony/orm/tests/testutils.py @@ -1,23 +1,51 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import basestring +import re +from contextlib import contextmanager + from pony.orm.core import Database from pony.utils import import_module -def raises_exception(exc_class, msg=None): +def test_exception_msg(test_case, exc_msg, test_msg=None): + if test_msg is None: return + error_template = "incorrect exception message. expected '%s', got '%s'" + error_msg = error_template % (test_msg, exc_msg) + assert test_msg not in ('...', '....', '.....', '......') + if '...' not in test_msg: + test_case.assertEqual(test_msg, exc_msg, error_msg) + else: + pattern = ''.join( + '[%s]' % char for char in test_msg.replace('\\', '\\\\') + .replace('[', '\\[') + ).replace('[.][.][.]', '.*') + regex = re.compile(pattern) + if not regex.match(exc_msg): + test_case.fail(error_template % (test_msg, exc_msg)) + +def raises_exception(exc_class, test_msg=None): def decorator(func): - def wrapper(self, *args, **kwargs): + def wrapper(test_case, *args, **kwargs): try: - func(self, *args, **kwargs) - self.fail("expected exception %s wasn't raised" % exc_class.__name__) + func(test_case, *args, **kwargs) + test_case.fail("Expected exception %s wasn't raised" % exc_class.__name__) except exc_class as e: - if not e.args: self.assertEqual(msg, None) - elif msg is not None: - self.assertEqual(e.args[0], msg, "incorrect exception message. expected '%s', got '%s'" % (msg, e.args[0])) + if not e.args: test_case.assertEqual(test_msg, None) + else: test_exception_msg(test_case, str(e), test_msg) wrapper.__name__ = func.__name__ return wrapper return decorator +@contextmanager +def raises_if(test_case, cond, exc_class, test_msg=None): + try: + yield + except exc_class as e: + test_case.assertTrue(cond) + test_exception_msg(test_case, str(e), test_msg) + else: + test_case.assertFalse(cond, "Expected exception %s wasn't raised" % exc_class.__name__) + def flatten(x): result = [] for el in x: @@ -58,7 +86,7 @@ class TestPool(object): def __init__(pool, database): pool.database = database def connect(pool): - return TestConnection(pool.database) + return TestConnection(pool.database), True def release(pool, con): pass def drop(pool, con): @@ -70,7 +98,9 @@ class TestDatabase(Database): real_provider_name = None raw_server_version = None sql = None - def bind(self, provider_name, *args, **kwargs): + def bind(self, provider, *args, **kwargs): + provider_name = provider + assert isinstance(provider_name, basestring) if self.real_provider_name is not None: provider_name = self.real_provider_name self.provider_name = provider_name @@ -92,6 +122,7 @@ def bind(self, provider_name, *args, **kwargs): server_version = int('%d%02d%02d' % server_version) class TestProvider(provider_cls): + json1_available = False # for SQLite def inspect_connection(provider, connection): pass TestProvider.server_version = server_version diff --git a/pony/py23compat.py b/pony/py23compat.py index 51e059da0..7fe218ee9 100644 --- a/pony/py23compat.py +++ b/pony/py23compat.py @@ -1,6 +1,9 @@ -import sys +import sys, platform PY2 = sys.version_info[0] == 2 +PYPY = platform.python_implementation() == 'PyPy' +PYPY2 = PYPY and PY2 +PY37 = sys.version_info[:2] >= (3, 7) if PY2: from future_builtins import zip as izip, map as imap diff --git a/pony/thirdparty/compiler/ast.py b/pony/thirdparty/compiler/ast.py index 7268a2573..e5596d4b6 100644 --- a/pony/thirdparty/compiler/ast.py +++ b/pony/thirdparty/compiler/ast.py @@ -530,6 +530,20 @@ def getChildNodes(self): def __repr__(self): return "For(%s, %s, %s, %s)" % (repr(self.assign), repr(self.list), repr(self.body), repr(self.else_)) +class FormattedValue(Node): + def __init__(self, value, fmt_spec): + self.value = value + self.fmt_spec = fmt_spec + + def getChildren(self): + return self.value, self.fmt_spec + + def getChildNodes(self): + return self.value, self.fmt_spec + + def __repr__(self): + return "FormattedValue(%s, %s)" % (self.value, self.fmt_spec) + class From(Node): def __init__(self, modname, names, level, lineno=None): self.modname = modname @@ -1231,6 +1245,33 @@ def getChildNodes(self): def __repr__(self): return "Stmt(%s)" % (repr(self.nodes),) +class Str(Node): + def __init__(self, value, flags): + self.value = value + self.flags = flags + + def getChildren(self): + return self.value, self.flags + + def getChildNodes(self): + return self.value, + + def __repr__(self): + return "Str(%s, %d)" % (self.value, self.flags) + +class JoinedStr(Node): + def __init__(self, values): + self.values = values + + def getChildren(self): + return self.values + + def getChildNodes(self): + return self.values + + def __repr__(self): + return "JoinedStr(%s)" % (', '.join(repr(value) for value in self.values)) + class Sub(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] diff --git a/pony/thirdparty/compiler/pycodegen.py b/pony/thirdparty/compiler/pycodegen.py index 0181b4fa8..fd24132c6 100644 --- a/pony/thirdparty/compiler/pycodegen.py +++ b/pony/thirdparty/compiler/pycodegen.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, print_function from pony.py23compat import izip -import imp import os import marshal import struct @@ -123,7 +122,12 @@ def dump(self, f): f.write(self.getPycHeader()) marshal.dump(self.code, f) - MAGIC = imp.get_magic() + if VERSION < 3: + import imp + MAGIC = imp.get_magic() + else: + import importlib.util + MAGIC = importlib.util.MAGIC_NUMBER def getPycHeader(self): # compile.c uses marshal to write a long directly, with diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index 8e08361a1..18f5c0cfe 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -31,6 +31,7 @@ from .ast import * import parser import symbol +import sys import token # Python 2.6 compatibility fix @@ -40,6 +41,7 @@ if not hasattr(symbol, 'comp_if'): symbol.comp_if = symbol.gen_if atom_expr = getattr(symbol, 'atom_expr', None) +namedexpr_test = getattr(symbol, 'namedexpr_test', None) class WalkerError(Exception): pass @@ -120,6 +122,9 @@ def atom_expr(self, nodelist): node = self.com_apply_trailer(node, elt) return node + def namedexpr_test(self, nodelist): + return self.test(nodelist[0][1:]) + def __init__(self): self._dispatch = {} for value, name in symbol.sym_name.items(): @@ -136,12 +141,23 @@ def __init__(self): if PY2: self._atom_dispatch.update({ token.BACKQUOTE: self.atom_backquote, }) + if not PY2: self._atom_dispatch.update({ + token.ELLIPSIS: self.atom_ellipsis + }) self.encoding = None + def print_tree(self, tree, indent=''): + for item in tree: + if isinstance(item, tuple): + self.print_tree(item, indent+' ') + else: + print(indent, symbol.sym_name.get(item, item)) + def transform(self, tree): """Transform an AST into a modified parse tree.""" if not (isinstance(tree, tuple) or isinstance(tree, list)): tree = parser.st2tuple(tree, line_info=1) + # self.print_tree(tree) return self.compile_node(tree) def parsesuite(self, text): @@ -610,7 +626,10 @@ def star_expr(self, *args): def testlist_comp(self, nodelist): # test ( comp_for | (',' test)* [','] ) - assert nodelist[0][0] == symbol.test + PY38 = sys.version_info >= (3, 8) + code = nodelist[0][0] + if code not in (symbol.test, namedexpr_test): + assert False, symbol.sym_name.get(code, code) if len(nodelist) == 2 and nodelist[1][0] == symbol.comp_for: test = self.com_node(nodelist[0]) return self.com_generator_expression(test, nodelist[1]) @@ -786,6 +805,9 @@ def atom_lbrace(self, nodelist): def atom_backquote(self, nodelist): return Backquote(self.com_node(nodelist[1])) + def atom_ellipsis(self, nodelist): + return Ellipsis() + def atom_number(self, nodelist): ### need to verify this matches compile.c k = eval(nodelist[0][1]) @@ -1147,7 +1169,7 @@ def com_list_constructor(self, nodelist): # listmaker: test ( list_for | (',' test)* [','] ) values = [] for i in range(1, len(nodelist)): - if nodelist[i][0] == symbol.list_for: + if PY2 and nodelist[i][0] == symbol.list_for: assert len(nodelist[i:]) == 1 return self.com_list_comprehension(values[0], nodelist[i]) @@ -1219,9 +1241,17 @@ def com_generator_expression(self, expr, node): # comp_for: 'for' exprlist 'in' test [comp_iter] # comp_if: 'if' test [comp_iter] - lineno = node[1][2] + PY37 = sys.version_info >= (3, 7) + fors = [] while node: + if PY37 and node[0] == symbol.comp_for: + node = node[1] + assert node[0] == symbol.sync_comp_for + + lineno = node[1][2] + assert lineno is None or isinstance(lineno, int) + t = node[1][1] if t == 'for': assignNode = self.com_assign(node[2], OP_ASSIGN) @@ -1244,7 +1274,7 @@ def com_generator_expression(self, expr, node): else: raise SyntaxError("unexpected generator expression element: %s %d" % (node, lineno)) fors[0].is_outmost = True - return GenExpr(GenExprInner(expr, fors), lineno=lineno) + return GenExpr(GenExprInner(expr, fors), lineno=expr.lineno) def com_dictorsetmaker(self, nodelist): # dictorsetmaker: ( (test ':' test (comp_for | (',' test ':' test)* [','])) | diff --git a/pony/utils/__init__.py b/pony/utils/__init__.py new file mode 100644 index 000000000..83f8d4248 --- /dev/null +++ b/pony/utils/__init__.py @@ -0,0 +1,4 @@ + + +from .utils import * +from .properties import * \ No newline at end of file diff --git a/pony/utils/properties.py b/pony/utils/properties.py new file mode 100644 index 000000000..eedccd7d5 --- /dev/null +++ b/pony/utils/properties.py @@ -0,0 +1,40 @@ + + +class cached_property(object): + """ + A property that is only computed once per instance and then replaces itself + with an ordinary attribute. Deleting the attribute resets the property. + Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 + """ # noqa + + def __init__(self, func): + self.__doc__ = getattr(func, '__doc__') + self.func = func + + def __get__(self, obj, cls): + if obj is None: + return self + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value + + +class class_property(object): + """ + Read-only class property + """ + + def __init__(self, func): + self.func = func + + def __get__(self, instance, cls): + return self.func(cls) + +class class_cached_property(object): + + def __init__(self, func): + self.func = func + + def __get__(self, obj, cls): + value = self.func(cls) + setattr(cls, self.func.__name__, value) + return value \ No newline at end of file diff --git a/pony/utils.py b/pony/utils/utils.py similarity index 74% rename from pony/utils.py rename to pony/utils/utils.py index 1868effa8..e8c6606b9 100644 --- a/pony/utils.py +++ b/pony/utils/utils.py @@ -1,344 +1,452 @@ -from __future__ import absolute_import, print_function -from pony.py23compat import PY2, imap, basestring, unicode - -import re, os.path, sys, inspect, types, warnings - -from datetime import datetime -from itertools import count as _count -from inspect import isfunction -from time import strptime -from collections import defaultdict -from copy import deepcopy, _deepcopy_dispatch -from functools import update_wrapper -from xml.etree import cElementTree - -# deepcopy instance method patch for Python < 2.7: -if types.MethodType not in _deepcopy_dispatch: - assert PY2 - def _deepcopy_method(x, memo): - return type(x)(x.im_func, deepcopy(x.im_self, memo), x.im_class) - _deepcopy_dispatch[types.MethodType] = _deepcopy_method - -import pony -from pony import options - -from pony.thirdparty.compiler import ast -from pony.thirdparty.decorator import decorator as _decorator - -if pony.MODE.startswith('GAE-'): localbase = object -else: from threading import local as localbase - -class PonyDeprecationWarning(DeprecationWarning): - pass - -def deprecated(stacklevel, message): - warnings.warn(message, PonyDeprecationWarning, stacklevel) - -warnings.simplefilter('once', PonyDeprecationWarning) - -def _improved_decorator(caller, func): - if isfunction(func): - return _decorator(caller, func) - def pony_wrapper(*args, **kwargs): - return caller(func, *args, **kwargs) - return pony_wrapper - -def decorator(caller, func=None): - if func is not None: - return _improved_decorator(caller, func) - def new_decorator(func): - return _improved_decorator(caller, func) - if isfunction(caller): - update_wrapper(new_decorator, caller) - return new_decorator - -def decorator_with_params(dec): - def parameterized_decorator(*args, **kwargs): - if len(args) == 1 and isfunction(args[0]) and not kwargs: - return decorator(dec(), args[0]) - return decorator(dec(*args, **kwargs)) - return parameterized_decorator - -@decorator -def cut_traceback(func, *args, **kwargs): - if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): - return func(*args, **kwargs) - - try: return func(*args, **kwargs) - except AssertionError: raise - except Exception: - exc_type, exc, tb = sys.exc_info() - last_pony_tb = None - try: - while tb.tb_next: - module_name = tb.tb_frame.f_globals['__name__'] - if module_name == 'pony' or (module_name is not None # may be None during import - and module_name.startswith('pony.')): - last_pony_tb = tb - tb = tb.tb_next - if last_pony_tb is None: raise - if tb.tb_frame.f_globals.get('__name__') == 'pony.utils' and tb.tb_frame.f_code.co_name == 'throw': - reraise(exc_type, exc, last_pony_tb) - raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback - finally: - del exc, tb, last_pony_tb - -if PY2: - exec('''def reraise(exc_type, exc, tb): - try: raise exc_type, exc, tb - finally: del tb''') -else: - def reraise(exc_type, exc, tb): - try: raise exc.with_traceback(tb) - finally: del exc, tb - -def throw(exc_type, *args, **kwargs): - if isinstance(exc_type, Exception): - assert not args and not kwargs - exc = exc_type - else: exc = exc_type(*args, **kwargs) - exc.__cause__ = None - try: - if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): - raise exc - else: - raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback - finally: del exc - -def truncate_repr(s, max_len=100): - s = repr(s) - return s if len(s) <= max_len else s[:max_len-3] + '...' - -lambda_args_cache = {} - -def get_lambda_args(func): - names = lambda_args_cache.get(func) - if names is not None: return names - if type(func) is types.FunctionType: - if hasattr(inspect, 'signature'): - names, argsname, kwname, defaults = [], None, None, None - for p in inspect.signature(func).parameters.values(): - if p.default is not p.empty: - defaults.append(p.default) - - if p.kind == p.POSITIONAL_OR_KEYWORD: - names.append(p.name) - elif p.kind == p.VAR_POSITIONAL: - argsname = p.name - elif p.kind == p.VAR_KEYWORD: - kwname = p.name - elif p.kind == p.POSITIONAL_ONLY: - throw(TypeError, 'Positional-only arguments like %s are not supported' % p.name) - elif p.kind == p.KEYWORD_ONLY: - throw(TypeError, 'Keyword-only arguments like %s are not supported' % p.name) - else: assert False - else: - names, argsname, kwname, defaults = inspect.getargspec(func) - elif isinstance(func, ast.Lambda): - names = func.argnames - if func.kwargs: names, kwname = names[:-1], names[-1] - else: kwname = None - if func.varargs: names, argsname = names[:-1], names[-1] - else: argsname = None - defaults = func.defaults - else: assert False # pragma: no cover - if argsname: throw(TypeError, '*%s is not supported' % argsname) - if kwname: throw(TypeError, '**%s is not supported' % kwname) - if defaults: throw(TypeError, 'Defaults are not supported') - lambda_args_cache[func] = names - return names - -def error_method(*args, **kwargs): - raise TypeError() - -_ident_re = re.compile(r'^[A-Za-z_]\w*\Z') - -# is_ident = ident_re.match -def is_ident(string): - 'is_ident(string) -> bool' - return bool(_ident_re.match(string)) - -_name_parts_re = re.compile(r''' - [A-Z][A-Z0-9]+(?![a-z]) # ACRONYM - | [A-Z][a-z]* # Capitalized or single capital - | [a-z]+ # all-lowercase - | [0-9]+ # numbers - | _+ # underscores - ''', re.VERBOSE) - -def split_name(name): - "split_name('Some_FUNNYName') -> ['Some', 'FUNNY', 'Name']" - if not _ident_re.match(name): - raise ValueError('Name is not correct Python identifier') - list = _name_parts_re.findall(name) - if not (list[0].strip('_') and list[-1].strip('_')): - raise ValueError('Name must not starting or ending with underscores') - return [ s for s in list if s.strip('_') ] - -def uppercase_name(name): - "uppercase_name('Some_FUNNYName') -> 'SOME_FUNNY_NAME'" - return '_'.join(s.upper() for s in split_name(name)) - -def lowercase_name(name): - "uppercase_name('Some_FUNNYName') -> 'some_funny_name'" - return '_'.join(s.lower() for s in split_name(name)) - -def camelcase_name(name): - "uppercase_name('Some_FUNNYName') -> 'SomeFunnyName'" - return ''.join(s.capitalize() for s in split_name(name)) - -def mixedcase_name(name): - "mixedcase_name('Some_FUNNYName') -> 'someFunnyName'" - list = split_name(name) - return list[0].lower() + ''.join(s.capitalize() for s in list[1:]) - -def import_module(name): - "import_module('a.b.c') -> " - mod = sys.modules.get(name) - if mod is not None: return mod - mod = __import__(name) - components = name.split('.') - for comp in components[1:]: mod = getattr(mod, comp) - return mod - -if sys.platform == 'win32': - _absolute_re = re.compile(r'^(?:[A-Za-z]:)?[\\/]') -else: _absolute_re = re.compile(r'^/') - -def is_absolute_path(filename): - return bool(_absolute_re.match(filename)) - -def absolutize_path(filename, frame_depth): - if is_absolute_path(filename): return filename - code_filename = sys._getframe(frame_depth+1).f_code.co_filename - if not is_absolute_path(code_filename): - if code_filename.startswith('<') and code_filename.endswith('>'): - if pony.MODE == 'INTERACTIVE': raise ValueError( - 'When in interactive mode, please provide absolute file path. Got: %r' % filename) - raise EnvironmentError('Unexpected module filename, which is not absolute file path: %r' % code_filename) - code_path = os.path.dirname(code_filename) - return os.path.join(code_path, filename) - -def current_timestamp(): - return datetime2timestamp(datetime.now()) - -def datetime2timestamp(d): - result = d.isoformat(' ') - if len(result) == 19: return result + '.000000' - return result - -def timestamp2datetime(t): - time_tuple = strptime(t[:19], '%Y-%m-%d %H:%M:%S') - microseconds = int((t[20:26] + '000000')[:6]) - return datetime(*(time_tuple[:6] + (microseconds,))) - -expr1_re = re.compile(r''' - ([A-Za-z_]\w*) # identifier (group 1) - | ([(]) # open parenthesis (group 2) - ''', re.VERBOSE) - -expr2_re = re.compile(r''' - \s*(?: - (;) # semicolon (group 1) - | (\.\s*[A-Za-z_]\w*) # dot + identifier (group 2) - | ([([]) # open parenthesis or braces (group 3) - ) - ''', re.VERBOSE) - -expr3_re = re.compile(r""" - [()[\]] # parenthesis or braces (group 1) - | '''(?:[^\\]|\\.)*?''' # '''triple-quoted string''' - | \"""(?:[^\\]|\\.)*?\""" # \"""triple-quoted string\""" - | '(?:[^'\\]|\\.)*?' # 'string' - | "(?:[^"\\]|\\.)*?" # "string" - """, re.VERBOSE) - -def parse_expr(s, pos=0): - z = 0 - match = expr1_re.match(s, pos) - if match is None: raise ValueError() - start = pos - i = match.lastindex - if i == 1: pos = match.end() # identifier - elif i == 2: z = 2 # "(" - else: assert False # pragma: no cover - while True: - match = expr2_re.match(s, pos) - if match is None: return s[start:pos], z==1 - pos = match.end() - i = match.lastindex - if i == 1: return s[start:pos], False # ";" - explicit end of expression - elif i == 2: z = 2 # .identifier - elif i == 3: # "(" or "[" - pos = match.end() - counter = 1 - open = match.group(i) - if open == '(': close = ')' - elif open == '[': close = ']'; z = 2 - else: assert False # pragma: no cover - while True: - match = expr3_re.search(s, pos) - if match is None: raise ValueError() - pos = match.end() - x = match.group() - if x == open: counter += 1 - elif x == close: - counter -= 1 - if not counter: z += 1; break - else: assert False # pragma: no cover - -def tostring(x): - if isinstance(x, basestring): return x - if hasattr(x, '__unicode__'): - try: return unicode(x) - except: pass - if hasattr(x, 'makeelement'): return cElementTree.tostring(x) - try: return str(x) - except: pass - try: return repr(x) - except: pass - if type(x) == types.InstanceType: return '<%s instance at 0x%X>' % (x.__class__.__name__) - return '<%s object at 0x%X>' % (x.__class__.__name__) - -def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None): - "Can join mix of unicode and byte strings in different encodings" - strings = list(strings) - try: return sep.join(strings) - except UnicodeDecodeError: pass - for i, s in enumerate(strings): - if isinstance(s, str): - strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?') - result = sep.join(strings) - if dest_encoding is None: return result - return result.encode(dest_encoding, replace) - -def count(*args, **kwargs): - if kwargs: return _count(*args, **kwargs) - if len(args) != 1: return _count(*args) - arg = args[0] - if hasattr(arg, 'count'): return arg.count() - try: it = iter(arg) - except TypeError: return _count(arg) - return len(set(it)) - -def avg(iter): - count = 0 - sum = 0.0 - for elem in iter: - if elem is None: continue - sum += elem - count += 1 - if not count: return None - return sum / count - -def distinct(iter): - d = defaultdict(int) - for item in iter: - d[item] = d[item] + 1 - return d - -def concat(*args): - return ''.join(tostring(arg) for arg in args) - -def is_utf8(encoding): - return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') +from __future__ import absolute_import, print_function +from pony.py23compat import PY2, imap, basestring, unicode, pickle, iteritems + +import io, re, os.path, sys, inspect, types, warnings + +from datetime import datetime +from itertools import count as _count +from inspect import isfunction +from time import strptime +from collections import defaultdict +from functools import update_wrapper, wraps +from xml.etree import cElementTree +from copy import deepcopy + +import pony +from pony import options + +from pony.thirdparty.compiler import ast +from pony.thirdparty.decorator import decorator as _decorator + +if pony.MODE.startswith('GAE-'): localbase = object +else: from threading import local as localbase + + +class PonyDeprecationWarning(DeprecationWarning): + pass + +def deprecated(stacklevel, message): + warnings.warn(message, PonyDeprecationWarning, stacklevel) + +warnings.simplefilter('once', PonyDeprecationWarning) + +def _improved_decorator(caller, func): + if isfunction(func): + return _decorator(caller, func) + def pony_wrapper(*args, **kwargs): + return caller(func, *args, **kwargs) + return pony_wrapper + +def decorator(caller, func=None): + if func is not None: + return _improved_decorator(caller, func) + def new_decorator(func): + return _improved_decorator(caller, func) + if isfunction(caller): + update_wrapper(new_decorator, caller) + return new_decorator + +def decorator_with_params(dec): + def parameterized_decorator(*args, **kwargs): + if len(args) == 1 and isfunction(args[0]) and not kwargs: + return decorator(dec(), args[0]) + return decorator(dec(*args, **kwargs)) + return parameterized_decorator + +@decorator +def cut_traceback(func, *args, **kwargs): + if not options.CUT_TRACEBACK: + return func(*args, **kwargs) + + try: return func(*args, **kwargs) + except AssertionError: raise + except Exception: + exc_type, exc, tb = sys.exc_info() + full_tb = tb + last_pony_tb = None + try: + while tb.tb_next: + module_name = tb.tb_frame.f_globals['__name__'] + if module_name == 'pony' or (module_name is not None # may be None during import + and module_name.startswith('pony.')): + last_pony_tb = tb + tb = tb.tb_next + if last_pony_tb is None: raise + module_name = tb.tb_frame.f_globals.get('__name__') or '' + if module_name.startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': + reraise(exc_type, exc, last_pony_tb) + reraise(exc_type, exc, full_tb) + finally: + del exc, full_tb, tb, last_pony_tb + +cut_traceback_depth = 2 + +if pony.MODE != 'INTERACTIVE': + cut_traceback_depth = 0 + def cut_traceback(func): + return func + +if PY2: + exec('''def reraise(exc_type, exc, tb): + try: raise exc_type, exc, tb + finally: del tb''') +else: + def reraise(exc_type, exc, tb): + try: raise exc.with_traceback(tb) + finally: del exc, tb + +def throw(exc_type, *args, **kwargs): + if isinstance(exc_type, Exception): + assert not args and not kwargs + exc = exc_type + else: exc = exc_type(*args, **kwargs) + exc.__cause__ = None + try: + if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): + raise exc + else: + raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback + finally: del exc + +def truncate_repr(s, max_len=100): + s = repr(s) + return s if len(s) <= max_len else s[:max_len-3] + '...' + +codeobjects = {} + +def get_codeobject_id(codeobject): + codeobject_id = id(codeobject) + if codeobject_id not in codeobjects: + codeobjects[codeobject_id] = codeobject + return codeobject_id + +lambda_args_cache = {} + +def get_lambda_args(func): + if type(func) is types.FunctionType: + codeobject = func.func_code if PY2 else func.__code__ + cache_key = get_codeobject_id(codeobject) + elif isinstance(func, ast.Lambda): + cache_key = func + else: assert False # pragma: no cover + + names = lambda_args_cache.get(cache_key) + if names is not None: return names + + if type(func) is types.FunctionType: + if hasattr(inspect, 'signature'): + names, argsname, kwname, defaults = [], None, None, [] + for p in inspect.signature(func).parameters.values(): + if p.default is not p.empty: + defaults.append(p.default) + + if p.kind == p.POSITIONAL_OR_KEYWORD: + names.append(p.name) + elif p.kind == p.VAR_POSITIONAL: + argsname = p.name + elif p.kind == p.VAR_KEYWORD: + kwname = p.name + elif p.kind == p.POSITIONAL_ONLY: + throw(TypeError, 'Positional-only arguments like %s are not supported' % p.name) + elif p.kind == p.KEYWORD_ONLY: + throw(TypeError, 'Keyword-only arguments like %s are not supported' % p.name) + else: assert False + else: + names, argsname, kwname, defaults = inspect.getargspec(func) + elif isinstance(func, ast.Lambda): + names = func.argnames + if func.kwargs: names, kwname = names[:-1], names[-1] + else: kwname = None + if func.varargs: names, argsname = names[:-1], names[-1] + else: argsname = None + defaults = func.defaults + else: assert False # pragma: no cover + if argsname: throw(TypeError, '*%s is not supported' % argsname) + if kwname: throw(TypeError, '**%s is not supported' % kwname) + if defaults: throw(TypeError, 'Defaults are not supported') + + lambda_args_cache[cache_key] = names + return names + +def error_method(*args, **kwargs): + raise TypeError() + +_ident_re = re.compile(r'^[A-Za-z_]\w*\Z') + +# is_ident = ident_re.match +def is_ident(string): + 'is_ident(string) -> bool' + return bool(_ident_re.match(string)) + +_name_parts_re = re.compile(r''' + [A-Z][A-Z0-9]+(?![a-z]) # ACRONYM + | [A-Z][a-z]* # Capitalized or single capital + | [a-z]+ # all-lowercase + | [0-9]+ # numbers + | _+ # underscores + ''', re.VERBOSE) + +def split_name(name): + "split_name('Some_FUNNYName') -> ['Some', 'FUNNY', 'Name']" + if not _ident_re.match(name): + raise ValueError('Name is not correct Python identifier') + list = _name_parts_re.findall(name) + if not (list[0].strip('_') and list[-1].strip('_')): + raise ValueError('Name must not starting or ending with underscores') + return [ s for s in list if s.strip('_') ] + +def uppercase_name(name): + "uppercase_name('Some_FUNNYName') -> 'SOME_FUNNY_NAME'" + return '_'.join(s.upper() for s in split_name(name)) + +def lowercase_name(name): + "uppercase_name('Some_FUNNYName') -> 'some_funny_name'" + return '_'.join(s.lower() for s in split_name(name)) + +def camelcase_name(name): + "uppercase_name('Some_FUNNYName') -> 'SomeFunnyName'" + return ''.join(s.capitalize() for s in split_name(name)) + +def mixedcase_name(name): + "mixedcase_name('Some_FUNNYName') -> 'someFunnyName'" + list = split_name(name) + return list[0].lower() + ''.join(s.capitalize() for s in list[1:]) + +def import_module(name): + "import_module('a.b.c') -> " + mod = sys.modules.get(name) + if mod is not None: return mod + mod = __import__(name) + components = name.split('.') + for comp in components[1:]: mod = getattr(mod, comp) + return mod + +if sys.platform == 'win32': + _absolute_re = re.compile(r'^(?:[A-Za-z]:)?[\\/]') +else: _absolute_re = re.compile(r'^/') + +def is_absolute_path(filename): + return bool(_absolute_re.match(filename)) + +def absolutize_path(filename, frame_depth): + if is_absolute_path(filename): return filename + code_filename = sys._getframe(frame_depth+1).f_code.co_filename + if not is_absolute_path(code_filename): + if code_filename.startswith('<') and code_filename.endswith('>'): + if pony.MODE == 'INTERACTIVE': raise ValueError( + 'When in interactive mode, please provide absolute file path. Got: %r' % filename) + raise EnvironmentError('Unexpected module filename, which is not absolute file path: %r' % code_filename) + code_path = os.path.dirname(code_filename) + return os.path.join(code_path, filename) + +def current_timestamp(): + return datetime2timestamp(datetime.now()) + +def datetime2timestamp(d): + result = d.isoformat(' ') + if len(result) == 19: return result + '.000000' + return result + +def timestamp2datetime(t): + time_tuple = strptime(t[:19], '%Y-%m-%d %H:%M:%S') + microseconds = int((t[20:26] + '000000')[:6]) + return datetime(*(time_tuple[:6] + (microseconds,))) + +expr1_re = re.compile(r''' + ([A-Za-z_]\w*) # identifier (group 1) + | ([(]) # open parenthesis (group 2) + ''', re.VERBOSE) + +expr2_re = re.compile(r''' + \s*(?: + (;) # semicolon (group 1) + | (\.\s*[A-Za-z_]\w*) # dot + identifier (group 2) + | ([([]) # open parenthesis or braces (group 3) + ) + ''', re.VERBOSE) + +expr3_re = re.compile(r""" + [()[\]] # parenthesis or braces (group 1) + | '''(?:[^\\]|\\.)*?''' # '''triple-quoted string''' + | \"""(?:[^\\]|\\.)*?\""" # \"""triple-quoted string\""" + | '(?:[^'\\]|\\.)*?' # 'string' + | "(?:[^"\\]|\\.)*?" # "string" + """, re.VERBOSE) + +def parse_expr(s, pos=0): + z = 0 + match = expr1_re.match(s, pos) + if match is None: raise ValueError() + start = pos + i = match.lastindex + if i == 1: pos = match.end() # identifier + elif i == 2: z = 2 # "(" + else: assert False # pragma: no cover + while True: + match = expr2_re.match(s, pos) + if match is None: return s[start:pos], z==1 + pos = match.end() + i = match.lastindex + if i == 1: return s[start:pos], False # ";" - explicit end of expression + elif i == 2: z = 2 # .identifier + elif i == 3: # "(" or "[" + pos = match.end() + counter = 1 + open = match.group(i) + if open == '(': close = ')' + elif open == '[': close = ']'; z = 2 + else: assert False # pragma: no cover + while True: + match = expr3_re.search(s, pos) + if match is None: raise ValueError() + pos = match.end() + x = match.group() + if x == open: counter += 1 + elif x == close: + counter -= 1 + if not counter: z += 1; break + else: assert False # pragma: no cover + +def tostring(x): + if isinstance(x, basestring): return x + if hasattr(x, '__unicode__'): + try: return unicode(x) + except: pass + if hasattr(x, 'makeelement'): return cElementTree.tostring(x) + try: return str(x) + except: pass + try: return repr(x) + except: pass + if type(x) == types.InstanceType: return '<%s instance at 0x%X>' % (x.__class__.__name__) + return '<%s object at 0x%X>' % (x.__class__.__name__) + +def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None): + "Can join mix of unicode and byte strings in different encodings" + strings = list(strings) + try: return sep.join(strings) + except UnicodeDecodeError: pass + for i, s in enumerate(strings): + if isinstance(s, str): + strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?') + result = sep.join(strings) + if dest_encoding is None: return result + return result.encode(dest_encoding, 'replace') + +def count(*args, **kwargs): + if kwargs: return _count(*args, **kwargs) + if len(args) != 1: return _count(*args) + arg = args[0] + if hasattr(arg, 'count'): return arg.count() + try: it = iter(arg) + except TypeError: return _count(arg) + return len(set(it)) + +def avg(iter): + count = 0 + sum = 0.0 + for elem in iter: + if elem is None: continue + sum += elem + count += 1 + if not count: return None + return sum / count + +def group_concat(items, sep=','): + if items is None: + return None + return str(sep).join(str(item) for item in items) + +def coalesce(*args): + for arg in args: + if arg is not None: + return arg + return None + +def distinct(iter): + d = defaultdict(int) + for item in iter: + d[item] = d[item] + 1 + return d + +def concat(*args): + return ''.join(tostring(arg) for arg in args) + +def between(x, a, b): + return a <= x <= b + +def is_utf8(encoding): + return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') + +def _persistent_id(obj): + if obj is Ellipsis: + return "Ellipsis" + +def _persistent_load(persid): + if persid == "Ellipsis": + return Ellipsis + raise pickle.UnpicklingError("unsupported persistent object") + +def pickle_ast(val): + pickled = io.BytesIO() + pickler = pickle.Pickler(pickled) + pickler.persistent_id = _persistent_id + pickler.dump(val) + return pickled + +def unpickle_ast(pickled): + pickled.seek(0) + unpickler = pickle.Unpickler(pickled) + unpickler.persistent_load = _persistent_load + return unpickler.load() + +def copy_ast(tree): + return unpickle_ast(pickle_ast(tree)) + +def _hashable_wrap(func): + @wraps(func, assigned=('__name__', '__doc__')) + def new_func(self, *args, **kwargs): + if getattr(self, '_hash', None) is not None: + assert False, 'Cannot mutate HashableDict instance after the hash value is calculated' + return func(self, *args, **kwargs) + return new_func + +class HashableDict(dict): + def __hash__(self): + result = getattr(self, '_hash', None) + if result is None: + result = 0 + for key, value in self.items(): + result ^= hash(key) + result ^= hash(value) + self._hash = result + return result + def __deepcopy__(self, memo): + if getattr(self, '_hash', None) is not None: + return self + return HashableDict({deepcopy(key, memo): deepcopy(value, memo) + for key, value in iteritems(self)}) + __setitem__ = _hashable_wrap(dict.__setitem__) + __delitem__ = _hashable_wrap(dict.__delitem__) + clear = _hashable_wrap(dict.clear) + pop = _hashable_wrap(dict.pop) + popitem = _hashable_wrap(dict.popitem) + setdefault = _hashable_wrap(dict.setdefault) + update = _hashable_wrap(dict.update) + +def deref_proxy(value): + t = type(value) + if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: + # Flask local proxy + value = value._get_current_object() + elif t.__name__ == 'EntityProxy': + # Pony proxy + value = value._get_object() + + return value + +def deduplicate(value, deduplication_cache): + t = type(value) + try: + return deduplication_cache[t].setdefault(value, value) + except: + return value diff --git a/setup.py b/setup.py index e8f6bdbfe..56709cc5e 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,15 @@ from __future__ import print_function -from distutils.core import setup +from setuptools import setup import sys +import unittest + +def test_suite(): + test_loader = unittest.TestLoader() + test_suite = test_loader.discover('pony.orm.tests', pattern='test_*.py') + return test_suite + name = "pony" version = __import__('pony').__version__ description = "Pony Object-Relational Mapper" @@ -43,57 +50,67 @@ Pony ORM Links: ================= -- Main site: http://ponyorm.com -- Documentation: http://doc.ponyorm.com +- Main site: https://ponyorm.com +- Documentation: https://docs.ponyorm.com - GitHub: https://github.com/ponyorm/pony - Mailing list: http://ponyorm-list.ponyorm.com - ER Diagram Editor: https://editor.ponyorm.com -- Blog: http://blog.ponyorm.com +- Blog: https://blog.ponyorm.com """ classifiers = [ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', - 'License :: Free for non-commercial use', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'License :: Other/Proprietary License', - 'License :: Free For Educational Use', - 'License :: Free for non-commercial use', + 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries', 'Topic :: Database' ] author = "Alexander Kozlovsky, Alexey Malashkevich" author_email = "team@ponyorm.com" -url = "http://ponyorm.com" -lic = "AGPL, Commercial, Free for educational and non-commercial use" +url = "https://ponyorm.com" +licence = "Apache License Version 2.0" packages = [ "pony", + "pony.flask", + "pony.flask.example", "pony.orm", "pony.orm.dbproviders", "pony.orm.examples", "pony.orm.integration", "pony.orm.tests", "pony.thirdparty", - "pony.thirdparty.compiler" + "pony.thirdparty.compiler", + "pony.utils" ] +package_data = { + 'pony.flask.example': ['templates/*.html'], + 'pony.orm.tests': ['queries.txt'] +} + download_url = "http://pypi.python.org/pypi/pony/" if __name__ == "__main__": pv = sys.version_info[:2] - if pv not in ((2, 6), (2, 7), (3, 3), (3, 4), (3, 5)): - s = "Sorry, but %s %s requires Python of one of the following versions: 2.6, 2.7, 3.3, 3.4 and 3.5." \ + if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11)): + s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.11." \ " You have version %s" print(s % (name, version, sys.version.split(' ', 1)[0])) sys.exit(1) @@ -107,7 +124,9 @@ author=author, author_email=author_email, url=url, - license=lic, + license=licence, packages=packages, - download_url=download_url + package_data=package_data, + download_url=download_url, + test_suite='setup.test_suite' )