19.0 vanilla

This commit is contained in:
Ernad Husremovic 2025-10-03 18:07:25 +02:00
parent 0a7ae8db93
commit 991d2234ca
416 changed files with 646602 additions and 300844 deletions

View file

@ -15,6 +15,7 @@ import threading
import time
import typing
import uuid
import warnings
from contextlib import contextmanager
from datetime import datetime, timedelta
from inspect import currentframe
@ -28,22 +29,31 @@ from psycopg2.sql import Composable
from werkzeug import urls
import odoo
from . import tools
from .tools import SQL
from .release import MIN_PG_VERSION
from .tools import config, SQL
from .tools.func import frame_codeinfo, locked
from .tools.misc import Callbacks
from .tools.misc import Callbacks, real_time
if typing.TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from odoo.orm.environments import Transaction
T = typing.TypeVar('T')
# when type checking, the BaseCursor exposes methods of the psycopg cursor
_CursorProtocol = psycopg2.extensions.cursor
else:
_CursorProtocol = object
def undecimalize(value, cr):
def undecimalize(value, cr) -> float | None:
if value is None:
return None
return float(value)
DECIMAL_TO_FLOAT_TYPE = psycopg2.extensions.new_type((1700,), 'float', undecimalize)
psycopg2.extensions.register_type(DECIMAL_TO_FLOAT_TYPE)
psycopg2.extensions.register_type(psycopg2.extensions.new_array_type((1231,), 'float[]', DECIMAL_TO_FLOAT_TYPE))
@ -51,13 +61,11 @@ psycopg2.extensions.register_type(psycopg2.extensions.new_array_type((1231,), 'f
_logger = logging.getLogger(__name__)
_logger_conn = _logger.getChild("connection")
real_time = time.time.__call__ # ensure we have a non patched time for query times when using freezegun
re_from = re.compile(r'\bfrom\s+"?([a-zA-Z_0-9]+)\b', re.IGNORECASE)
re_into = re.compile(r'\binto\s+"?([a-zA-Z_0-9]+)\b', re.IGNORECASE)
def categorize_query(decoded_query):
def categorize_query(decoded_query: str) -> tuple[typing.Literal['from', 'into'], str] | tuple[typing.Literal['other'], None]:
res_into = re_into.search(decoded_query)
# prioritize `insert` over `select` so `select` subqueries are not
# considered when inside a `insert`
@ -71,7 +79,7 @@ def categorize_query(decoded_query):
return 'other', None
sql_counter = 0
sql_counter: int = 0
MAX_IDLE_TIMEOUT = 60 * 10
@ -94,10 +102,11 @@ class Savepoint:
:param BaseCursor cr: the cursor to execute the `SAVEPOINT` queries on
"""
def __init__(self, cr):
def __init__(self, cr: _CursorProtocol):
self.name = str(uuid.uuid1())
self._cr = cr
self.closed = False
self.closed: bool = False
cr.execute('SAVEPOINT "%s"' % self.name)
def __enter__(self):
@ -106,14 +115,14 @@ class Savepoint:
def __exit__(self, exc_type, exc_val, exc_tb):
self.close(rollback=exc_type is not None)
def close(self, *, rollback=True):
def close(self, *, rollback: bool = True):
if not self.closed:
self._close(rollback)
def rollback(self):
self._cr.execute('ROLLBACK TO SAVEPOINT "%s"' % self.name)
def _close(self, rollback):
def _close(self, rollback: bool):
if rollback:
self.rollback()
self._cr.execute('RELEASE SAVEPOINT "%s"' % self.name)
@ -121,15 +130,17 @@ class Savepoint:
class _FlushingSavepoint(Savepoint):
def __init__(self, cr):
def __init__(self, cr: BaseCursor):
cr.flush()
super().__init__(cr)
def rollback(self):
assert isinstance(self._cr, BaseCursor)
self._cr.clear()
super().rollback()
def _close(self, rollback):
def _close(self, rollback: bool):
assert isinstance(self._cr, BaseCursor)
try:
if not rollback:
self._cr.flush()
@ -140,39 +151,63 @@ class _FlushingSavepoint(Savepoint):
super()._close(rollback)
class BaseCursor:
# _CursorProtocol declares the available methods and type information,
# at runtime, it is just an `object`
class BaseCursor(_CursorProtocol):
""" Base class for cursors that manage pre/post commit hooks. """
IN_MAX = 1000 # decent limit on size of IN queries - guideline = Oracle limit
def __init__(self):
transaction: Transaction | None
cache: dict[typing.Any, typing.Any]
dbname: str
def __init__(self) -> None:
self.precommit = Callbacks()
self.postcommit = Callbacks()
self.prerollback = Callbacks()
self.postrollback = Callbacks()
self._now: datetime | None = None
self.cache = {}
# By default a cursor has no transaction object. A transaction object
# for managing environments is instantiated by registry.cursor(). It
# is not done here in order to avoid cyclic module dependencies.
self.transaction = None
def flush(self):
def flush(self) -> None:
""" Flush the current transaction, and run precommit hooks. """
if self.transaction is not None:
self.transaction.flush()
self.precommit.run()
def clear(self):
def clear(self) -> None:
""" Clear the current transaction, and clear precommit hooks. """
if self.transaction is not None:
self.transaction.clear()
self.precommit.clear()
def reset(self):
def reset(self) -> None:
""" Reset the current transaction (this invalidates more that clear()).
This method should be called only right after commit() or rollback().
"""
if self.transaction is not None:
self.transaction.reset()
def savepoint(self, flush=True) -> Savepoint:
def execute(self, query, params=None, log_exceptions: bool = True) -> None:
""" Execute a query inside the current transaction.
"""
raise NotImplementedError
def commit(self) -> None:
""" Commit the current transaction.
"""
raise NotImplementedError
def rollback(self) -> None:
""" Rollback the current transaction.
"""
raise NotImplementedError
def savepoint(self, flush: bool = True) -> Savepoint:
"""context manager entering in a new savepoint
With ``flush`` (the default), will automatically run (or clear) the
@ -202,6 +237,39 @@ class BaseCursor:
finally:
self.close()
def dictfetchone(self) -> dict[str, typing.Any] | None:
""" Return the first row as a dict (column_name -> value) or None if no rows are available. """
raise NotImplementedError
def dictfetchmany(self, size: int) -> list[dict[str, typing.Any]]:
res: list[dict[str, typing.Any]] = []
while size > 0 and (row := self.dictfetchone()) is not None:
res.append(row)
size -= 1
return res
def dictfetchall(self) -> list[dict[str, typing.Any]]:
""" Return all rows as dicts (column_name -> value). """
res: list[dict[str, typing.Any]] = []
while (row := self.dictfetchone()) is not None:
res.append(row)
return res
def split_for_in_conditions(self, ids: Iterable[T], size: int = 0) -> Iterator[tuple[T, ...]]:
"""Split a list of identifiers into one or more smaller tuples
safe for IN conditions, after uniquifying them."""
warnings.warn("Deprecated since 19.0, use split_every(cr.IN_MAX, ids)", DeprecationWarning)
return tools.misc.split_every(size or self.IN_MAX, ids)
def now(self) -> datetime:
""" Return the transaction's timestamp ``NOW() AT TIME ZONE 'UTC'``. """
if self._now is None:
self.execute("SELECT (now() AT TIME ZONE 'UTC')")
row = self.fetchone()
assert row
self._now = row[0]
return self._now
class Cursor(BaseCursor):
"""Represents an open transaction to the PostgreSQL DB backend,
@ -267,9 +335,11 @@ class Cursor(BaseCursor):
*any* data which may be modified during the life of the cursor.
"""
IN_MAX = 1000 # decent limit on size of IN queries - guideline = Oracle limit
sql_from_log: dict[str, tuple[int, float]]
sql_into_log: dict[str, tuple[int, float]]
sql_log_count: int
def __init__(self, pool, dbname, dsn):
def __init__(self, pool: ConnectionPool, dbname: str, dsn: dict):
super().__init__()
self.sql_from_log = {}
self.sql_into_log = {}
@ -280,13 +350,13 @@ class Cursor(BaseCursor):
# avoid the call of close() (by __del__) if an exception
# is raised by any of the following initializations
self._closed = True
self._closed: bool = True
self.__pool = pool
self.__pool: ConnectionPool = pool
self.dbname = dbname
self._cnx = pool.borrow(dsn)
self._obj = self._cnx.cursor()
self._cnx: PsycoConnection = pool.borrow(dsn)
self._obj: psycopg2.extensions.cursor = self._cnx.cursor()
if _logger.isEnabledFor(logging.DEBUG):
self.__caller = frame_codeinfo(currentframe(), 2)
else:
@ -296,23 +366,23 @@ class Cursor(BaseCursor):
self.connection.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
self.connection.set_session(readonly=pool.readonly)
self.cache = {}
self._now = None
if os.getenv('ODOO_FAKETIME_TEST_MODE') and self.dbname in tools.config['db_name'].split(','):
if os.getenv('ODOO_FAKETIME_TEST_MODE') and self.dbname in tools.config['db_name']:
self.execute("SET search_path = public, pg_catalog;")
self.commit() # ensure that the search_path remains after a rollback
def __build_dict(self, row):
return {d.name: row[i] for i, d in enumerate(self._obj.description)}
def __build_dict(self, row: tuple) -> dict[str, typing.Any]:
description = self._obj.description
assert description, "Query does not have results"
return {column.name: row[index] for index, column in enumerate(description)}
def dictfetchone(self):
def dictfetchone(self) -> dict[str, typing.Any] | None:
row = self._obj.fetchone()
return row and self.__build_dict(row)
return self.__build_dict(row) if row else None
def dictfetchmany(self, size):
def dictfetchmany(self, size) -> list[dict[str, typing.Any]]:
return [self.__build_dict(row) for row in self._obj.fetchmany(size)]
def dictfetchall(self):
def dictfetchall(self) -> list[dict[str, typing.Any]]:
return [self.__build_dict(row) for row in self._obj.fetchall()]
def __del__(self):
@ -330,17 +400,17 @@ class Cursor(BaseCursor):
_logger.warning(msg)
self._close(True)
def _format(self, query, params=None):
def _format(self, query, params=None) -> str:
encoding = psycopg2.extensions.encodings[self.connection.encoding]
return self.mogrify(query, params).decode(encoding, 'replace')
def mogrify(self, query, params=None):
def mogrify(self, query, params=None) -> bytes:
if isinstance(query, SQL):
assert params is None, "Unexpected parameters for SQL query object"
query, params = query.code, query.params
return self._obj.mogrify(query, params)
def execute(self, query, params=None, log_exceptions=True):
def execute(self, query, params=None, log_exceptions: bool = True) -> None:
global sql_counter
if isinstance(query, SQL):
@ -353,8 +423,7 @@ class Cursor(BaseCursor):
start = real_time()
try:
params = params or None
res = self._obj.execute(query, params)
self._obj.execute(query, params)
except Exception as e:
if log_exceptions:
_logger.error("bad query: %s\nERROR: %s", self._obj.query or query, e)
@ -371,6 +440,7 @@ class Cursor(BaseCursor):
current_thread = threading.current_thread()
if hasattr(current_thread, 'query_count'):
current_thread.query_count += 1
if hasattr(current_thread, 'query_time'):
current_thread.query_time += delay
# optional hooks for performance and tracing analysis
@ -379,17 +449,18 @@ class Cursor(BaseCursor):
# advanced stats
if _logger.isEnabledFor(logging.DEBUG):
query_type, table = categorize_query(self._obj.query.decode())
if obj_query := self._obj.query:
query = obj_query.decode()
query_type, table = categorize_query(query)
log_target = None
if query_type == 'into':
log_target = self.sql_into_log
elif query_type == 'from':
log_target = self.sql_from_log
if log_target:
stats = log_target.setdefault(table, [0, 0])
stats[0] += 1
stats[1] += delay * 1E6
return res
stat_count, stat_time = log_target.get(table or '', (0, 0))
log_target[table or ''] = (stat_count + 1, stat_time + delay * 1E6)
return None
def execute_values(self, query, argslist, template=None, page_size=100, fetch=False):
"""
@ -402,30 +473,26 @@ class Cursor(BaseCursor):
query = query.as_string(self._obj)
return psycopg2.extras.execute_values(self, query, argslist, template=template, page_size=page_size, fetch=fetch)
def split_for_in_conditions(self, ids: Iterable[T], size: int = 0) -> Iterator[tuple[T, ...]]:
"""Split a list of identifiers into one or more smaller tuples
safe for IN conditions, after uniquifying them."""
return tools.misc.split_every(size or self.IN_MAX, ids)
def print_log(self):
def print_log(self) -> None:
global sql_counter
if not _logger.isEnabledFor(logging.DEBUG):
return
def process(type):
def process(log_type: str):
sqllogs = {'from': self.sql_from_log, 'into': self.sql_into_log}
sum = 0
if sqllogs[type]:
sqllogitems = sqllogs[type].items()
_logger.debug("SQL LOG %s:", type)
for r in sorted(sqllogitems, key=lambda k: k[1]):
delay = timedelta(microseconds=r[1][1])
_logger.debug("table: %s: %s/%s", r[0], delay, r[1][0])
sum += r[1][1]
sqllogs[type].clear()
sum = timedelta(microseconds=sum)
_logger.debug("SUM %s:%s/%d [%d]", type, sum, self.sql_log_count, sql_counter)
sqllogs[type].clear()
sqllog = sqllogs[log_type]
total = 0.0
if sqllog:
_logger.debug("SQL LOG %s:", log_type)
for table, (stat_count, stat_time) in sorted(sqllog.items(), key=lambda k: k[1]):
delay = timedelta(microseconds=stat_time)
_logger.debug("table: %s: %s/%s", table, delay, stat_count)
total += stat_time
sqllog.clear()
total_delay = timedelta(microseconds=total)
_logger.debug("SUM %s:%s/%d [%d]", log_type, total_delay, self.sql_log_count, sql_counter)
process('from')
process('into')
self.sql_log_count = 0
@ -443,15 +510,15 @@ class Cursor(BaseCursor):
finally:
_logger.setLevel(level)
def close(self):
def close(self) -> None:
if not self.closed:
return self._close(False)
def _close(self, leak=False):
def _close(self, leak: bool = False) -> None:
if not self._obj:
return
del self.cache
self.cache.clear()
# advanced stats only at logging.DEBUG level
self.print_log()
@ -471,32 +538,30 @@ class Cursor(BaseCursor):
self._closed = True
if leak:
self._cnx.leaked = True
self._cnx.leaked = True # type: ignore
else:
chosen_template = tools.config['db_template']
keep_in_pool = self.dbname not in ('template0', 'template1', 'postgres', chosen_template)
self.__pool.give_back(self._cnx, keep_in_pool=keep_in_pool)
def commit(self):
def commit(self) -> None:
""" Perform an SQL `COMMIT` """
self.flush()
result = self._cnx.commit()
self._cnx.commit()
self.clear()
self._now = None
self.prerollback.clear()
self.postrollback.clear()
self.postcommit.run()
return result
def rollback(self):
def rollback(self) -> None:
""" Perform an SQL `ROLLBACK` """
self.clear()
self.postcommit.clear()
self.prerollback.run()
result = self._cnx.rollback()
self._cnx.rollback()
self._now = None
self.postrollback.run()
return result
def __getattr__(self, name):
if self._closed and name == '_obj':
@ -504,131 +569,18 @@ class Cursor(BaseCursor):
return getattr(self._obj, name)
@property
def closed(self):
return self._closed or self._cnx.closed
def closed(self) -> bool:
return self._closed or bool(self._cnx.closed)
@property
def readonly(self):
def readonly(self) -> bool:
return bool(self._cnx.readonly)
def now(self):
""" Return the transaction's timestamp ``NOW() AT TIME ZONE 'UTC'``. """
if self._now is None:
self.execute("SELECT (now() AT TIME ZONE 'UTC')")
self._now = self.fetchone()[0]
return self._now
class TestCursor(BaseCursor):
""" A pseudo-cursor to be used for tests, on top of a real cursor. It keeps
the transaction open across requests, and simulates committing, rolling
back, and closing:
+------------------------+---------------------------------------------------+
| test cursor | queries on actual cursor |
+========================+===================================================+
|``cr = TestCursor(...)``| |
+------------------------+---------------------------------------------------+
| ``cr.execute(query)`` | SAVEPOINT test_cursor_N (if not savepoint) |
| | query |
+------------------------+---------------------------------------------------+
| ``cr.commit()`` | RELEASE SAVEPOINT test_cursor_N (if savepoint) |
+------------------------+---------------------------------------------------+
| ``cr.rollback()`` | ROLLBACK TO SAVEPOINT test_cursor_N (if savepoint)|
+------------------------+---------------------------------------------------+
| ``cr.close()`` | ROLLBACK TO SAVEPOINT test_cursor_N (if savepoint)|
| | RELEASE SAVEPOINT test_cursor_N (if savepoint) |
+------------------------+---------------------------------------------------+
"""
_cursors_stack = []
def __init__(self, cursor, lock, readonly, current_test=None):
assert isinstance(cursor, BaseCursor)
self.current_test = current_test
self._check('__init__')
super().__init__()
self._now = None
self._closed = False
self._cursor = cursor
self.readonly = readonly
# we use a lock to serialize concurrent requests
self._lock = lock
self._lock.acquire()
last_cursor = self._cursors_stack and self._cursors_stack[-1]
if last_cursor and last_cursor.readonly and not readonly and last_cursor._savepoint:
raise Exception('Opening a read/write test cursor from a readonly one')
self._cursors_stack.append(self)
# in order to simulate commit and rollback, the cursor maintains a
# savepoint at its last commit, the savepoint is created lazily
self._savepoint = None
def _check_savepoint(self):
if not self._savepoint:
# we use self._cursor._obj for the savepoint to avoid having the
# savepoint queries in the query counts, profiler, ...
# Those queries are tests artefacts and should be invisible.
self._savepoint = Savepoint(self._cursor._obj)
if self.readonly:
# this will simulate a readonly connection
self._cursor._obj.execute('SET TRANSACTION READ ONLY') # use _obj to avoid impacting query count and profiler.
def _check(self, operation):
if self.current_test:
self.current_test.check_test_cursor(operation)
def execute(self, *args, **kwargs):
assert not self._closed, "Cannot use a closed cursor"
self._check_savepoint()
return self._cursor.execute(*args, **kwargs)
def close(self):
if not self._closed:
try:
self.rollback()
if self._savepoint:
self._savepoint.close(rollback=False)
finally:
self._closed = True
tos = self._cursors_stack.pop()
if tos is not self:
_logger.warning("Found different un-closed cursor when trying to close %s: %s", self, tos)
self._lock.release()
def commit(self):
""" Perform an SQL `COMMIT` """
self._check('commit')
self.flush()
if self._savepoint:
self._savepoint.close(rollback=self.readonly)
self._savepoint = None
self.clear()
self.prerollback.clear()
self.postrollback.clear()
self.postcommit.clear() # TestCursor ignores post-commit hooks by default
def rollback(self):
""" Perform an SQL `ROLLBACK` """
self._check('rollback')
self.clear()
self.postcommit.clear()
self.prerollback.run()
if self._savepoint:
self._savepoint.close(rollback=True)
self._savepoint = None
self.postrollback.run()
def __getattr__(self, name):
self._check(name)
return getattr(self._cursor, name)
def now(self):
""" Return the transaction's timestamp ``datetime.now()``. """
if self._now is None:
self._now = datetime.now()
return self._now
class PsycoConnection(psycopg2.extensions.connection):
_pool_in_use: bool = False
_pool_last_used: float = 0
def lobject(*args, **kwargs):
pass
@ -642,7 +594,7 @@ class PsycoConnection(psycopg2.extensions.connection):
return PsycoConnectionInfo(self)
class ConnectionPool(object):
class ConnectionPool:
""" The pool of connections to database(s)
Keep a set of connections to pg databases open, and reuse them
@ -651,27 +603,29 @@ class ConnectionPool(object):
The connections are *not* automatically closed. Only a close_db()
can trigger that.
"""
def __init__(self, maxconn=64, readonly=False):
_connections: list[PsycoConnection]
def __init__(self, maxconn: int = 64, readonly: bool = False):
self._connections = []
self._maxconn = max(maxconn, 1)
self._readonly = readonly
self._lock = threading.Lock()
def __repr__(self):
used = len([1 for c, u, _ in self._connections[:] if u])
used = sum(1 for c in self._connections if c._pool_in_use)
count = len(self._connections)
mode = 'read-only' if self._readonly else 'read/write'
return f"ConnectionPool({mode};used={used}/count={count}/max={self._maxconn})"
@property
def readonly(self):
def readonly(self) -> bool:
return self._readonly
def _debug(self, msg, *args):
def _debug(self, msg: str, *args):
_logger_conn.debug(('%r ' + msg), self, *args)
@locked
def borrow(self, connection_info):
def borrow(self, connection_info: dict) -> PsycoConnection:
"""
Borrow a PsycoConnection from the pool. If no connection is available, create a new one
as long as there are still slots available. Perform some garbage-collection in the pool:
@ -681,8 +635,8 @@ class ConnectionPool(object):
:rtype: PsycoConnection
"""
# free idle, dead and leaked connections
for i, (cnx, used, last_used) in tools.reverse_enumerate(self._connections):
if not used and not cnx.closed and time.time() - last_used > MAX_IDLE_TIMEOUT:
for i, cnx in tools.reverse_enumerate(self._connections):
if not cnx._pool_in_use and not cnx.closed and time.time() - cnx._pool_last_used > MAX_IDLE_TIMEOUT:
self._debug('Close connection at index %d: %r', i, cnx.dsn)
cnx.close()
if cnx.closed:
@ -691,11 +645,11 @@ class ConnectionPool(object):
continue
if getattr(cnx, 'leaked', False):
delattr(cnx, 'leaked')
self._connections[i][1] = False
cnx._pool_in_use = False
_logger.info('%r: Free leaked connection to %r', self, cnx.dsn)
for i, (cnx, used, _) in enumerate(self._connections):
if not used and self._dsn_equals(cnx.dsn, connection_info):
for i, cnx in enumerate(self._connections):
if not cnx._pool_in_use and self._dsn_equals(cnx.dsn, connection_info):
try:
cnx.reset()
except psycopg2.OperationalError:
@ -704,15 +658,15 @@ class ConnectionPool(object):
if not cnx.closed:
cnx.close()
continue
self._connections[i][1] = True
cnx._pool_in_use = True
self._debug('Borrow existing connection to %r at index %d', cnx.dsn, i)
return cnx
if len(self._connections) >= self._maxconn:
# try to remove the oldest connection not used
for i, (cnx, used, _) in enumerate(self._connections):
if not used:
for i, cnx in enumerate(self._connections):
if not cnx._pool_in_use:
self._connections.pop(i)
if not cnx.closed:
cnx.close()
@ -729,43 +683,46 @@ class ConnectionPool(object):
except psycopg2.Error:
_logger.info('Connection to the database failed')
raise
self._connections.append([result, True, 0])
if result.server_version < MIN_PG_VERSION * 10000:
warnings.warn(f"Postgres version is {result.server_version}, lower than minimum required {MIN_PG_VERSION * 10000}")
result._pool_in_use = True
self._connections.append(result)
self._debug('Create new connection backend PID %d', result.get_backend_pid())
return result
@locked
def give_back(self, connection, keep_in_pool=True):
def give_back(self, connection: PsycoConnection, keep_in_pool: bool = True):
self._debug('Give back connection to %r', connection.dsn)
for i, (cnx, _, _) in enumerate(self._connections):
if cnx is connection:
if keep_in_pool:
# Release the connection and record the last time used
self._connections[i][1] = False
self._connections[i][2] = time.time()
self._debug('Put connection to %r in pool', cnx.dsn)
else:
self._connections.pop(i)
self._debug('Forgot connection to %r', cnx.dsn)
cnx.close()
break
else:
try:
index = self._connections.index(connection)
except ValueError:
raise PoolError('This connection does not belong to the pool')
if keep_in_pool:
# Release the connection and record the last time used
connection._pool_in_use = False
connection._pool_last_used = time.time()
self._debug('Put connection to %r in pool', connection.dsn)
else:
cnx = self._connections.pop(index)
self._debug('Forgot connection to %r', cnx.dsn)
cnx.close()
@locked
def close_all(self, dsn=None):
def close_all(self, dsn: dict | str | None = None):
count = 0
last = None
for i, (cnx, _, _) in tools.reverse_enumerate(self._connections):
for i, cnx in tools.reverse_enumerate(self._connections):
if dsn is None or self._dsn_equals(cnx.dsn, dsn):
cnx.close()
last = self._connections.pop(i)[0]
last = self._connections.pop(i)
count += 1
if count:
_logger.info('%r: Closed %d connections %s', self, count,
(dsn and last and 'to %r' % last.dsn) or '')
def _dsn_equals(self, dsn1, dsn2):
def _dsn_equals(self, dsn1: dict | str, dsn2: dict | str) -> bool:
alias_keys = {'dbname': 'database'}
ignore_keys = ['password']
dsn1, dsn2 = ({
@ -776,32 +733,33 @@ class ConnectionPool(object):
return dsn1 == dsn2
class Connection(object):
class Connection:
""" A lightweight instance of a connection to postgres
"""
def __init__(self, pool, dbname, dsn):
def __init__(self, pool: ConnectionPool, dbname: str, dsn: dict):
self.__dbname = dbname
self.__dsn = dsn
self.__pool = pool
@property
def dsn(self):
def dsn(self) -> dict:
dsn = dict(self.__dsn)
dsn.pop('password', None)
return dsn
@property
def dbname(self):
def dbname(self) -> str:
return self.__dbname
def cursor(self):
def cursor(self) -> Cursor:
_logger.debug('create cursor to %r', self.dsn)
return Cursor(self.__pool, self.__dbname, self.__dsn)
def __bool__(self):
raise NotImplementedError()
def connection_info_for(db_or_uri, readonly=False):
def connection_info_for(db_or_uri: str, readonly=False) -> tuple[str, dict]:
""" parse the given `db_or_uri` and return a 2-tuple (dbname, connection_params)
Connection params are either a dictionary with a single key ``dsn``
@ -814,14 +772,15 @@ def connection_info_for(db_or_uri, readonly=False):
the default configuration from ``db_`` or ``db_replica_``.
:rtype: (str, dict)
"""
app_name = config['db_app_name']
if 'ODOO_PGAPPNAME' in os.environ:
# Using manual string interpolation for security reason and trimming at default NAMEDATALEN=63
app_name = os.environ['ODOO_PGAPPNAME'].replace('{pid}', str(os.getpid()))[0:63]
else:
app_name = "odoo-%d" % os.getpid()
warnings.warn("Since 19.0, use PGAPPNAME instead of ODOO_PGAPPNAME", DeprecationWarning)
app_name = os.environ['ODOO_PGAPPNAME']
# Using manual string interpolation for security reason and trimming at default NAMEDATALEN=63
app_name = app_name.replace('{pid}', str(os.getpid()))[:63]
if db_or_uri.startswith(('postgresql://', 'postgres://')):
# extract db from uri
us = urls.url_parse(db_or_uri)
us = urls.url_parse(db_or_uri) # type: ignore
if len(us.path) > 1:
db_name = us.path[1:]
elif us.username:
@ -840,31 +799,40 @@ def connection_info_for(db_or_uri, readonly=False):
return db_or_uri, connection_info
_Pool = None
_Pool_readonly = None
def db_connect(to, allow_uri=False, readonly=False):
_Pool: ConnectionPool | None = None
_Pool_readonly: ConnectionPool | None = None
def db_connect(to: str, allow_uri=False, readonly=False) -> Connection:
global _Pool, _Pool_readonly # noqa: PLW0603 (global-statement)
maxconn = odoo.evented and tools.config['db_maxconn_gevent'] or tools.config['db_maxconn']
if _Pool is None and not readonly:
_Pool = ConnectionPool(int(maxconn), readonly=False)
if _Pool_readonly is None and readonly:
_Pool_readonly = ConnectionPool(int(maxconn), readonly=True)
maxconn = (tools.config['db_maxconn_gevent'] if hasattr(odoo, 'evented') and odoo.evented else 0) or tools.config['db_maxconn']
_Pool_readonly if readonly else _Pool
if readonly:
if _Pool_readonly is None:
_Pool_readonly = ConnectionPool(int(maxconn), readonly=True)
pool = _Pool_readonly
else:
if _Pool is None:
_Pool = ConnectionPool(int(maxconn), readonly=False)
pool = _Pool
db, info = connection_info_for(to, readonly)
if not allow_uri and db != to:
raise ValueError('URI connections not allowed')
return Connection(_Pool_readonly if readonly else _Pool, db, info)
return Connection(pool, db, info)
def close_db(db_name):
def close_db(db_name: str) -> None:
""" You might want to call odoo.modules.registry.Registry.delete(db_name) along this function."""
if _Pool:
_Pool.close_all(connection_info_for(db_name)[1])
if _Pool_readonly:
_Pool_readonly.close_all(connection_info_for(db_name)[1])
def close_all():
def close_all() -> None:
if _Pool:
_Pool.close_all()
if _Pool_readonly: