17.0 vanilla

This commit is contained in:
Ernad Husremovic 2025-10-03 18:05:14 +02:00
parent 2e65bf056a
commit df627a6bba
328 changed files with 578149 additions and 759311 deletions

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.
SUPPORTED_DEBUGGER = {'pdb', 'ipdb', 'wdb', 'pudb'}
from . import _monkeypatches
from . import _monkeypatches_pytz
@ -12,6 +11,7 @@ if not hasattr(urls, 'url_join'):
from . import appdirs
from . import cloc
from . import constants
from . import pdf
from . import pycompat
from . import win32

View file

@ -1039,6 +1039,7 @@ urls.url_encode = url_encode
urls.url_join = url_join
urls.url_parse = url_parse
urls.url_quote = url_quote
urls.url_unquote = url_unquote
urls.url_quote_plus = url_quote_plus
urls.url_unquote_plus = url_unquote_plus
urls.url_unparse = url_unparse

View file

@ -4,8 +4,9 @@
# decorator makes wrappers that have the same API as their wrapped function
from collections import Counter, defaultdict
from decorator import decorator
from inspect import signature
from inspect import signature, Parameter
import logging
import warnings
unsafe_eval = eval
@ -52,19 +53,32 @@ class ormcache(object):
def __init__(self, *args, **kwargs):
self.args = args
self.skiparg = kwargs.get('skiparg')
self.cache_name = kwargs.get('cache', 'default')
def __call__(self, method):
self.method = method
self.determine_key()
lookup = decorator(self.lookup, method)
lookup.clear_cache = self.clear
lookup.__cache__ = self
return lookup
def add_value(self, *args, cache_value=None, **kwargs):
model = args[0]
d, key0, _ = self.lru(model)
key = key0 + self.key(*args, **kwargs)
d[key] = cache_value
def determine_key(self):
""" Determine the function that computes a cache key from arguments. """
if self.skiparg is None:
# build a string that represents function code and evaluate it
args = str(signature(self.method))[1:-1]
args = ', '.join(
# remove annotations because lambdas can't be type-annotated,
# and defaults because they are redundant (defaults are present
# in the wrapper function itself)
str(params.replace(annotation=Parameter.empty, default=Parameter.empty))
for params in signature(self.method).parameters.values()
)
if self.args:
code = "lambda %s: (%s,)" % (args, ", ".join(self.args))
else:
@ -76,7 +90,7 @@ class ormcache(object):
def lru(self, model):
counter = STAT[(model.pool.db_name, model._name, self.method)]
return model.pool._Registry__cache, (model._name, self.method), counter
return model.pool._Registry__caches[self.cache_name], (model._name, self.method), counter
def lookup(self, method, *args, **kwargs):
d, key0, counter = self.lru(args[0])
@ -96,7 +110,8 @@ class ormcache(object):
def clear(self, model, *args):
""" Clear the registry cache """
model.pool._clear_cache()
warnings.warn('Deprecated method ormcache.clear(model, *args), use registry.clear_cache() instead')
model.pool.clear_all_caches()
class ormcache_context(ormcache):
@ -114,7 +129,10 @@ class ormcache_context(ormcache):
assert self.skiparg is None, "ormcache_context() no longer supports skiparg"
# build a string that represents function code and evaluate it
sign = signature(self.method)
args = str(sign)[1:-1]
args = ', '.join(
str(params.replace(annotation=Parameter.empty, default=Parameter.empty))
for params in sign.parameters.values()
)
cont_expr = "(context or {})" if 'context' in sign.parameters else "self._context"
keys_expr = "tuple(%s.get(k) for k in %r)" % (cont_expr, self.keys)
if self.args:
@ -124,78 +142,6 @@ class ormcache_context(ormcache):
self.key = unsafe_eval(code)
class ormcache_multi(ormcache):
""" This LRU cache decorator is a variant of :class:`ormcache`, with an
extra parameter ``multi`` that gives the name of a parameter. Upon call, the
corresponding argument is iterated on, and every value leads to a cache
entry under its own key.
"""
def __init__(self, *args, **kwargs):
super(ormcache_multi, self).__init__(*args, **kwargs)
self.multi = kwargs['multi']
def determine_key(self):
""" Determine the function that computes a cache key from arguments. """
assert self.skiparg is None, "ormcache_multi() no longer supports skiparg"
assert isinstance(self.multi, str), "ormcache_multi() parameter multi must be an argument name"
super(ormcache_multi, self).determine_key()
# key_multi computes the extra element added to the key
sign = signature(self.method)
args = str(sign)[1:-1]
code_multi = "lambda %s: %s" % (args, self.multi)
self.key_multi = unsafe_eval(code_multi)
# self.multi_pos is the position of self.multi in args
self.multi_pos = list(sign.parameters).index(self.multi)
def lookup(self, method, *args, **kwargs):
d, key0, counter = self.lru(args[0])
base_key = key0 + self.key(*args, **kwargs)
ids = self.key_multi(*args, **kwargs)
result = {}
missed = []
# first take what is available in the cache
for i in ids:
key = base_key + (i,)
try:
result[i] = d[key]
counter.hit += 1
except Exception:
counter.miss += 1
missed.append(i)
if missed:
# call the method for the ids that were not in the cache; note that
# thanks to decorator(), the multi argument will be bound and passed
# positionally in args.
args = list(args)
args[self.multi_pos] = missed
result.update(method(*args, **kwargs))
# store those new results back in the cache
for i in missed:
key = base_key + (i,)
d[key] = result[i]
return result
class dummy_cache(object):
""" Cache decorator replacement to actually do no caching. """
def __init__(self, *l, **kw):
pass
def __call__(self, fn):
fn.clear_cache = self.clear
return fn
def clear(self, *l, **kw):
pass
def log_ormcache_stats(sig=None, frame=None):
""" Log statistics of ormcache usage by database, model, and method. """
from odoo.modules.registry import Registry
@ -204,26 +150,30 @@ def log_ormcache_stats(sig=None, frame=None):
me = threading.current_thread()
me_dbname = getattr(me, 'dbname', 'n/a')
for dbname, reg in sorted(Registry.registries.d.items()):
# set logger prefix to dbname
me.dbname = dbname
entries = Counter(k[:2] for k in reg._Registry__cache.d)
def _log_ormcache_stats(cache_name, cache):
entries = Counter(k[:2] for k in cache.d)
# show entries sorted by model name, method name
for key in sorted(entries, key=lambda key: (key[0], key[1].__name__)):
model, method = key
stat = STAT[(dbname, model, method)]
_logger.info(
"%6d entries, %6d hit, %6d miss, %6d err, %4.1f%% ratio, for %s.%s",
entries[key], stat.hit, stat.miss, stat.err, stat.ratio, model, method.__name__,
"%s, %6d entries, %6d hit, %6d miss, %6d err, %4.1f%% ratio, for %s.%s",
cache_name.rjust(25), entries[key], stat.hit, stat.miss, stat.err, stat.ratio, model, method.__name__,
)
for dbname, reg in sorted(Registry.registries.d.items()):
# set logger prefix to dbname
me.dbname = dbname
for cache_name, cache in reg._Registry__caches.items():
_log_ormcache_stats(cache_name, cache)
me.dbname = me_dbname
def get_cache_key_counter(bound_method, *args, **kwargs):
""" Return the cache, key and stat counter for the given call. """
model = bound_method.__self__
ormcache = bound_method.clear_cache.__self__
ormcache = bound_method.__cache__
cache, key0, counter = ormcache.lru(model)
key = key0 + ormcache.key(model, *args, **kwargs)
return cache, key, counter

View file

@ -266,6 +266,8 @@ class configmanager(object):
help="specify the database ssl connection mode (see PostgreSQL documentation)")
group.add_option("--db_maxconn", dest="db_maxconn", type='int', my_default=64,
help="specify the maximum number of physical connections to PostgreSQL")
group.add_option("--db_maxconn_gevent", dest="db_maxconn_gevent", type='int', my_default=False,
help="specify the maximum number of physical connections to PostgreSQL specifically for the gevent worker")
group.add_option("--db-template", dest="db_template", my_default="template0",
help="specify a custom database template to create a new database")
parser.add_option_group(group)
@ -314,9 +316,6 @@ class configmanager(object):
help="Time limit (decimal value in hours) records created with a "
"TransientModel (mostly wizard) are kept in the database. Default to 1 hour.",
type="float")
group.add_option("--osv-memory-age-limit", dest="osv_memory_age_limit", my_default=False,
help="Deprecated alias to the transient-age-limit option",
type="float")
group.add_option("--max-cron-threads", dest="max_cron_threads", my_default=2,
help="Maximum number of threads processing concurrently cron jobs (default 2).",
type="int")
@ -326,8 +325,10 @@ class configmanager(object):
type="int")
group.add_option("--unaccent", dest="unaccent", my_default=False, action="store_true",
help="Try to enable the unaccent extension when creating new databases.")
group.add_option("--geoip-db", dest="geoip_database", my_default='/usr/share/GeoIP/GeoLite2-City.mmdb',
help="Absolute path to the GeoIP database file.")
group.add_option("--geoip-city-db", "--geoip-db", dest="geoip_city_db", my_default='/usr/share/GeoIP/GeoLite2-City.mmdb',
help="Absolute path to the GeoIP City database file.")
group.add_option("--geoip-country-db", dest="geoip_country_db", my_default='/usr/share/GeoIP/GeoLite2-Country.mmdb',
help="Absolute path to the GeoIP Country database file.")
parser.add_option_group(group)
if os.name == 'posix':
@ -419,10 +420,6 @@ class configmanager(object):
"The config file '%s' selected with -c/--config doesn't exist or is not readable, "\
"use -s/--save if you want to generate it"% opt.config)
die(bool(opt.osv_memory_age_limit) and bool(opt.transient_memory_age_limit),
"the osv-memory-count-limit option cannot be used with the "
"transient-age-limit option, please only use the latter.")
# place/search the config file on Win32 near the server installation
# (../etc from the server)
# if the server is run by an unprivileged user, he has to specify location of a config file where he has the rights to write,
@ -464,11 +461,11 @@ class configmanager(object):
'db_port', 'db_template', 'logfile', 'pidfile', 'smtp_port',
'email_from', 'smtp_server', 'smtp_user', 'smtp_password', 'from_filter',
'smtp_ssl_certificate_filename', 'smtp_ssl_private_key_filename',
'db_maxconn', 'import_partial', 'addons_path', 'upgrade_path', 'pre_upgrade_scripts',
'db_maxconn', 'db_maxconn_gevent', 'import_partial', 'addons_path', 'upgrade_path', 'pre_upgrade_scripts',
'syslog', 'without_demo', 'screencasts', 'screenshots',
'dbfilter', 'log_level', 'log_db',
'log_db_level', 'geoip_database', 'dev_mode', 'shell_interface',
'limit_time_worker_cron',
'log_db_level', 'geoip_city_db', 'geoip_country_db', 'dev_mode',
'shell_interface', 'limit_time_worker_cron',
]
for arg in keys:
@ -491,7 +488,7 @@ class configmanager(object):
'stop_after_init', 'without_demo', 'http_enable', 'syslog',
'list_db', 'proxy_mode',
'test_file', 'test_tags',
'osv_memory_count_limit', 'osv_memory_age_limit', 'transient_age_limit', 'max_cron_threads', 'unaccent',
'osv_memory_count_limit', 'transient_age_limit', 'max_cron_threads', 'unaccent',
'data_dir',
'server_wide_modules',
]
@ -515,6 +512,8 @@ class configmanager(object):
elif isinstance(self.options[arg], str) and self.casts[arg].type in optparse.Option.TYPE_CHECKER:
self.options[arg] = optparse.Option.TYPE_CHECKER[self.casts[arg].type](self.casts[arg], arg, self.options[arg])
ismultidb = ',' in (self.options.get('db_name') or '')
die(ismultidb and (opt.init or opt.update), "Cannot use -i/--init or -u/--update with multiple databases in the -d/--database/db_name")
self.options['root_path'] = self._normalize(os.path.join(os.path.dirname(__file__), '..'))
if not self.options['addons_path'] or self.options['addons_path']=='None':
default_addons = []
@ -565,7 +564,7 @@ class configmanager(object):
self.save()
# normalize path options
for key in ['data_dir', 'logfile', 'pidfile', 'test_file', 'screencasts', 'screenshots', 'pg_path', 'translate_out', 'translate_in', 'geoip_database']:
for key in ['data_dir', 'logfile', 'pidfile', 'test_file', 'screencasts', 'screenshots', 'pg_path', 'translate_out', 'translate_in', 'geoip_city_db', 'geoip_country_db']:
self.options[key] = self._normalize(self.options[key])
conf.addons_paths = self.options['addons_path'].split(',')
@ -576,19 +575,48 @@ class configmanager(object):
return opt
def _warn_deprecated_options(self):
if self.options['osv_memory_age_limit']:
warnings.warn(
"The osv-memory-age-limit is a deprecated alias to "
"the transient-age-limit option, please use the latter.",
DeprecationWarning)
self.options['transient_age_limit'] = self.options.pop('osv_memory_age_limit')
if self.options['longpolling_port']:
if self.options.get('longpolling_port', 0):
warnings.warn(
"The longpolling-port is a deprecated alias to "
"the gevent-port option, please use the latter.",
DeprecationWarning)
self.options['gevent_port'] = self.options.pop('longpolling_port')
for old_option_name, new_option_name in [
('geoip_database', 'geoip_city_db'),
('osv_memory_age_limit', 'transient_age_limit')
]:
deprecated_value = self.options.pop(old_option_name, None)
if deprecated_value:
default_value = self.casts[new_option_name].my_default
current_value = self.options[new_option_name]
if deprecated_value in (current_value, default_value):
# Surely this is from a --save that was run in a
# prior version. There is no point in emitting a
# warning because: (1) it holds the same value as
# the correct option, and (2) it is going to be
# automatically removed on the next --save anyway.
pass
elif current_value == default_value:
# deprecated_value != current_value == default_value
# assume the new option was not set
self.options[new_option_name] = deprecated_value
warnings.warn(
f"The {old_option_name!r} option found in the "
"configuration file is a deprecated alias to "
f"{new_option_name!r}, please use the latter.",
DeprecationWarning)
else:
# deprecated_value != current_value != default_value
self.parser.error(
f"The two options {old_option_name!r} "
"(found in the configuration file but "
f"deprecated) and {new_option_name!r} are set "
"to different values. Please remove the first "
"one and make sure the second is correct."
)
def _is_addons_path(self, path):
from odoo.modules.module import MANIFEST_NAMES
for f in os.listdir(path):

View file

@ -28,7 +28,7 @@ except ImportError:
import odoo
from . import pycompat
from .config import config
from .misc import file_open, unquote, ustr, SKIPPED_ELEMENT_TYPES
from .misc import file_open, file_path, SKIPPED_ELEMENT_TYPES
from .translate import _
from odoo import SUPERUSER_ID, api
from odoo.exceptions import ValidationError
@ -155,9 +155,10 @@ def _eval_xml(self, node, env):
# after that, only text content makes sense
data = pycompat.to_text(data)
if t == 'file':
from ..modules import module
path = data.strip()
if not module.get_module_resource(self.module, path):
try:
file_path(os.path.join(self.module, path))
except FileNotFoundError:
raise IOError("No such file or directory: '%s' in %s" % (
path, self.module))
return '%s,%s' % (self.module, path)
@ -229,7 +230,7 @@ class xml_import(object):
**safe_eval(context, {
'ref': self.id_get,
**(eval_context or {})
})
}),
}
)
return self.env
@ -272,160 +273,12 @@ form: module.record_id""" % (xml_id,)
if records:
records.unlink()
def _tag_report(self, rec):
res = {}
for dest,f in (('name','string'),('model','model'),('report_name','name')):
res[dest] = rec.get(f)
assert res[dest], "Attribute %s of report is empty !" % (f,)
for field, dest in (('attachment', 'attachment'),
('attachment_use', 'attachment_use'),
('usage', 'usage'),
('file', 'report_file'),
('report_type', 'report_type'),
('parser', 'parser'),
('print_report_name', 'print_report_name'),
):
if rec.get(field):
res[dest] = rec.get(field)
if rec.get('auto'):
res['auto'] = safe_eval(rec.get('auto','False'))
if rec.get('header'):
res['header'] = safe_eval(rec.get('header','False'))
res['multi'] = rec.get('multi') and safe_eval(rec.get('multi','False'))
xml_id = rec.get('id','')
self._test_xml_id(xml_id)
warnings.warn(f"The <report> tag is deprecated, use a <record> tag for {xml_id!r}.", DeprecationWarning)
if rec.get('groups'):
g_names = rec.get('groups','').split(',')
groups_value = []
for group in g_names:
if group.startswith('-'):
group_id = self.id_get(group[1:])
groups_value.append(odoo.Command.unlink(group_id))
else:
group_id = self.id_get(group)
groups_value.append(odoo.Command.link(group_id))
res['groups_id'] = groups_value
if rec.get('paperformat'):
pf_name = rec.get('paperformat')
pf_id = self.id_get(pf_name)
res['paperformat_id'] = pf_id
xid = self.make_xml_id(xml_id)
data = dict(xml_id=xid, values=res, noupdate=self.noupdate)
report = self.env['ir.actions.report']._load_records([data], self.mode == 'update')
self.idref[xml_id] = report.id
if not rec.get('menu') or safe_eval(rec.get('menu','False')):
report.create_action()
elif self.mode=='update' and safe_eval(rec.get('menu','False'))==False:
# Special check for report having attribute menu=False on update
report.unlink_action()
return report.id
def _tag_function(self, rec):
if self.noupdate and self.mode != 'init':
return
env = self.get_env(rec)
_eval_xml(self, rec, env)
def _tag_act_window(self, rec):
name = rec.get('name')
xml_id = rec.get('id','')
self._test_xml_id(xml_id)
warnings.warn(f"The <act_window> tag is deprecated, use a <record> for {xml_id!r}.", DeprecationWarning)
view_id = False
if rec.get('view_id'):
view_id = self.id_get(rec.get('view_id'))
domain = rec.get('domain') or '[]'
res_model = rec.get('res_model')
binding_model = rec.get('binding_model')
view_mode = rec.get('view_mode') or 'tree,form'
usage = rec.get('usage')
limit = rec.get('limit')
uid = self.env.user.id
# Act_window's 'domain' and 'context' contain mostly literals
# but they can also refer to the variables provided below
# in eval_context, so we need to eval() them before storing.
# Among the context variables, 'active_id' refers to
# the currently selected items in a list view, and only
# takes meaning at runtime on the client side. For this
# reason it must remain a bare variable in domain and context,
# even after eval() at server-side. We use the special 'unquote'
# class to achieve this effect: a string which has itself, unquoted,
# as representation.
active_id = unquote("active_id")
active_ids = unquote("active_ids")
active_model = unquote("active_model")
# Include all locals() in eval_context, for backwards compatibility
eval_context = {
'name': name,
'xml_id': xml_id,
'type': 'ir.actions.act_window',
'view_id': view_id,
'domain': domain,
'res_model': res_model,
'src_model': binding_model,
'view_mode': view_mode,
'usage': usage,
'limit': limit,
'uid': uid,
'active_id': active_id,
'active_ids': active_ids,
'active_model': active_model,
}
context = self.get_env(rec, eval_context).context
try:
domain = safe_eval(domain, eval_context)
except (ValueError, NameError):
# Some domains contain references that are only valid at runtime at
# client-side, so in that case we keep the original domain string
# as it is. We also log it, just in case.
_logger.debug('Domain value (%s) for element with id "%s" does not parse '\
'at server-side, keeping original string, in case it\'s meant for client side only',
domain, xml_id or 'n/a', exc_info=True)
res = {
'name': name,
'type': 'ir.actions.act_window',
'view_id': view_id,
'domain': domain,
'context': context,
'res_model': res_model,
'view_mode': view_mode,
'usage': usage,
'limit': limit,
}
if rec.get('groups'):
g_names = rec.get('groups','').split(',')
groups_value = []
for group in g_names:
if group.startswith('-'):
group_id = self.id_get(group[1:])
groups_value.append(odoo.Command.unlink(group_id))
else:
group_id = self.id_get(group)
groups_value.append(odoo.Command.link(group_id))
res['groups_id'] = groups_value
if rec.get('target'):
res['target'] = rec.get('target','')
if binding_model:
res['binding_model_id'] = self.env['ir.model']._get(binding_model).id
res['binding_type'] = rec.get('binding_type') or 'action'
views = rec.get('binding_views')
if views is not None:
res['binding_view_types'] = views
xid = self.make_xml_id(xml_id)
data = dict(xml_id=xid, values=res, noupdate=self.noupdate)
self.env['ir.actions.act_window']._load_records([data], self.mode == 'update')
def _tag_menuitem(self, rec, parent=None):
rec_id = rec.attrib["id"]
self._test_xml_id(rec_id)
@ -556,7 +409,7 @@ form: module.record_id""" % (xml_id,)
if f_search:
idref2 = _get_idref(self, env, f_model, self.idref)
q = safe_eval(f_search, idref2)
assert f_model, 'Define an attribute model="..." in your .XML file !'
assert f_model, 'Define an attribute model="..." in your .XML file!'
# browse the objects searched
s = env[f_model].search(q)
# column definitions of the "local" object
@ -602,6 +455,10 @@ form: module.record_id""" % (xml_id,)
res[f_name] = f_val
if extra_vals:
res.update(extra_vals)
if 'sequence' not in res and 'sequence' in model._fields:
sequence = self.next_sequence()
if sequence:
res['sequence'] = sequence
data = dict(xml_id=xid, values=res, noupdate=self.noupdate)
record = model._load_records([data], self.mode == 'update')
@ -702,6 +559,7 @@ form: module.record_id""" % (xml_id,)
self.envs.append(self.get_env(el))
self._noupdate.append(nodeattr2bool(el, 'noupdate', self.noupdate))
self._sequences.append(0 if nodeattr2bool(el, 'auto_sequence', False) else None)
try:
f(rec)
except ParseError:
@ -724,6 +582,7 @@ form: module.record_id""" % (xml_id,)
finally:
self._noupdate.pop()
self.envs.pop()
self._sequences.pop()
@property
def env(self):
@ -733,12 +592,19 @@ form: module.record_id""" % (xml_id,)
def noupdate(self):
return self._noupdate[-1]
def __init__(self, cr, module, idref, mode, noupdate=False, xml_filename=None):
def next_sequence(self):
value = self._sequences[-1]
if value is not None:
value = self._sequences[-1] = value + 10
return value
def __init__(self, env, module, idref, mode, noupdate=False, xml_filename=None):
self.mode = mode
self.module = module
self.envs = [odoo.api.Environment(cr, SUPERUSER_ID, {})]
self.envs = [env(context=dict(env.context, lang=None))]
self.idref = {} if idref is None else idref
self._noupdate = [noupdate]
self._sequences = []
self.xml_filename = xml_filename
self._tags = {
'record': self._tag_record,
@ -746,8 +612,6 @@ form: module.record_id""" % (xml_id,)
'function': self._tag_function,
'menuitem': self._tag_menuitem,
'template': self._tag_template,
'report': self._tag_report,
'act_window': self._tag_act_window,
**dict.fromkeys(self.DATA_ROOTS, self._tag_root)
}
@ -757,32 +621,33 @@ form: module.record_id""" % (xml_id,)
self._tag_root(de)
DATA_ROOTS = ['odoo', 'data', 'openerp']
def convert_file(cr, module, filename, idref, mode='update', noupdate=False, kind=None, pathname=None):
def convert_file(env, module, filename, idref, mode='update', noupdate=False, kind=None, pathname=None):
if pathname is None:
pathname = os.path.join(module, filename)
ext = os.path.splitext(filename)[1].lower()
with file_open(pathname, 'rb') as fp:
if ext == '.csv':
convert_csv_import(cr, module, pathname, fp.read(), idref, mode, noupdate)
convert_csv_import(env, module, pathname, fp.read(), idref, mode, noupdate)
elif ext == '.sql':
convert_sql_import(cr, fp)
convert_sql_import(env, fp)
elif ext == '.xml':
convert_xml_import(cr, module, fp, idref, mode, noupdate)
convert_xml_import(env, module, fp, idref, mode, noupdate)
elif ext == '.js':
pass # .js files are valid but ignored here.
else:
raise ValueError("Can't load unknown file type %s.", filename)
def convert_sql_import(cr, fp):
cr.execute(fp.read()) # pylint: disable=sql-injection
def convert_sql_import(env, fp):
env.cr.execute(fp.read()) # pylint: disable=sql-injection
def convert_csv_import(cr, module, fname, csvcontent, idref=None, mode='init',
def convert_csv_import(env, module, fname, csvcontent, idref=None, mode='init',
noupdate=False):
'''Import csv file :
quote: "
delimiter: ,
encoding: utf-8'''
env = env(context=dict(env.context, lang=None))
filename, _ext = os.path.splitext(os.path.basename(fname))
model = filename.split('-')[0]
reader = pycompat.csv_reader(io.BytesIO(csvcontent), quotechar='"', delimiter=',')
@ -805,21 +670,20 @@ def convert_csv_import(cr, module, fname, csvcontent, idref=None, mode='init',
'install_filename': fname,
'noupdate': noupdate,
}
env = odoo.api.Environment(cr, SUPERUSER_ID, context)
result = env[model].load(fields, datas)
result = env[model].with_context(**context).load(fields, datas)
if any(msg['type'] == 'error' for msg in result['messages']):
# Report failed import and abort module install
warning_msg = "\n".join(msg['message'] for msg in result['messages'])
raise Exception(_('Module loading %s failed: file %s could not be processed:\n %s') % (module, fname, warning_msg))
def convert_xml_import(cr, module, xmlfile, idref=None, mode='init', noupdate=False, report=None):
def convert_xml_import(env, module, xmlfile, idref=None, mode='init', noupdate=False, report=None):
doc = etree.parse(xmlfile)
schema = os.path.join(config['root_path'], 'import_xml.rng')
relaxng = etree.RelaxNG(etree.parse(schema))
try:
relaxng.assert_(doc)
except Exception:
_logger.exception("The XML file '%s' does not fit the required schema !", xmlfile.name)
_logger.exception("The XML file '%s' does not fit the required schema!", xmlfile.name)
if jingtrang:
p = subprocess.run(['pyjing', schema, xmlfile.name], stdout=subprocess.PIPE)
_logger.warning(p.stdout.decode())
@ -833,5 +697,5 @@ def convert_xml_import(cr, module, xmlfile, idref=None, mode='init', noupdate=Fa
xml_filename = xmlfile
else:
xml_filename = xmlfile.name
obj = xml_import(cr, module, idref, mode, noupdate=noupdate, xml_filename=xml_filename)
obj = xml_import(env, module, idref, mode, noupdate=noupdate, xml_filename=xml_filename)
obj.parse(doc.getroot())

View file

@ -223,25 +223,35 @@ def json_default(obj):
def date_range(start, end, step=relativedelta(months=1)):
"""Date range generator with a step interval.
:param datetime start: beginning date of the range.
:param datetime end: ending date of the range.
:param date | datetime start: beginning date of the range.
:param date | datetime end: ending date of the range.
:param relativedelta step: interval of the range.
:return: a range of datetime from start to end.
:rtype: Iterator[datetime]
"""
if isinstance(start, datetime) and isinstance(end, datetime):
are_naive = start.tzinfo is None and end.tzinfo is None
are_utc = start.tzinfo == pytz.utc and end.tzinfo == pytz.utc
are_naive = start.tzinfo is None and end.tzinfo is None
are_utc = start.tzinfo == pytz.utc and end.tzinfo == pytz.utc
# Cases with miscellenous timezone are more complexe because of DST.
are_others = start.tzinfo and end.tzinfo and not are_utc
# Cases with miscellenous timezone are more complexe because of DST.
are_others = start.tzinfo and end.tzinfo and not are_utc
if are_others:
if start.tzinfo.zone != end.tzinfo.zone:
if are_others and start.tzinfo.zone != end.tzinfo.zone:
raise ValueError("Timezones of start argument and end argument seem inconsistent")
if not are_naive and not are_utc and not are_others:
raise ValueError("Timezones of start argument and end argument mismatch")
if not are_naive and not are_utc and not are_others:
raise ValueError("Timezones of start argument and end argument mismatch")
dt = start.replace(tzinfo=None)
end_dt = end.replace(tzinfo=None)
post_process = start.tzinfo.localize if start.tzinfo else lambda dt: dt
elif isinstance(start, date) and isinstance(end, date):
dt, end_dt = start, end
post_process = lambda dt: dt
else:
raise ValueError("start/end should be both date or both datetime type")
if start > end:
raise ValueError("start > end, start date must be before end")
@ -249,15 +259,8 @@ def date_range(start, end, step=relativedelta(months=1)):
if start == start + step:
raise ValueError("Looks like step is null")
if start.tzinfo:
localize = start.tzinfo.localize
else:
localize = lambda dt: dt
dt = start.replace(tzinfo=None)
end = end.replace(tzinfo=None)
while dt <= end:
yield localize(dt)
while dt <= end_dt:
yield post_process(dt)
dt = dt + step

View file

@ -45,10 +45,13 @@ def float_round(value, precision_digits=None, precision_rounding=None, rounding_
:param float precision_rounding: decimal number representing the minimum
non-zero value at the desired precision (for example, 0.01 for a
2-digit precision).
:param rounding_method: the rounding method used: 'HALF-UP', 'UP' or 'DOWN',
the first one rounding up to the closest number with the rule that
number>=0.5 is rounded up to 1, the second always rounding up and the
latest one always rounding down.
:param rounding_method: the rounding method used:
- 'HALF-UP' will round to the closest number with ties going away from zero.
- 'HALF-DOWN' will round to the closest number with ties going towards zero.
- 'HALF_EVEN' will round to the closest number with ties going to the closest
even number.
- 'UP' will always round away from 0.
- 'DOWN' will always round towards 0.
:return: rounded float
"""
rounding_factor = _float_check_precision(precision_digits=precision_digits,
@ -90,6 +93,17 @@ def float_round(value, precision_digits=None, precision_rounding=None, rounding_
normalized_value += sign*epsilon
rounded_value = math.floor(abs(normalized_value)) * sign
# TIE-BREAKING: HALF-EVEN
# We want to apply HALF-EVEN tie-breaking rules, i.e. 0.5 rounds towards closest even number.
elif rounding_method == 'HALF-EVEN':
rounded_value = math.copysign(builtins.round(normalized_value), normalized_value)
# TIE-BREAKING: HALF-DOWN
# We want to apply HALF-DOWN tie-breaking rules, i.e. 0.5 rounds towards 0.
elif rounding_method == 'HALF-DOWN':
normalized_value -= math.copysign(epsilon, normalized_value)
rounded_value = round(normalized_value)
# TIE-BREAKING: HALF-UP (for normal rounding)
# We want to apply HALF-UP tie-breaking rules, i.e. 0.5 rounds away from 0.
else:

View file

@ -1,66 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os.path
try:
import GeoIP # Legacy
except ImportError:
GeoIP = None
try:
import geoip2
import geoip2.database
except ImportError:
geoip2 = None
class GeoIPResolver(object):
def __init__(self, fname):
self.fname = fname
try:
self._db = geoip2.database.Reader(fname)
self.version = 2
except Exception:
try:
self._db = GeoIP.open(fname, GeoIP.GEOIP_STANDARD)
self.version = 1
assert self._db.database_info is not None
except Exception:
raise ValueError('Invalid GeoIP database: %r' % fname)
def __del__(self):
if self.version == 2:
self._db.close()
@classmethod
def open(cls, fname):
if not GeoIP and not geoip2:
return None
if not os.path.exists(fname):
return None
return GeoIPResolver(fname)
def resolve(self, ip):
if self.version == 1:
return self._db.record_by_addr(ip) or {}
elif self.version == 2:
try:
r = self._db.city(ip)
except (ValueError, geoip2.errors.AddressNotFoundError):
return {}
# Compatibility with Legacy database.
# Some ips cannot be located to a specific country. Legacy DB used to locate them in
# continent instead of country. Do the same to not change behavior of existing code.
country, attr = (r.country, 'iso_code') if r.country.geoname_id else (r.continent, 'code')
return {
'city': r.city.name,
'country_code': getattr(country, attr),
'country_name': country.name,
'latitude': r.location.latitude,
'longitude': r.location.longitude,
'region': r.subdivisions[0].iso_code if r.subdivisions else None,
'time_zone': r.location.time_zone,
}
# compat
def record_by_addr(self, addr):
return self.resolve(addr)

View file

@ -15,6 +15,7 @@ except ImportError:
from random import randrange
from odoo.exceptions import UserError
from odoo.tools.misc import DotDict
from odoo.tools.translate import _
@ -29,6 +30,7 @@ FILETYPE_BASE64_MAGICWORD = {
b'R': 'gif',
b'i': 'png',
b'P': 'svg+xml',
b'U': 'webp',
}
EXIF_TAG_ORIENTATION = 0x112
@ -73,8 +75,8 @@ class ImageProcess():
self.source = source or False
self.operationsCount = 0
if not source or source[:1] == b'<':
# don't process empty source or SVG
if not source or source[:1] == b'<' or (source[0:4] == b'RIFF' and source[8:15] == b'WEBPVP8'):
# don't process empty source or SVG or WEBP
self.image = False
else:
try:
@ -149,7 +151,12 @@ class ImageProcess():
if output_image.mode not in ["1", "L", "P", "RGB", "RGBA"] or (output_format == 'JPEG' and output_image.mode == 'RGBA'):
output_image = output_image.convert("RGB")
return image_apply_opt(output_image, **opt)
output_bytes = image_apply_opt(output_image, **opt)
if len(output_bytes) >= len(self.source) and self.original_format == output_format and not self.operationsCount:
# Format has not changed and image content is unchanged but the
# reached binary is bigger: rather use the original.
return self.source
return output_bytes
def resize(self, max_width=0, max_height=0):
"""Resize the image.
@ -435,6 +442,43 @@ def image_to_base64(image, output_format, **params):
return base64.b64encode(stream)
def get_webp_size(source):
"""
Returns the size of the provided webp binary source for VP8, VP8X and
VP8L, otherwise returns None.
See https://developers.google.com/speed/webp/docs/riff_container.
:param source: binary source
:return: (width, height) tuple, or None if not supported
"""
if not (source[0:4] == b'RIFF' and source[8:15] == b'WEBPVP8'):
raise UserError(_("This file is not a webp file."))
vp8_type = source[15]
if vp8_type == 0x20: # 0x20 = ' '
# Sizes on big-endian 16 bits at offset 26.
width_low, width_high, height_low, height_high = source[26:30]
width = (width_high << 8) + width_low
height = (height_high << 8) + height_low
return (width, height)
elif vp8_type == 0x58: # 0x48 = 'X'
# Sizes (minus one) on big-endian 24 bits at offset 24.
width_low, width_medium, width_high, height_low, height_medium, height_high = source[24:30]
width = 1 + (width_high << 16) + (width_medium << 8) + width_low
height = 1 + (height_high << 16) + (height_medium << 8) + height_low
return (width, height)
elif vp8_type == 0x4C and source[20] == 0x2F: # 0x4C = 'L'
# Sizes (minus one) on big-endian-ish 14 bits at offset 21.
# E.g. [@20] 2F ab cd ef gh
# - width = 1 + (c&0x3)d ab: ignore the two high bits of the second byte
# - height= 1 + hef(c&0xC>>2): used them as the first two bits of the height
ab, cd, ef, gh = source[21:25]
width = 1 + ((cd & 0x3F) << 8) + ab
height = 1 + ((gh & 0xF) << 10) + (ef << 2) + (cd >> 6)
return (width, height)
return None
def is_image_size_above(base64_source_1, base64_source_2):
"""Return whether or not the size of the given image `base64_source_1` is
above the size of the given image `base64_source_2`.
@ -444,8 +488,21 @@ def is_image_size_above(base64_source_1, base64_source_2):
if base64_source_1[:1] in (b'P', 'P') or base64_source_2[:1] in (b'P', 'P'):
# False for SVG
return False
image_source = image_fix_orientation(base64_to_image(base64_source_1))
image_target = image_fix_orientation(base64_to_image(base64_source_2))
def get_image_size(base64_source):
source = base64.b64decode(base64_source)
if (source[0:4] == b'RIFF' and source[8:15] == b'WEBPVP8'):
size = get_webp_size(source)
if size:
return DotDict({'width': size[0], 'height': size[0]})
else:
# False for unknown WEBP format
return False
else:
return image_fix_orientation(binary_to_image(source))
image_source = get_image_size(base64_source_1)
image_target = get_image_size(base64_source_2)
return image_source.width > image_target.width or image_source.height > image_target.height

View file

@ -1,6 +1,6 @@
"""
This code is what let us use ES6-style modules in odoo.
Classic Odoo modules are composed of a top-level :samp:`odoo.define({name},{body_function})` call.
Classic Odoo modules are composed of a top-level :samp:`odoo.define({name},{dependencies},{body_function})` call.
This processor will take files starting with an `@odoo-module` annotation (in a comment) and convert them to classic modules.
If any file has the ``/** odoo-module */`` on top of it, it will get processed by this class.
It performs several operations to get from ES6 syntax to the usual odoo one with minimal changes.
@ -15,6 +15,8 @@ import re
import logging
from functools import partial
from odoo.tools.misc import OrderedSet
_logger = logging.getLogger(__name__)
def transpile_javascript(url, content):
@ -27,7 +29,7 @@ def transpile_javascript(url, content):
"""
module_path = url_to_module_path(url)
legacy_odoo_define = get_aliased_odoo_define_content(module_path, content)
dependencies = OrderedSet()
# The order of the operations does sometimes matter.
steps = [
convert_legacy_default_import,
@ -39,14 +41,15 @@ def transpile_javascript(url, content):
convert_unnamed_relative_import,
convert_from_export,
convert_star_from_export,
partial(convert_relative_require, url),
remove_index,
partial(convert_relative_require, url, dependencies),
convert_export_function,
convert_export_class,
convert_variable_export,
convert_object_export,
convert_default_export,
partial(wrap_with_odoo_define, module_path),
partial(wrap_with_qunit_module, url),
partial(wrap_with_odoo_define, module_path, dependencies),
]
for s in steps:
content = s(content)
@ -65,7 +68,7 @@ URL_RE = re.compile(r"""
def url_to_module_path(url):
"""
Odoo modules each have a name. (odoo.define("<the name>", async function (require) {...});
Odoo modules each have a name. (odoo.define("<the name>", [<dependencies>], function (require) {...});
It is used in to be required later. (const { something } = require("<the name>").
The transpiler transforms the url of the file in the project to this name.
It takes the module name and add a @ on the start of it, and map it to be the source of the static/src (or
@ -94,13 +97,23 @@ def url_to_module_path(url):
else:
raise ValueError("The js file %r must be in the folder '/static/src' or '/static/lib' or '/static/test'" % url)
def wrap_with_qunit_module(url, content):
"""
Wraps the test file content (source code) with the QUnit.module('module_name', function() {...}).
"""
if "tests" in url and re.search(r'QUnit\.(test|debug|only)\(', content):
match = URL_RE.match(url)
return f"""QUnit.module("{match["module"]}", function() {{{content}}});"""
else:
return content
def wrap_with_odoo_define(module_path, content):
def wrap_with_odoo_define(module_path, dependencies, content):
"""
Wraps the current content (source code) with the odoo.define call.
It adds as a second argument the list of dependencies.
Should logically be called once all other operations have been performed.
"""
return f"""odoo.define({module_path!r}, async function (require) {{
return f"""odoo.define({module_path!r}, {list(dependencies)}, function (require) {{
'use strict';
let __exports = {{}};
{content}
@ -503,14 +516,15 @@ def convert_default_and_named_import(content):
RELATIVE_REQUIRE_RE = re.compile(r"""
require\((?P<quote>["'`])([^@"'`]+)(?P=quote)\) # require("some/path")
""", re.VERBOSE)
^[^/*\n]*require\((?P<quote>[\"'`])([^\"'`]+)(?P=quote)\) # require("some/path")
""", re.MULTILINE | re.VERBOSE)
def convert_relative_require(url, content):
def convert_relative_require(url, dependencies, content):
"""
Convert the relative path contained in a 'require()'
to the new path system (@module/path)
to the new path system (@module/path).
Adds all modules path to dependencies.
.. code-block:: javascript
// Relative path:
@ -527,10 +541,13 @@ def convert_relative_require(url, content):
"""
new_content = content
for quote, path in RELATIVE_REQUIRE_RE.findall(new_content):
module_path = path
if path.startswith(".") and "/" in path:
pattern = rf"require\({quote}{path}{quote}\)"
repl = f'require("{relative_path_to_module_path(url, path)}")'
module_path = relative_path_to_module_path(url, path)
repl = f'require("{module_path}")'
new_content = re.sub(pattern, repl, new_content)
dependencies.add(module_path)
return new_content
@ -677,7 +694,7 @@ def get_aliased_odoo_define_content(module_path, content):
we have a problem when we will have converted to module to ES6: its new name will be more like
"web/chrome/abstract_action". So the require would fail !
So we add a second small modules, an alias, as such:
> odoo.define("web/chrome/abstract_action", async function(require) {
> odoo.define("web/chrome/abstract_action", ['web.AbstractAction'], function (require) {
> return require('web.AbstractAction')[Symbol.for("default")];
> });
@ -707,13 +724,13 @@ def get_aliased_odoo_define_content(module_path, content):
alias = matchobj['alias']
if alias:
if matchobj['default']:
return """\nodoo.define(`%s`, async function(require) {
return """\nodoo.define(`%s`, ['%s'], function (require) {
return require('%s');
});\n""" % (alias, module_path)
});\n""" % (alias, module_path, module_path)
else:
return """\nodoo.define(`%s`, async function(require) {
return """\nodoo.define(`%s`, ['%s'], function (require) {
return require('%s')[Symbol.for("default")];
});\n""" % (alias, module_path)
});\n""" % (alias, module_path, module_path)
def convert_as(val):

View file

@ -45,12 +45,12 @@ safe_attrs = defs.safe_attrs | frozenset(
['style',
'data-o-mail-quote', 'data-o-mail-quote-node', # quote detection
'data-oe-model', 'data-oe-id', 'data-oe-field', 'data-oe-type', 'data-oe-expression', 'data-oe-translation-initial-sha', 'data-oe-nodeid',
'data-last-history-steps', 'data-width', 'data-height', 'data-scale-x', 'data-scale-y', 'data-x', 'data-y',
'data-last-history-steps', 'data-oe-protected', 'data-oe-transient-content', 'data-width', 'data-height', 'data-scale-x', 'data-scale-y', 'data-x', 'data-y',
'data-publish', 'data-id', 'data-res_id', 'data-interval', 'data-member_id', 'data-scroll-background-ratio', 'data-view-id',
'data-class', 'data-mimetype', 'data-original-src', 'data-original-id', 'data-gl-filter', 'data-quality', 'data-resize-width',
'data-shape', 'data-shape-colors', 'data-file-name', 'data-original-mimetype',
'data-oe-protected', # editor
'data-behavior-props', 'data-prop-name', # knowledge commands
'data-mimetype-before-conversion',
])
SANITIZE_TAGS = {
# allow new semantic HTML5 tags
@ -455,7 +455,7 @@ def html2plaintext(html, body_id=None, encoding='utf-8'):
html = html.replace('&gt;', '>')
html = html.replace('&lt;', '<')
html = html.replace('&amp;', '&')
html = html.replace('&nbsp;', u'\N{NO-BREAK SPACE}')
html = html.replace('&nbsp;', '\N{NO-BREAK SPACE}')
# strip all lines
html = '\n'.join([x.strip() for x in html.splitlines()])
@ -874,6 +874,33 @@ def encapsulate_email(old_email, new_email):
new_email_split[0][1],
))
def parse_contact_from_email(text):
""" Parse contact name and email (given by text) in order to find contact
information, able to populate records like partners, leads, ...
Supported syntax:
* Raoul <raoul@grosbedon.fr>
* "Raoul le Grand" <raoul@grosbedon.fr>
* Raoul raoul@grosbedon.fr (strange fault tolerant support from
df40926d2a57c101a3e2d221ecfd08fbb4fea30e now supported directly
in 'email_split_tuples';
Otherwise: default, text is set as name.
:return: name, email (normalized if possible)
"""
if not text or not text.strip():
return '', ''
split_results = email_split_tuples(text)
name, email = split_results[0] if split_results else ('', '')
if email:
email_normalized = email_normalize(email, strict=False) or email
else:
name, email_normalized = text, ''
return name, email_normalized
def unfold_references(msg_references):
""" As it declared in [RFC2822] long header bodies can be "folded" using
CRLF+WSP. Some mail clients split References header body which contains

View file

@ -110,6 +110,10 @@ def _check_svg(data):
if b'<svg' in data and b'/svg' in data:
return 'image/svg+xml'
def _check_webp(data):
"""This checks the presence of the WEBP and VP8 in the RIFF"""
if data[8:15] == b'WEBPVP8':
return 'image/webp'
# for "master" formats with many subformats, discriminants is a list of
# functions, tried in order and the first non-falsy value returned is the
@ -128,6 +132,9 @@ _mime_mappings = (
_check_svg,
]),
_Entry('image/x-icon', [b'\x00\x00\x01\x00'], []),
_Entry('image/webp', [b'RIFF'], [
_check_webp,
]),
# OLECF files in general (Word, Excel, PPT, default to word because why not?)
_Entry('application/msword', [b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1', b'\x0D\x44\x4F\x43'], [
_check_olecf

View file

@ -13,6 +13,7 @@ import hmac as hmac_lib
import hashlib
import io
import itertools
import logging
import os
import pickle as pickle_
import re
@ -30,7 +31,7 @@ from collections import OrderedDict
from collections.abc import Iterable, Mapping, MutableMapping, MutableSet
from contextlib import ContextDecorator, contextmanager
from difflib import HtmlDiff
from functools import wraps
from functools import reduce, wraps
from itertools import islice, groupby as itergroupby
from operator import itemgetter
@ -47,6 +48,7 @@ import odoo.addons
# get_encodings, ustr and exception_to_unicode were originally from tools.misc.
# There are moved to loglevels until we refactor tools.
from odoo.loglevels import get_encodings, ustr, exception_to_unicode # noqa
from odoo.tools.float_utils import float_round
from . import pycompat
from .cache import *
from .config import config
@ -430,10 +432,9 @@ def scan_languages():
:returns: a list of (lang_code, lang_name) pairs
:rtype: [(str, unicode)]
"""
csvpath = odoo.modules.module.get_resource_path('base', 'data', 'res.lang.csv')
try:
# read (code, name) from languages in base/data/res.lang.csv
with open(csvpath, 'rb') as csvfile:
with file_open('base/data/res.lang.csv', 'rb') as csvfile:
reader = pycompat.csv_reader(csvfile, delimiter=',', quotechar='"')
fields = next(reader)
code_index = fields.index("code")
@ -443,7 +444,7 @@ def scan_languages():
for row in reader
]
except Exception:
_logger.error("Could not read %s", csvpath)
_logger.error("Could not read res.lang.csv")
result = []
return sorted(result or [('en_US', u'English')], key=itemgetter(1))
@ -943,7 +944,7 @@ class ConstantMapping(Mapping):
return self._value
def dumpstacks(sig=None, frame=None, thread_idents=None):
def dumpstacks(sig=None, frame=None, thread_idents=None, log_level=logging.INFO):
""" Signal handler: dump a stack trace for each existing thread or given
thread(s) specified through the ``thread_idents`` sequence.
"""
@ -960,16 +961,29 @@ def dumpstacks(sig=None, frame=None, thread_idents=None):
threads_info = {th.ident: {'repr': repr(th),
'uid': getattr(th, 'uid', 'n/a'),
'dbname': getattr(th, 'dbname', 'n/a'),
'url': getattr(th, 'url', 'n/a')}
'url': getattr(th, 'url', 'n/a'),
'query_count': getattr(th, 'query_count', 'n/a'),
'query_time': getattr(th, 'query_time', None),
'perf_t0': getattr(th, 'perf_t0', None)}
for th in threading.enumerate()}
for threadId, stack in sys._current_frames().items():
if not thread_idents or threadId in thread_idents:
thread_info = threads_info.get(threadId, {})
code.append("\n# Thread: %s (db:%s) (uid:%s) (url:%s)" %
query_time = thread_info.get('query_time')
perf_t0 = thread_info.get('perf_t0')
remaining_time = None
if query_time is not None and perf_t0:
remaining_time = '%.3f' % (time.time() - perf_t0 - query_time)
query_time = '%.3f' % query_time
# qc:query_count qt:query_time pt:python_time (aka remaining time)
code.append("\n# Thread: %s (db:%s) (uid:%s) (url:%s) (qc:%s qt:%s pt:%s)" %
(thread_info.get('repr', threadId),
thread_info.get('dbname', 'n/a'),
thread_info.get('uid', 'n/a'),
thread_info.get('url', 'n/a')))
thread_info.get('url', 'n/a'),
thread_info.get('query_count', 'n/a'),
query_time or 'n/a',
remaining_time or 'n/a'))
for line in extract_stack(stack):
code.append(line)
@ -984,7 +998,7 @@ def dumpstacks(sig=None, frame=None, thread_idents=None):
for line in extract_stack(ob.gr_frame):
code.append(line)
_logger.info("\n".join(code))
_logger.log(log_level, "\n".join(code))
def freehash(arg):
try:
@ -1134,6 +1148,9 @@ class OrderedSet(MutableSet):
def __repr__(self):
return f'{type(self).__name__}({list(self)!r})'
def intersection(self, *others):
return reduce(OrderedSet.__and__, others, self)
class LastOrderedSet(OrderedSet):
""" A set collection that remembers the elements last insertion order. """
@ -1363,36 +1380,66 @@ def babel_locale_parse(lang_code):
except:
return babel.Locale.parse("en_US")
def formatLang(env, value, digits=None, grouping=True, monetary=False, dp=False, currency_obj=False):
def formatLang(env, value, digits=2, grouping=True, monetary=False, dp=None, currency_obj=None, rounding_method='HALF-EVEN', rounding_unit='decimals'):
"""
Assuming 'Account' decimal.precision=3:
formatLang(value) -> digits=2 (default)
formatLang(value, digits=4) -> digits=4
formatLang(value, dp='Account') -> digits=3
formatLang(value, digits=5, dp='Account') -> digits=5
This function will format a number `value` to the appropriate format of the language used.
:param Object env: The environment.
:param float value: The value to be formatted.
:param int digits: The number of decimals digits.
:param bool grouping: Usage of language grouping or not.
:param bool monetary: Usage of thousands separator or not.
.. deprecated:: 13.0
:param str dp: Name of the decimals precision to be used. This will override ``digits``
and ``currency_obj`` precision.
:param Object currency_obj: Currency to be used. This will override ``digits`` precision.
:param str rounding_method: The rounding method to be used:
**'HALF-UP'** will round to the closest number with ties going away from zero,
**'HALF-DOWN'** will round to the closest number with ties going towards zero,
**'HALF_EVEN'** will round to the closest number with ties going to the closest
even number,
**'UP'** will always round away from 0,
**'DOWN'** will always round towards 0.
:param str rounding_unit: The rounding unit to be used:
**decimals** will round to decimals with ``digits`` or ``dp`` precision,
**units** will round to units without any decimals,
**thousands** will round to thousands without any decimals,
**lakhs** will round to lakhs without any decimals,
**millions** will round to millions without any decimals.
:returns: The value formatted.
:rtype: str
"""
if digits is None:
digits = DEFAULT_DIGITS = 2
if dp:
decimal_precision_obj = env['decimal.precision']
digits = decimal_precision_obj.precision_get(dp)
elif currency_obj:
digits = currency_obj.decimal_places
if isinstance(value, str) and not value:
# We don't want to return 0
if value == '':
return ''
lang_obj = get_lang(env)
if rounding_unit == 'decimals':
if dp:
digits = env['decimal.precision'].precision_get(dp)
elif currency_obj:
digits = currency_obj.decimal_places
else:
digits = 0
res = lang_obj.format('%.' + str(digits) + 'f', value, grouping=grouping, monetary=monetary)
rounding_unit_mapping = {
'decimals': 1,
'thousands': 10**3,
'lakhs': 10**5,
'millions': 10**6,
}
value /= rounding_unit_mapping.get(rounding_unit, 1)
rounded_value = float_round(value, precision_digits=digits, rounding_method=rounding_method)
formatted_value = get_lang(env).format(f'%.{digits}f', rounded_value, grouping=grouping, monetary=monetary)
if currency_obj and currency_obj.symbol:
if currency_obj.position == 'after':
res = '%s%s%s' % (res, NON_BREAKING_SPACE, currency_obj.symbol)
elif currency_obj and currency_obj.position == 'before':
res = '%s%s%s' % (currency_obj.symbol, NON_BREAKING_SPACE, res)
return res
arguments = (formatted_value, NON_BREAKING_SPACE, currency_obj.symbol)
return '%s%s%s' % (arguments if currency_obj.position == 'after' else arguments[::-1])
return formatted_value
def format_date(env, value, lang_code=False, date_format=False):
@ -1708,7 +1755,7 @@ def get_diff(data_from, data_to, custom_style=False, dark_color_scheme=False):
For the table to fit the modal width, some custom style is needed.
"""
to_append = {
'diff_header': 'bg-600 text-center align-top px-2',
'diff_header': 'bg-600 text-light text-center align-top px-2',
'diff_next': 'd-none',
}
for old, new in to_append.items():
@ -1802,3 +1849,39 @@ def has_list_types(values, types):
isinstance(values, (list, tuple)) and len(values) == len(types)
and all(isinstance(item, type_) for item, type_ in zip(values, types))
)
def get_flag(country_code: str) -> str:
"""Get the emoji representing the flag linked to the country code.
This emoji is composed of the two regional indicator emoji of the country code.
"""
return "".join(chr(int(f"1f1{ord(c)+165:02x}", base=16)) for c in country_code)
def format_frame(frame):
code = frame.f_code
return f'{code.co_name} {code.co_filename}:{frame.f_lineno}'
def named_to_positional_printf(string: str, args: Mapping) -> tuple[str, tuple]:
""" Convert a named printf-style format string with its arguments to an
equivalent positional format string with its arguments. This implementation
does not support escaped ``%`` characters (``"%%"``).
"""
if '%%' in string:
raise ValueError(f"Unsupported escaped '%' in format string {string!r}")
args = _PrintfArgs(args)
return string % args, tuple(args.values)
class _PrintfArgs:
""" Helper object to turn a named printf-style format string into a positional one. """
__slots__ = ('mapping', 'values')
def __init__(self, mapping):
self.mapping = mapping
self.values = []
def __getitem__(self, key):
self.values.append(self.mapping[key])
return "%s"

View file

@ -1,456 +0,0 @@
# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.
import io
import re
from datetime import datetime
from hashlib import md5
from logging import getLogger
from zlib import compress, decompress
from PIL import Image, PdfImagePlugin
from reportlab.lib import colors
from reportlab.lib.units import cm
from reportlab.lib.utils import ImageReader
from reportlab.pdfgen import canvas
try:
# class were renamed in PyPDF2 > 2.0
# https://pypdf2.readthedocs.io/en/latest/user/migration-1-to-2.html#classes
from PyPDF2 import PdfReader
import PyPDF2
# monkey patch to discard unused arguments as the old arguments were not discarded in the transitional class
# https://pypdf2.readthedocs.io/en/2.0.0/_modules/PyPDF2/_reader.html#PdfReader
class PdfFileReader(PdfReader):
def __init__(self, *args, **kwargs):
if "strict" not in kwargs and len(args) < 2:
kwargs["strict"] = True # maintain the default
kwargs = {k:v for k, v in kwargs.items() if k in ('strict', 'stream')}
super().__init__(*args, **kwargs)
PyPDF2.PdfFileReader = PdfFileReader
from PyPDF2 import PdfFileWriter, PdfFileReader
PdfFileWriter._addObject = PdfFileWriter._add_object
except ImportError:
from PyPDF2 import PdfFileWriter, PdfFileReader
from PyPDF2.generic import DictionaryObject, NameObject, ArrayObject, DecodedStreamObject, NumberObject, createStringObject, ByteStringObject
try:
from fontTools.ttLib import TTFont
except ImportError:
TTFont = None
from odoo.tools.misc import file_open
_logger = getLogger(__name__)
DEFAULT_PDF_DATETIME_FORMAT = "D:%Y%m%d%H%M%S+00'00'"
REGEX_SUBTYPE_UNFORMATED = re.compile(r'^\w+/[\w-]+$')
REGEX_SUBTYPE_FORMATED = re.compile(r'^/\w+#2F[\w-]+$')
# Disable linter warning: this import is needed to make sure a PDF stream can be saved in Image.
PdfImagePlugin.__name__
# make sure values are unwrapped by calling the specialized __getitem__
def _unwrapping_get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
DictionaryObject.get = _unwrapping_get
class BrandedFileWriter(PdfFileWriter):
def __init__(self):
super().__init__()
self.addMetadata({
'/Creator': "Odoo",
'/Producer': "Odoo",
})
PdfFileWriter = BrandedFileWriter
def merge_pdf(pdf_data):
''' Merge a collection of PDF documents in one.
Note that the attachments are not merged.
:param list pdf_data: a list of PDF datastrings
:return: a unique merged PDF datastring
'''
writer = PdfFileWriter()
for document in pdf_data:
reader = PdfFileReader(io.BytesIO(document), strict=False)
for page in range(0, reader.getNumPages()):
writer.addPage(reader.getPage(page))
with io.BytesIO() as _buffer:
writer.write(_buffer)
return _buffer.getvalue()
def rotate_pdf(pdf):
''' Rotate clockwise PDF (90°) into a new PDF.
Note that the attachments are not copied.
:param pdf: a PDF to rotate
:return: a PDF rotated
'''
writer = PdfFileWriter()
reader = PdfFileReader(io.BytesIO(pdf), strict=False)
for page in range(0, reader.getNumPages()):
page = reader.getPage(page)
page.rotateClockwise(90)
writer.addPage(page)
with io.BytesIO() as _buffer:
writer.write(_buffer)
return _buffer.getvalue()
def to_pdf_stream(attachment) -> io.BytesIO:
"""Get the byte stream of the attachment as a PDF."""
stream = io.BytesIO(attachment.raw)
if attachment.mimetype == 'application/pdf':
return stream
elif attachment.mimetype.startswith('image'):
output_stream = io.BytesIO()
Image.open(stream).convert("RGB").save(output_stream, format="pdf")
return output_stream
_logger.warning("mimetype (%s) not recognized for %s", attachment.mimetype, attachment)
def add_banner(pdf_stream, text=None, logo=False, thickness=2 * cm):
""" Add a banner on a PDF in the upper right corner, with Odoo's logo (optionally).
:param pdf_stream (BytesIO): The PDF stream where the banner will be applied.
:param text (str): The text to be displayed.
:param logo (bool): Whether to display Odoo's logo in the banner.
:param thickness (float): The thickness of the banner in pixels.
:return (BytesIO): The modified PDF stream.
"""
old_pdf = PdfFileReader(pdf_stream, strict=False, overwriteWarnings=False)
packet = io.BytesIO()
can = canvas.Canvas(packet)
odoo_logo = Image.open(file_open('base/static/img/main_partner-image.png', mode='rb'))
odoo_color = colors.Color(113 / 255, 75 / 255, 103 / 255, 0.8)
for p in range(old_pdf.getNumPages()):
page = old_pdf.getPage(p)
width = float(abs(page.mediaBox.getWidth()))
height = float(abs(page.mediaBox.getHeight()))
can.setPageSize((width, height))
can.translate(width, height)
can.rotate(-45)
# Draw banner
path = can.beginPath()
path.moveTo(-width, -thickness)
path.lineTo(-width, -2 * thickness)
path.lineTo(width, -2 * thickness)
path.lineTo(width, -thickness)
can.setFillColor(odoo_color)
can.drawPath(path, fill=1, stroke=False)
# Insert text (and logo) inside the banner
can.setFontSize(10)
can.setFillColor(colors.white)
can.drawRightString(0.75 * thickness, -1.45 * thickness, text)
logo and can.drawImage(
ImageReader(odoo_logo), 0.25 * thickness, -2.05 * thickness, 40, 40, mask='auto', preserveAspectRatio=True)
can.showPage()
can.save()
# Merge the old pages with the watermark
watermark_pdf = PdfFileReader(packet, overwriteWarnings=False)
new_pdf = PdfFileWriter()
for p in range(old_pdf.getNumPages()):
new_page = old_pdf.getPage(p)
# Remove annotations (if any), to prevent errors in PyPDF2
if '/Annots' in new_page:
del new_page['/Annots']
new_page.mergePage(watermark_pdf.getPage(p))
new_pdf.addPage(new_page)
# Write the new pdf into a new output stream
output = io.BytesIO()
new_pdf.write(output)
return output
# by default PdfFileReader will overwrite warnings.showwarning which is what
# logging.captureWarnings does, meaning it essentially reverts captureWarnings
# every time it's called which is undesirable
old_init = PdfFileReader.__init__
PdfFileReader.__init__ = lambda self, stream, strict=True, warndest=None, overwriteWarnings=True: \
old_init(self, stream=stream, strict=strict, warndest=None, overwriteWarnings=False)
class OdooPdfFileReader(PdfFileReader):
# OVERRIDE of PdfFileReader to add the management of multiple embedded files.
''' Returns the files inside the PDF.
:raises NotImplementedError: if document is encrypted and uses an unsupported encryption method.
'''
def getAttachments(self):
if self.isEncrypted:
# If the PDF is owner-encrypted, try to unwrap it by giving it an empty user password.
self.decrypt('')
try:
file_path = self.trailer["/Root"].get("/Names", {}).get("/EmbeddedFiles", {}).get("/Names")
if not file_path:
return []
for i in range(0, len(file_path), 2):
attachment = file_path[i+1].getObject()
yield (attachment["/F"], attachment["/EF"]["/F"].getObject().getData())
except Exception:
# malformed pdf (i.e. invalid xref page)
return []
class OdooPdfFileWriter(PdfFileWriter):
def __init__(self, *args, **kwargs):
"""
Override of the init to initialise additional variables.
:param pdf_content: if given, will initialise the reader with the pdf content.
"""
super().__init__(*args, **kwargs)
self._reader = None
self.is_pdfa = False
def addAttachment(self, name, data, subtype=None):
"""
Add an attachment to the pdf. Supports adding multiple attachment, while respecting PDF/A rules.
:param name: The name of the attachement
:param data: The data of the attachement
:param subtype: The mime-type of the attachement. This is required by PDF/A, but not essential otherwise.
It should take the form of "/xxx#2Fxxx". E.g. for "text/xml": "/text#2Fxml"
"""
adapted_subtype = subtype
if subtype:
# If we receive the subtype in an 'unformated' (mimetype) format, we'll try to convert it to a pdf-valid one
if REGEX_SUBTYPE_UNFORMATED.match(subtype):
adapted_subtype = '/' + subtype.replace('/', '#2F')
if not REGEX_SUBTYPE_FORMATED.match(adapted_subtype):
# The subtype still does not match the correct format, so we will not add it to the document
_logger.warning("Attempt to add an attachment with the incorrect subtype '%s'. The subtype will be ignored.", subtype)
adapted_subtype = ''
attachment = self._create_attachment_object({
'filename': name,
'content': data,
'subtype': adapted_subtype,
})
if self._root_object.get('/Names') and self._root_object['/Names'].get('/EmbeddedFiles'):
names_array = self._root_object["/Names"]["/EmbeddedFiles"]["/Names"]
names_array.extend([attachment.getObject()['/F'], attachment])
else:
names_array = ArrayObject()
names_array.extend([attachment.getObject()['/F'], attachment])
embedded_files_names_dictionary = DictionaryObject()
embedded_files_names_dictionary.update({
NameObject("/Names"): names_array
})
embedded_files_dictionary = DictionaryObject()
embedded_files_dictionary.update({
NameObject("/EmbeddedFiles"): embedded_files_names_dictionary
})
self._root_object.update({
NameObject("/Names"): embedded_files_dictionary
})
if self._root_object.get('/AF'):
attachment_array = self._root_object['/AF']
attachment_array.extend([attachment])
else:
# Create a new object containing an array referencing embedded file
# And reference this array in the root catalogue
attachment_array = self._addObject(ArrayObject([attachment]))
self._root_object.update({
NameObject("/AF"): attachment_array
})
def embed_odoo_attachment(self, attachment, subtype=None):
assert attachment, "embed_odoo_attachment cannot be called without attachment."
self.addAttachment(attachment.name, attachment.raw, subtype=subtype or attachment.mimetype)
def cloneReaderDocumentRoot(self, reader):
super().cloneReaderDocumentRoot(reader)
self._reader = reader
# Try to read the header coming in, and reuse it in our new PDF
# This is done in order to allows modifying PDF/A files after creating them (as PyPDF does not read it)
stream = reader.stream
stream.seek(0)
header = stream.readlines(9)
# Should always be true, the first line of a pdf should have 9 bytes (%PDF-1.x plus a newline)
if len(header) == 1:
# If we found a header, set it back to the new pdf
self._header = header[0]
# Also check the second line. If it is PDF/A, it should be a line starting by % following by four bytes + \n
second_line = stream.readlines(1)[0]
if second_line.decode('latin-1')[0] == '%' and len(second_line) == 6:
self._header += second_line
self.is_pdfa = True
# Look if we have an ID in the incoming stream and use it.
pdf_id = reader.trailer.get('/ID', None)
if pdf_id:
self._ID = pdf_id
def convert_to_pdfa(self):
"""
Transform the opened PDF file into a PDF/A compliant file
"""
# Set the PDF version to 1.7 (as PDF/A-3 is based on version 1.7) and make it PDF/A compliant.
# See https://github.com/veraPDF/veraPDF-validation-profiles/wiki/PDFA-Parts-2-and-3-rules#rule-612-1
# " The file header shall begin at byte zero and shall consist of "%PDF-1.n" followed by a single EOL marker,
# where 'n' is a single digit number between 0 (30h) and 7 (37h) "
# " The aforementioned EOL marker shall be immediately followed by a % (25h) character followed by at least four
# bytes, each of whose encoded byte values shall have a decimal value greater than 127 "
self._header = b"%PDF-1.7\n%\xFF\xFF\xFF\xFF"
# Add a document ID to the trailer. This is only needed when using encryption with regular PDF, but is required
# when using PDF/A
pdf_id = ByteStringObject(md5(self._reader.stream.getvalue()).digest())
# The first string is based on the content at the time of creating the file, while the second is based on the
# content of the file when it was last updated. When creating a PDF, both are set to the same value.
self._ID = ArrayObject((pdf_id, pdf_id))
with file_open('tools/data/files/sRGB2014.icc', mode='rb') as icc_profile:
icc_profile_file_data = compress(icc_profile.read())
icc_profile_stream_obj = DecodedStreamObject()
icc_profile_stream_obj.setData(icc_profile_file_data)
icc_profile_stream_obj.update({
NameObject("/Filter"): NameObject("/FlateDecode"),
NameObject("/N"): NumberObject(3),
NameObject("/Length"): NameObject(str(len(icc_profile_file_data))),
})
icc_profile_obj = self._addObject(icc_profile_stream_obj)
output_intent_dict_obj = DictionaryObject()
output_intent_dict_obj.update({
NameObject("/S"): NameObject("/GTS_PDFA1"),
NameObject("/OutputConditionIdentifier"): createStringObject("sRGB"),
NameObject("/DestOutputProfile"): icc_profile_obj,
NameObject("/Type"): NameObject("/OutputIntent"),
})
output_intent_obj = self._addObject(output_intent_dict_obj)
self._root_object.update({
NameObject("/OutputIntents"): ArrayObject([output_intent_obj]),
})
pages = self._root_object['/Pages']['/Kids']
# PDF/A needs the glyphs width array embedded in the pdf to be consistent with the ones from the font file.
# But it seems like it is not the case when exporting from wkhtmltopdf.
if TTFont:
fonts = {}
# First browse through all the pages of the pdf file, to get a reference to all the fonts used in the PDF.
for page in pages:
for font in page.getObject()['/Resources']['/Font'].values():
for descendant in font.getObject()['/DescendantFonts']:
fonts[descendant.idnum] = descendant.getObject()
# Then for each font, rewrite the width array with the information taken directly from the font file.
# The new width are calculated such as width = round(1000 * font_glyph_width / font_units_per_em)
# See: http://martin.hoppenheit.info/blog/2018/pdfa-validation-and-inconsistent-glyph-width-information/
for font in fonts.values():
font_file = font['/FontDescriptor']['/FontFile2']
stream = io.BytesIO(decompress(font_file._data))
ttfont = TTFont(stream)
font_upm = ttfont['head'].unitsPerEm
glyphs = ttfont.getGlyphSet()._hmtx.metrics
glyph_widths = []
for key, values in glyphs.items():
if key[:5] == 'glyph':
glyph_widths.append(NumberObject(round(1000.0 * values[0] / font_upm)))
font[NameObject('/W')] = ArrayObject([NumberObject(1), ArrayObject(glyph_widths)])
stream.close()
else:
_logger.warning('The fonttools package is not installed. Generated PDF may not be PDF/A compliant.')
outlines = self._root_object['/Outlines'].getObject()
outlines[NameObject('/Count')] = NumberObject(1)
# Set odoo as producer
self.addMetadata({
'/Creator': "Odoo",
'/Producer': "Odoo",
})
self.is_pdfa = True
def add_file_metadata(self, metadata_content):
"""
Set the XMP metadata of the pdf, wrapping it with the necessary XMP header/footer.
These are required for a PDF/A file to be completely compliant. Ommiting them would result in validation errors.
:param metadata_content: bytes of the metadata to add to the pdf.
"""
# See https://wwwimages2.adobe.com/content/dam/acom/en/devnet/xmp/pdfs/XMP%20SDK%20Release%20cc-2016-08/XMPSpecificationPart1.pdf
# Page 10/11
header = b'<?xpacket begin="" id="W5M0MpCehiHzreSzNTczkc9d"?>'
footer = b'<?xpacket end="w"?>'
metadata = b'%s%s%s' % (header, metadata_content, footer)
file_entry = DecodedStreamObject()
file_entry.setData(metadata)
file_entry.update({
NameObject("/Type"): NameObject("/Metadata"),
NameObject("/Subtype"): NameObject("/XML"),
NameObject("/Length"): NameObject(str(len(metadata))),
})
# Add the new metadata to the pdf, then redirect the reference to refer to this new object.
metadata_object = self._addObject(file_entry)
self._root_object.update({NameObject("/Metadata"): metadata_object})
def _create_attachment_object(self, attachment):
''' Create a PyPdf2.generic object representing an embedded file.
:param attachment: A dictionary containing:
* filename: The name of the file to embed (required)
* content: The bytes of the file to embed (required)
* subtype: The mime-type of the file to embed (optional)
:return:
'''
file_entry = DecodedStreamObject()
file_entry.setData(attachment['content'])
file_entry.update({
NameObject("/Type"): NameObject("/EmbeddedFile"),
NameObject("/Params"):
DictionaryObject({
NameObject('/CheckSum'): createStringObject(md5(attachment['content']).hexdigest()),
NameObject('/ModDate'): createStringObject(datetime.now().strftime(DEFAULT_PDF_DATETIME_FORMAT)),
NameObject('/Size'): NameObject(f"/{len(attachment['content'])}"),
}),
})
if attachment.get('subtype'):
file_entry.update({
NameObject("/Subtype"): NameObject(attachment['subtype']),
})
file_entry_object = self._addObject(file_entry)
filename_object = createStringObject(attachment['filename'])
filespec_object = DictionaryObject({
NameObject("/AFRelationship"): NameObject("/Data"),
NameObject("/Type"): NameObject("/Filespec"),
NameObject("/F"): filename_object,
NameObject("/EF"):
DictionaryObject({
NameObject("/F"): file_entry_object,
NameObject('/UF'): file_entry_object,
}),
NameObject("/UF"): filename_object,
})
if attachment.get('description'):
filespec_object.update({NameObject("/Desc"): createStringObject(attachment['description'])})
return self._addObject(filespec_object)

View file

@ -1,23 +1,32 @@
# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.
import re
import warnings
from zlib import crc32
from .func import lazy_property
IDENT_RE = re.compile(r'^[a-z_][a-z0-9_$]*$', re.I)
from odoo.tools.sql import make_identifier, SQL, IDENT_RE
def _from_table(table, alias):
""" Return a FROM clause element from ``table`` and ``alias``. """
if alias == table:
return f'"{alias}"'
elif IDENT_RE.match(table):
return f'"{table}" AS "{alias}"'
else:
return f'({table}) AS "{alias}"'
def _sql_table(table: str | SQL | None) -> SQL | None:
""" Wrap an optional table as an SQL object. """
if isinstance(table, str):
return SQL.identifier(table) if IDENT_RE.match(table) else SQL(f"({table})")
return table
def _sql_from_table(alias: str, table: SQL | None) -> SQL:
""" Return a FROM clause element from ``alias`` and ``table``. """
if table is None:
return SQL.identifier(alias)
return SQL("%s AS %s", table, SQL.identifier(alias))
def _sql_from_join(kind: SQL, alias: str, table: SQL | None, condition: SQL) -> SQL:
""" Return a FROM clause element for a JOIN. """
return SQL("%s %s ON (%s)", kind, _sql_from_table(alias, table), condition)
_SQL_JOINS = {
"JOIN": SQL("JOIN"),
"LEFT JOIN": SQL("LEFT JOIN"),
}
def _generate_table_alias(src_table_alias, link):
@ -37,14 +46,7 @@ def _generate_table_alias(src_table_alias, link):
:param str link: field name
:return str: alias
"""
alias = "%s__%s" % (src_table_alias, link)
# Use an alternate alias scheme if length exceeds the PostgreSQL limit
# of 63 characters.
if len(alias) >= 64:
# We have to fit a crc32 hash and one underscore into a 63 character
# alias. The remaining space we can use to add a human readable prefix.
alias = "%s_%08x" % (alias[:54], crc32(alias.encode('utf-8')))
return alias
return make_identifier(f"{src_table_alias}__{link}")
class Query(object):
@ -54,43 +56,63 @@ class Query(object):
:param cr: database cursor (for lazy evaluation)
:param alias: name or alias of the table
:param table: if given, a table expression (identifier or query)
:param table: a table expression (``str`` or ``SQL`` object), optional
"""
def __init__(self, cr, alias, table=None):
def __init__(self, cr, alias: str, table: (str | SQL | None) = None):
# database cursor
self._cr = cr
# tables {alias: table}
self._tables = {alias: table or alias}
# tables {alias: table(SQL|None)}
self._tables = {alias: _sql_table(table)}
# joins {alias: (kind, table, condition, condition_params)}
# joins {alias: (kind(SQL), table(SQL|None), condition(SQL))}
self._joins = {}
# holds the list of WHERE clause elements (to be joined with 'AND'), and
# the list of parameters
# holds the list of WHERE conditions (to be joined with 'AND')
self._where_clauses = []
self._where_params = []
# order, limit, offset
self.order = None
self._order = None
self.limit = None
self.offset = None
def add_table(self, alias, table=None):
# memoized result
self._ids = None
def make_alias(self, alias: str, link: str) -> str:
""" Return an alias based on ``alias`` and ``link``. """
return _generate_table_alias(alias, link)
def add_table(self, alias: str, table: (str | SQL | None) = None):
""" Add a table with a given alias to the from clause. """
assert alias not in self._tables and alias not in self._joins, "Alias %r already in %s" % (alias, str(self))
self._tables[alias] = table or alias
assert alias not in self._tables and alias not in self._joins, f"Alias {alias!r} already in {self}"
self._tables[alias] = _sql_table(table)
self._ids = None
def add_where(self, where_clause, where_params=()):
def add_join(self, kind: str, alias: str, table: str | SQL | None, condition: SQL):
""" Add a join clause with the given alias, table and condition. """
sql_kind = _SQL_JOINS.get(kind.upper())
assert sql_kind is not None, f"Invalid JOIN type {kind!r}"
assert alias not in self._tables, f"Alias {alias!r} already used"
table = _sql_table(table)
if alias in self._joins:
assert self._joins[alias] == (sql_kind, table, condition)
else:
self._joins[alias] = (sql_kind, table, condition)
self._ids = None
def add_where(self, where_clause: str | SQL, where_params=()):
""" Add a condition to the where clause. """
self._where_clauses.append(where_clause)
self._where_params.extend(where_params)
self._where_clauses.append(SQL(where_clause, *where_params))
self._ids = None
def join(self, lhs_alias, lhs_column, rhs_table, rhs_column, link, extra=None, extra_params=()):
def join(self, lhs_alias: str, lhs_column: str, rhs_table: str, rhs_column: str, link: str):
"""
Perform a join between a table already present in the current Query object and
another table.
another table. This method is essentially a shortcut for methods :meth:`~.make_alias`
and :meth:`~.add_join`.
:param str lhs_alias: alias of a table already defined in the current Query object.
:param str lhs_column: column of `lhs_alias` to be used for the join's ON condition.
@ -98,151 +120,156 @@ class Query(object):
:param str rhs_column: column of `rhs_alias` to be used for the join's ON condition.
:param str link: used to generate the alias for the joined table, this string should
represent the relationship (the link) between both tables.
:param str extra: an sql string of a predicate or series of predicates to append to the
join's ON condition, `lhs_alias` and `rhs_alias` can be injected if the string uses
the `lhs` and `rhs` variables with the `str.format` syntax. e.g.::
query.join(..., extra="{lhs}.name != {rhs}.name OR ...", ...)
:param tuple extra_params: a tuple of values to be interpolated into `extra`, this is
done by psycopg2.
Full example:
>>> rhs_alias = query.join(
... "res_users",
... "partner_id",
... "res_partner",
... "id",
... "partner_id", # partner_id is the "link" from res_users to res_partner
... "{lhs}.\"name\" != %s",
... ("Mitchell Admin",),
... )
>>> rhs_alias
res_users_res_partner__partner_id
From the example above, the resulting query would be something like::
SELECT ...
FROM "res_users" AS "res_users"
JOIN "res_partner" AS "res_users_res_partner__partner_id"
ON "res_users"."partner_id" = "res_users_res_partner__partner_id"."id"
AND "res_users"."name" != 'Mitchell Admin'
WHERE ...
"""
return self._join('JOIN', lhs_alias, lhs_column, rhs_table, rhs_column, link, extra, extra_params)
assert lhs_alias in self._tables or lhs_alias in self._joins, "Alias %r not in %s" % (lhs_alias, str(self))
rhs_alias = self.make_alias(lhs_alias, link)
condition = SQL("%s = %s", SQL.identifier(lhs_alias, lhs_column), SQL.identifier(rhs_alias, rhs_column))
self.add_join('JOIN', rhs_alias, rhs_table, condition)
return rhs_alias
def left_join(self, lhs_alias, lhs_column, rhs_table, rhs_column, link, extra=None, extra_params=()):
def left_join(self, lhs_alias: str, lhs_column: str, rhs_table: str, rhs_column: str, link: str):
""" Add a LEFT JOIN to the current table (if necessary), and return the
alias corresponding to ``rhs_table``.
See the documentation of :meth:`join` for a better overview of the
arguments and what they do.
"""
return self._join('LEFT JOIN', lhs_alias, lhs_column, rhs_table, rhs_column, link, extra, extra_params)
def _join(self, kind, lhs_alias, lhs_column, rhs_table, rhs_column, link, extra=None, extra_params=()):
assert lhs_alias in self._tables or lhs_alias in self._joins, "Alias %r not in %s" % (lhs_alias, str(self))
rhs_alias = _generate_table_alias(lhs_alias, link)
assert rhs_alias not in self._tables, "Alias %r already in %s" % (rhs_alias, str(self))
if rhs_alias not in self._joins:
condition = f'"{lhs_alias}"."{lhs_column}" = "{rhs_alias}"."{rhs_column}"'
condition_params = []
if extra:
condition = condition + " AND " + extra.format(lhs=lhs_alias, rhs=rhs_alias)
condition_params = list(extra_params)
if kind:
self._joins[rhs_alias] = (kind, rhs_table, condition, condition_params)
else:
self._tables[rhs_alias] = rhs_table
self.add_where(condition, condition_params)
rhs_alias = self.make_alias(lhs_alias, link)
condition = SQL("%s = %s", SQL.identifier(lhs_alias, lhs_column), SQL.identifier(rhs_alias, rhs_column))
self.add_join('LEFT JOIN', rhs_alias, rhs_table, condition)
return rhs_alias
def select(self, *args):
""" Return the SELECT query as a pair ``(query_string, query_params)``. """
from_clause, where_clause, params = self.get_sql()
query_str = 'SELECT {} FROM {} WHERE {}{}{}{}'.format(
", ".join(args or [f'"{next(iter(self._tables))}".id']),
from_clause,
where_clause or "TRUE",
(" ORDER BY %s" % self.order) if self.order else "",
(" LIMIT %d" % self.limit) if self.limit else "",
(" OFFSET %d" % self.offset) if self.offset else "",
)
return query_str, params
@property
def order(self) -> SQL | None:
return self._order
def subselect(self, *args):
@order.setter
def order(self, value: SQL | str | None):
self._order = SQL(value) if value is not None else None
@property
def table(self) -> str:
""" Return the query's main table, i.e., the first one in the FROM clause. """
return next(iter(self._tables))
@property
def from_clause(self) -> SQL:
""" Return the FROM clause of ``self``, without the FROM keyword. """
tables = SQL(", ").join(
_sql_from_table(alias, table)
for alias, table in self._tables.items()
)
if not self._joins:
return tables
items = [tables]
for alias, (kind, table, condition) in self._joins.items():
items.append(_sql_from_join(kind, alias, table, condition))
return SQL(" ").join(items)
@property
def where_clause(self) -> SQL:
""" Return the WHERE condition of ``self``, without the WHERE keyword. """
return SQL(" AND ").join(self._where_clauses)
def is_empty(self):
""" Return whether the query is known to return nothing. """
return self._ids == ()
def select(self, *args: str | SQL) -> SQL:
""" Return the SELECT query as an ``SQL`` object. """
sql_args = map(SQL, args) if args else [SQL.identifier(self.table, 'id')]
return SQL(
"%s%s%s%s%s%s",
SQL("SELECT %s", SQL(", ").join(sql_args)),
SQL(" FROM %s", self.from_clause),
SQL(" WHERE %s", self.where_clause) if self._where_clauses else SQL(),
SQL(" ORDER BY %s", self._order) if self._order else SQL(),
SQL(" LIMIT %s", self.limit) if self.limit else SQL(),
SQL(" OFFSET %s", self.offset) if self.offset else SQL(),
)
def subselect(self, *args: str | SQL) -> SQL:
""" Similar to :meth:`.select`, but for sub-queries.
This one avoids the ORDER BY clause when possible.
This one avoids the ORDER BY clause when possible,
and includes parentheses around the subquery.
"""
if self._ids is not None and not args:
# inject the known result instead of the subquery
return SQL("%s", self._ids or (None,))
if self.limit or self.offset:
# in this case, the ORDER BY clause is necessary
return self.select(*args)
return SQL("(%s)", self.select(*args))
from_clause, where_clause, params = self.get_sql()
query_str = 'SELECT {} FROM {} WHERE {}'.format(
", ".join(args or [f'"{next(iter(self._tables))}".id']),
from_clause,
where_clause or "TRUE",
sql_args = map(SQL, args) if args else [SQL.identifier(self.table, 'id')]
return SQL(
"(%s%s%s)",
SQL("SELECT %s", SQL(", ").join(sql_args)),
SQL(" FROM %s", self.from_clause),
SQL(" WHERE %s", self.where_clause) if self._where_clauses else SQL(),
)
return query_str, params
def get_sql(self):
""" Returns (query_from, query_where, query_params). """
tables = [_from_table(table, alias) for alias, table in self._tables.items()]
joins = []
params = []
for alias, (kind, table, condition, condition_params) in self._joins.items():
joins.append(f'{kind} {_from_table(table, alias)} ON ({condition})')
params.extend(condition_params)
from_string, from_params = self.from_clause
where_string, where_params = self.where_clause
return from_string, where_string, from_params + where_params
from_clause = " ".join([", ".join(tables)] + joins)
where_clause = " AND ".join(self._where_clauses)
return from_clause, where_clause, params + self._where_params
def get_result_ids(self):
""" Return the result of ``self.select()`` as a tuple of ids. The result
is memoized for future use, which avoids making the same query twice.
"""
if self._ids is None:
self._cr.execute(self.select())
self._ids = tuple(row[0] for row in self._cr.fetchall())
return self._ids
@lazy_property
def _result(self):
query_str, params = self.select()
self._cr.execute(query_str, params)
return [row[0] for row in self._cr.fetchall()]
def set_result_ids(self, ids, ordered=True):
""" Set up the query to return the lines given by ``ids``. The parameter
``ordered`` tells whether the query must be ordered to match exactly the
sequence ``ids``.
"""
assert not (self._joins or self._where_clauses or self.limit or self.offset), \
"Method set_result_ids() can only be called on a virgin Query"
ids = tuple(ids)
if not ids:
self.add_where("FALSE")
elif ordered:
# This guarantees that self.select() returns the results in the
# expected order of ids:
# SELECT "stuff".id
# FROM "stuff"
# JOIN (SELECT * FROM unnest(%s) WITH ORDINALITY) AS "stuff__ids"
# ON ("stuff"."id" = "stuff__ids"."unnest")
# ORDER BY "stuff__ids"."ordinality"
alias = self.join(
self.table, 'id',
SQL('(SELECT * FROM unnest(%s) WITH ORDINALITY)', list(ids)), 'unnest',
'ids',
)
self.order = SQL.identifier(alias, 'ordinality')
else:
self.add_where(SQL("%s IN %s", SQL.identifier(self.table, 'id'), ids))
self._ids = ids
def __str__(self):
return '<osv.Query: %r with params: %r>' % self.select()
sql = self.select()
return f"<Query: {sql.code!r} with params: {sql.params!r}>"
def __bool__(self):
return bool(self._result)
return bool(self.get_result_ids())
def __len__(self):
return len(self._result)
if self._ids is None:
if self.limit or self.offset:
# optimization: generate a SELECT FROM, and then count the rows
sql = SQL("SELECT COUNT(*) FROM (%s) t", self.select(""))
else:
sql = self.select('COUNT(*)')
self._cr.execute(sql)
return self._cr.fetchone()[0]
return len(self.get_result_ids())
def __iter__(self):
return iter(self._result)
#
# deprecated attributes and methods
#
@property
def tables(self):
warnings.warn("deprecated Query.tables, use Query.get_sql() instead",
DeprecationWarning)
return tuple(_from_table(table, alias) for alias, table in self._tables.items())
@property
def where_clause(self):
return tuple(self._where_clauses)
@property
def where_clause_params(self):
return tuple(self._where_params)
def add_join(self, connection, implicit=True, outer=False, extra=None, extra_params=()):
warnings.warn("deprecated Query.add_join, use Query.join() or Query.left_join() instead",
DeprecationWarning)
lhs_alias, rhs_table, lhs_column, rhs_column, link = connection
kind = '' if implicit else ('LEFT JOIN' if outer else 'JOIN')
rhs_alias = self._join(kind, lhs_alias, lhs_column, rhs_table, rhs_column, link, extra, extra_params)
return rhs_alias, _from_table(rhs_table, rhs_alias)
return iter(self.get_result_ids())

View file

@ -125,7 +125,7 @@ class Speedscope:
"""
:param stack: A list of hashable frame
:param context: an iterable of (level, value) ordered by level
:param stack_offset: offeset level for stack
:param stack_offset: offset level for stack
Assemble stack and context and return a list of ids representing
this stack, adding each corresponding context at the corresponding
@ -146,7 +146,7 @@ class Speedscope:
return stack_ids
def process(self, entries, continuous=True, hide_gaps=False, use_context=True, constant_time=False):
# constant_time parameters is mainly usefull to hide temporality when focussing on sql determinism
# constant_time parameters is mainly useful to hide temporality when focussing on sql determinism
entry_end = previous_end = None
if not entries:
return []

View file

@ -1,20 +1,24 @@
# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.
# pylint: disable=sql-injection
from __future__ import annotations
import logging
import enum
import json
import logging
import re
import psycopg2
from psycopg2.sql import SQL, Identifier
import odoo.sql_db
from binascii import crc32
from collections import defaultdict
from contextlib import closing
from typing import Iterable, Union
import psycopg2
from .misc import named_to_positional_printf
_schema = logging.getLogger('odoo.schema')
IDENT_RE = re.compile(r'^[a-z0-9_][a-z0-9_$\-]*$', re.I)
_CONFDELTYPES = {
'RESTRICT': 'r',
'NO ACTION': 'a',
@ -23,37 +27,188 @@ _CONFDELTYPES = {
'SET DEFAULT': 'd',
}
class SQL:
""" An object that wraps SQL code with its parameters, like::
sql = SQL("UPDATE TABLE foo SET a = %s, b = %s", 'hello', 42)
cr.execute(sql)
The code is given as a ``%``-format string, and supports either positional
arguments (with `%s`) or named arguments (with `%(name)s`). Escaped
characters (like ``"%%"``) are not supported, though. The arguments are
meant to be merged into the code using the `%` formatting operator.
The SQL wrapper is designed to be composable: the arguments can be either
actual parameters, or SQL objects themselves::
sql = SQL(
"UPDATE TABLE %s SET %s",
SQL.identifier(tablename),
SQL("%s = %s", SQL.identifier(columnname), value),
)
The combined SQL code is given by ``sql.code``, while the corresponding
combined parameters are given by the list ``sql.params``. This allows to
combine any number of SQL terms without having to separately combine their
parameters, which can be tedious, bug-prone, and is the main downside of
`psycopg2.sql <https://www.psycopg.org/docs/sql.html>`.
The second purpose of the wrapper is to discourage SQL injections. Indeed,
if ``code`` is a string literal (not a dynamic string), then the SQL object
made with ``code`` is guaranteed to be safe, provided the SQL objects
within its parameters are themselves safe.
"""
__slots__ = ('__code', '__args')
# pylint: disable=keyword-arg-before-vararg
def __new__(cls, code: (str | SQL) = "", /, *args, **kwargs):
if isinstance(code, SQL):
return code
# validate the format of code and parameters
if args and kwargs:
raise TypeError("SQL() takes either positional arguments, or named arguments")
if args:
code % tuple("" for arg in args)
elif kwargs:
code, args = named_to_positional_printf(code, kwargs)
self = object.__new__(cls)
self.__code = code
self.__args = args
return self
@property
def code(self) -> str:
""" Return the combined SQL code string. """
stack = [] # stack of intermediate results
for node in self.__postfix():
if not isinstance(node, SQL):
stack.append("%s")
elif arity := len(node.__args):
stack[-arity:] = [node.__code % tuple(stack[-arity:])]
else:
stack.append(node.__code)
return stack[0]
@property
def params(self) -> list:
""" Return the combined SQL code params as a list of values. """
return [node for node in self.__postfix() if not isinstance(node, SQL)]
def __postfix(self):
""" Return a postfix iterator for the SQL tree ``self``. """
stack = [(self, False)]
while stack:
node, ispostfix = stack.pop()
if ispostfix or not isinstance(node, SQL):
yield node
else:
stack.append((node, True))
stack.extend((arg, False) for arg in reversed(node.__args))
def __repr__(self):
return f"SQL({', '.join(map(repr, [self.code, *self.params]))})"
def __bool__(self):
return bool(self.__code)
def __eq__(self, other):
return self.code == other.code and self.params == other.params
def __iter__(self):
""" Yields ``self.code`` and ``self.params``. This was introduced for
backward compatibility, as it enables to access the SQL and parameters
by deconstructing the object::
sql = SQL(...)
code, params = sql
"""
yield self.code
yield self.params
def join(self, args: Iterable) -> SQL:
""" Join SQL objects or parameters with ``self`` as a separator. """
args = list(args)
# optimizations for special cases
if len(args) == 0:
return SQL()
if len(args) == 1:
return args[0]
if not self.__args:
return SQL(self.__code.join("%s" for arg in args), *args)
# general case: alternate args with self
items = [self] * (len(args) * 2 - 1)
for index, arg in enumerate(args):
items[index * 2] = arg
return SQL("%s" * len(items), *items)
@classmethod
def identifier(cls, name: str, subname: (str | None) = None) -> SQL:
""" Return an SQL object that represents an identifier. """
assert name.isidentifier() or IDENT_RE.match(name), f"{name!r} invalid for SQL.identifier()"
if subname is None:
return cls(f'"{name}"')
assert subname.isidentifier() or IDENT_RE.match(subname), f"{subname!r} invalid for SQL.identifier()"
return cls(f'"{name}"."{subname}"')
def existing_tables(cr, tablenames):
""" Return the names of existing tables among ``tablenames``. """
query = """
cr.execute(SQL("""
SELECT c.relname
FROM pg_class c
JOIN pg_namespace n ON (n.oid = c.relnamespace)
WHERE c.relname IN %s
AND c.relkind IN ('r', 'v', 'm')
AND n.nspname = current_schema
"""
cr.execute(query, [tuple(tablenames)])
""", tuple(tablenames)))
return [row[0] for row in cr.fetchall()]
def table_exists(cr, tablename):
""" Return whether the given table exists. """
return len(existing_tables(cr, {tablename})) == 1
def table_kind(cr, tablename):
""" Return the kind of a table: ``'r'`` (regular table), ``'v'`` (view),
``'f'`` (foreign table), ``'t'`` (temporary table),
``'m'`` (materialized view), or ``None``.
class TableKind(enum.Enum):
Regular = 'r'
Temporary = 't'
View = 'v'
Materialized = 'm'
Foreign = 'f'
Other = None
def table_kind(cr, tablename: str) -> Union[TableKind, None]:
""" Return the kind of a table, if ``tablename`` is a regular or foreign
table, or a view (ignores indexes, sequences, toast tables, and partitioned
tables; unlogged tables are considered regular)
"""
query = """
SELECT c.relkind
cr.execute(SQL("""
SELECT c.relkind, c.relpersistence
FROM pg_class c
JOIN pg_namespace n ON (n.oid = c.relnamespace)
WHERE c.relname = %s
AND n.nspname = current_schema
"""
cr.execute(query, (tablename,))
return cr.fetchone()[0] if cr.rowcount else None
""", tablename))
if not cr.rowcount:
return None
kind, persistence = cr.fetchone()
# special case: permanent, temporary, and unlogged tables differ by their
# relpersistence, they're all "ordinary" (relkind = r)
if kind == 'r':
return TableKind.Temporary if persistence == 't' else TableKind.Regular
try:
return TableKind(kind)
except ValueError:
# NB: or raise? unclear if it makes sense to allow table_kind to
# "work" with something like an index or sequence
return TableKind.Other
# prescribed column order by type: columns aligned on 4 bytes, columns aligned
# on 1 byte, columns aligned on 8 bytes(values have been chosen to minimize
@ -70,26 +225,32 @@ SQL_ORDER_BY_TYPE = defaultdict(lambda: 16, {
'float8': 9, # 8 bytes aligned on 8 bytes
})
def create_model_table(cr, tablename, comment=None, columns=()):
""" Create the table for a model. """
colspecs = ['id SERIAL NOT NULL'] + [
'"{}" {}'.format(columnname, columntype)
for columnname, columntype, columncomment in columns
colspecs = [
SQL('id SERIAL NOT NULL'),
*(SQL("%s %s", SQL.identifier(colname), SQL(coltype)) for colname, coltype, _ in columns),
SQL('PRIMARY KEY(id)'),
]
queries = [
SQL("CREATE TABLE %s (%s)", SQL.identifier(tablename), SQL(", ").join(colspecs)),
]
cr.execute('CREATE TABLE "{}" ({}, PRIMARY KEY(id))'.format(tablename, ", ".join(colspecs)))
queries, params = [], []
if comment:
queries.append('COMMENT ON TABLE "{}" IS %s'.format(tablename))
params.append(comment)
for columnname, columntype, columncomment in columns:
queries.append('COMMENT ON COLUMN "{}"."{}" IS %s'.format(tablename, columnname))
params.append(columncomment)
if queries:
cr.execute("; ".join(queries), params)
queries.append(SQL(
"COMMENT ON TABLE %s IS %s",
SQL.identifier(tablename), comment,
))
for colname, _, colcomment in columns:
queries.append(SQL(
"COMMENT ON COLUMN %s IS %s",
SQL.identifier(tablename, colname), colcomment,
))
cr.execute(SQL("; ").join(queries))
_schema.debug("Table %r: created", tablename)
def table_columns(cr, tablename):
""" Return a dict mapping column names to their configuration. The latter is
a dict with the data from the table ``information_schema.columns``.
@ -97,51 +258,78 @@ def table_columns(cr, tablename):
# Do not select the field `character_octet_length` from `information_schema.columns`
# because specific access right restriction in the context of shared hosting (Heroku, OVH, ...)
# might prevent a postgres user to read this field.
query = '''SELECT column_name, udt_name, character_maximum_length, is_nullable
FROM information_schema.columns WHERE table_name=%s'''
cr.execute(query, (tablename,))
cr.execute(SQL(
''' SELECT column_name, udt_name, character_maximum_length, is_nullable
FROM information_schema.columns WHERE table_name=%s ''',
tablename,
))
return {row['column_name']: row for row in cr.dictfetchall()}
def column_exists(cr, tablename, columnname):
""" Return whether the given column exists. """
query = """ SELECT 1 FROM information_schema.columns
WHERE table_name=%s AND column_name=%s """
cr.execute(query, (tablename, columnname))
cr.execute(SQL(
""" SELECT 1 FROM information_schema.columns
WHERE table_name=%s AND column_name=%s """,
tablename, columnname,
))
return cr.rowcount
def create_column(cr, tablename, columnname, columntype, comment=None):
""" Create a column with the given type. """
coldefault = (columntype.upper()=='BOOLEAN') and 'DEFAULT false' or ''
cr.execute('ALTER TABLE "{}" ADD COLUMN "{}" {} {}'.format(tablename, columnname, columntype, coldefault))
sql = SQL(
"ALTER TABLE %s ADD COLUMN %s %s %s",
SQL.identifier(tablename),
SQL.identifier(columnname),
SQL(columntype),
SQL("DEFAULT false" if columntype.upper() == 'BOOLEAN' else ""),
)
if comment:
cr.execute('COMMENT ON COLUMN "{}"."{}" IS %s'.format(tablename, columnname), (comment,))
sql = SQL("%s; %s", sql, SQL(
"COMMENT ON COLUMN %s IS %s",
SQL.identifier(tablename, columnname), comment,
))
cr.execute(sql)
_schema.debug("Table %r: added column %r of type %s", tablename, columnname, columntype)
def rename_column(cr, tablename, columnname1, columnname2):
""" Rename the given column. """
cr.execute('ALTER TABLE "{}" RENAME COLUMN "{}" TO "{}"'.format(tablename, columnname1, columnname2))
cr.execute(SQL(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
SQL.identifier(tablename),
SQL.identifier(columnname1),
SQL.identifier(columnname2),
))
_schema.debug("Table %r: renamed column %r to %r", tablename, columnname1, columnname2)
def convert_column(cr, tablename, columnname, columntype):
""" Convert the column to the given type. """
using = f'"{columnname}"::{columntype}'
using = SQL("%s::%s", SQL.identifier(columnname), SQL(columntype))
_convert_column(cr, tablename, columnname, columntype, using)
def convert_column_translatable(cr, tablename, columnname, columntype):
""" Convert the column from/to a 'jsonb' translated field column. """
drop_index(cr, f"{tablename}_{columnname}_index", tablename)
drop_index(cr, make_index_name(tablename, columnname), tablename)
if columntype == "jsonb":
using = f"""CASE WHEN "{columnname}" IS NOT NULL THEN jsonb_build_object('en_US', "{columnname}"::varchar) END"""
using = SQL(
"CASE WHEN %s IS NOT NULL THEN jsonb_build_object('en_US', %s::varchar) END",
SQL.identifier(columnname), SQL.identifier(columnname),
)
else:
using = f""""{columnname}"->>'en_US'"""
using = SQL("%s->>'en_US'", SQL.identifier(columnname))
_convert_column(cr, tablename, columnname, columntype, using)
def _convert_column(cr, tablename, columnname, columntype, using):
query = f'''
ALTER TABLE "{tablename}"
ALTER COLUMN "{columnname}" DROP DEFAULT,
ALTER COLUMN "{columnname}" TYPE {columntype} USING {using}
'''
def _convert_column(cr, tablename, columnname, columntype, using: SQL):
query = SQL(
"ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT, ALTER COLUMN %s TYPE %s USING %s",
SQL.identifier(tablename), SQL.identifier(columnname),
SQL.identifier(columnname), SQL(columntype), using,
)
try:
with cr.savepoint(flush=False):
cr.execute(query, log_exceptions=False)
@ -150,15 +338,21 @@ def _convert_column(cr, tablename, columnname, columntype, using):
cr.execute(query)
_schema.debug("Table %r: column %r changed to type %s", tablename, columnname, columntype)
def drop_depending_views(cr, table, column):
"""drop views depending on a field to allow the ORM to resize it in-place"""
for v, k in get_depending_views(cr, table, column):
cr.execute("DROP {0} VIEW IF EXISTS {1} CASCADE".format("MATERIALIZED" if k == "m" else "", v))
cr.execute(SQL(
"DROP %s IF EXISTS %s CASCADE",
SQL("MATERIALIZED VIEW" if k == "m" else "VIEW"),
SQL.identifier(v),
))
_schema.debug("Drop view %r", v)
def get_depending_views(cr, table, column):
# http://stackoverflow.com/a/11773226/75349
q = """
cr.execute(SQL("""
SELECT distinct quote_ident(dependee.relname), dependee.relkind
FROM pg_depend
JOIN pg_rewrite ON pg_depend.objid = pg_rewrite.oid
@ -170,13 +364,16 @@ def get_depending_views(cr, table, column):
AND pg_attribute.attnum > 0
AND pg_attribute.attname = %s
AND dependee.relkind in ('v', 'm')
"""
cr.execute(q, [table, column])
""", table, column))
return cr.fetchall()
def set_not_null(cr, tablename, columnname):
""" Add a NOT NULL constraint on the given column. """
query = 'ALTER TABLE "{}" ALTER COLUMN "{}" SET NOT NULL'.format(tablename, columnname)
query = SQL(
"ALTER TABLE %s ALTER COLUMN %s SET NOT NULL",
SQL.identifier(tablename), SQL.identifier(columnname),
)
try:
with cr.savepoint(flush=False):
cr.execute(query, log_exceptions=False)
@ -184,53 +381,76 @@ def set_not_null(cr, tablename, columnname):
except Exception:
raise Exception("Table %r: unable to set NOT NULL on column %r", tablename, columnname)
def drop_not_null(cr, tablename, columnname):
""" Drop the NOT NULL constraint on the given column. """
cr.execute('ALTER TABLE "{}" ALTER COLUMN "{}" DROP NOT NULL'.format(tablename, columnname))
cr.execute(SQL(
"ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL",
SQL.identifier(tablename), SQL.identifier(columnname),
))
_schema.debug("Table %r: column %r: dropped constraint NOT NULL", tablename, columnname)
def constraint_definition(cr, tablename, constraintname):
""" Return the given constraint's definition. """
query = """
cr.execute(SQL("""
SELECT COALESCE(d.description, pg_get_constraintdef(c.oid))
FROM pg_constraint c
JOIN pg_class t ON t.oid = c.conrelid
LEFT JOIN pg_description d ON c.oid = d.objoid
WHERE t.relname = %s AND conname = %s;"""
cr.execute(query, (tablename, constraintname))
WHERE t.relname = %s AND conname = %s
""", tablename, constraintname))
return cr.fetchone()[0] if cr.rowcount else None
def add_constraint(cr, tablename, constraintname, definition):
""" Add a constraint on the given table. """
query1 = 'ALTER TABLE "{}" ADD CONSTRAINT "{}" {}'.format(tablename, constraintname, definition)
query2 = 'COMMENT ON CONSTRAINT "{}" ON "{}" IS %s'.format(constraintname, tablename)
query1 = SQL(
"ALTER TABLE %s ADD CONSTRAINT %s %s",
SQL.identifier(tablename), SQL.identifier(constraintname), SQL(definition),
)
query2 = SQL(
"COMMENT ON CONSTRAINT %s ON %s IS %s",
SQL.identifier(constraintname), SQL.identifier(tablename), definition,
)
try:
with cr.savepoint(flush=False):
cr.execute(query1, log_exceptions=False)
cr.execute(query2, (definition,), log_exceptions=False)
cr.execute(query2, log_exceptions=False)
_schema.debug("Table %r: added constraint %r as %s", tablename, constraintname, definition)
except Exception:
raise Exception("Table %r: unable to add constraint %r as %s", tablename, constraintname, definition)
def drop_constraint(cr, tablename, constraintname):
""" drop the given constraint. """
try:
with cr.savepoint(flush=False):
cr.execute('ALTER TABLE "{}" DROP CONSTRAINT "{}"'.format(tablename, constraintname))
cr.execute(SQL(
"ALTER TABLE %s DROP CONSTRAINT %s",
SQL.identifier(tablename), SQL.identifier(constraintname),
))
_schema.debug("Table %r: dropped constraint %r", tablename, constraintname)
except Exception:
_schema.warning("Table %r: unable to drop constraint %r!", tablename, constraintname)
def add_foreign_key(cr, tablename1, columnname1, tablename2, columnname2, ondelete):
""" Create the given foreign key, and return ``True``. """
query = 'ALTER TABLE "{}" ADD FOREIGN KEY ("{}") REFERENCES "{}"("{}") ON DELETE {}'
cr.execute(query.format(tablename1, columnname1, tablename2, columnname2, ondelete))
cr.execute(SQL(
"ALTER TABLE %s ADD FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE %s",
SQL.identifier(tablename1), SQL.identifier(columnname1),
SQL.identifier(tablename2), SQL.identifier(columnname2),
SQL(ondelete),
))
_schema.debug("Table %r: added foreign key %r references %r(%r) ON DELETE %s",
tablename1, columnname1, tablename2, columnname2, ondelete)
return True
def get_foreign_keys(cr, tablename1, columnname1, tablename2, columnname2, ondelete):
cr.execute(
deltype = _CONFDELTYPES[ondelete.upper()]
cr.execute(SQL(
"""
SELECT fk.conname as name
FROM pg_constraint AS fk
@ -244,25 +464,29 @@ def get_foreign_keys(cr, tablename1, columnname1, tablename2, columnname2, ondel
AND c2.relname = %s
AND a2.attname = %s
AND fk.confdeltype = %s
""", [tablename1, columnname1, tablename2, columnname2, _CONFDELTYPES[ondelete.upper()]]
)
""",
tablename1, columnname1, tablename2, columnname2, deltype,
))
return [r[0] for r in cr.fetchall()]
def fix_foreign_key(cr, tablename1, columnname1, tablename2, columnname2, ondelete):
""" Update the foreign keys between tables to match the given one, and
return ``True`` if the given foreign key has been recreated.
"""
# Do not use 'information_schema' here, as those views are awfully slow!
deltype = _CONFDELTYPES.get(ondelete.upper(), 'a')
query = """ SELECT con.conname, c2.relname, a2.attname, con.confdeltype as deltype
FROM pg_constraint as con, pg_class as c1, pg_class as c2,
pg_attribute as a1, pg_attribute as a2
WHERE con.contype='f' AND con.conrelid=c1.oid AND con.confrelid=c2.oid
AND array_lower(con.conkey, 1)=1 AND con.conkey[1]=a1.attnum
AND array_lower(con.confkey, 1)=1 AND con.confkey[1]=a2.attnum
AND a1.attrelid=c1.oid AND a2.attrelid=c2.oid
AND c1.relname=%s AND a1.attname=%s """
cr.execute(query, (tablename1, columnname1))
cr.execute(SQL(
""" SELECT con.conname, c2.relname, a2.attname, con.confdeltype as deltype
FROM pg_constraint as con, pg_class as c1, pg_class as c2,
pg_attribute as a1, pg_attribute as a2
WHERE con.contype='f' AND con.conrelid=c1.oid AND con.confrelid=c2.oid
AND array_lower(con.conkey, 1)=1 AND con.conkey[1]=a1.attnum
AND array_lower(con.confkey, 1)=1 AND con.confkey[1]=a2.attnum
AND a1.attrelid=c1.oid AND a2.attrelid=c2.oid
AND c1.relname=%s AND a1.attname=%s """,
tablename1, columnname1,
))
found = False
for fk in cr.fetchall():
if not found and fk[1:] == (tablename2, columnname2, deltype):
@ -272,47 +496,63 @@ def fix_foreign_key(cr, tablename1, columnname1, tablename2, columnname2, ondele
if not found:
return add_foreign_key(cr, tablename1, columnname1, tablename2, columnname2, ondelete)
def index_exists(cr, indexname):
""" Return whether the given index exists. """
cr.execute("SELECT 1 FROM pg_indexes WHERE indexname=%s", (indexname,))
cr.execute(SQL("SELECT 1 FROM pg_indexes WHERE indexname=%s", indexname))
return cr.rowcount
def check_index_exist(cr, indexname):
assert index_exists(cr, indexname), f"{indexname} does not exist"
def create_index(cr, indexname, tablename, expressions, method='btree', where=''):
""" Create the given index unless it exists. """
if index_exists(cr, indexname):
return
args = ', '.join(expressions)
if where:
where = f' WHERE {where}'
cr.execute(f'CREATE INDEX "{indexname}" ON "{tablename}" USING {method} ({args}){where}')
_schema.debug("Table %r: created index %r (%s)", tablename, indexname, args)
cr.execute(SQL(
"CREATE INDEX %s ON %s USING %s (%s)%s",
SQL.identifier(indexname),
SQL.identifier(tablename),
SQL(method),
SQL(", ").join(SQL(expression) for expression in expressions),
SQL(" WHERE %s", SQL(where)) if where else SQL(),
))
_schema.debug("Table %r: created index %r (%s)", tablename, indexname, ", ".join(expressions))
def create_unique_index(cr, indexname, tablename, expressions):
""" Create the given index unless it exists. """
if index_exists(cr, indexname):
return
args = ', '.join(expressions)
cr.execute('CREATE UNIQUE INDEX "{}" ON "{}" ({})'.format(indexname, tablename, args))
_schema.debug("Table %r: created index %r (%s)", tablename, indexname, args)
cr.execute(SQL(
"CREATE UNIQUE INDEX %s ON %s (%s)",
SQL.identifier(indexname),
SQL.identifier(tablename),
SQL(", ").join(SQL(expression) for expression in expressions),
))
_schema.debug("Table %r: created index %r (%s)", tablename, indexname, ", ".join(expressions))
def drop_index(cr, indexname, tablename):
""" Drop the given index if it exists. """
cr.execute('DROP INDEX IF EXISTS "{}"'.format(indexname))
cr.execute(SQL("DROP INDEX IF EXISTS %s", SQL.identifier(indexname)))
_schema.debug("Table %r: dropped index %r", tablename, indexname)
def drop_view_if_exists(cr, viewname):
kind = table_kind(cr, viewname)
if kind == 'v':
cr.execute("DROP VIEW {} CASCADE".format(viewname))
elif kind == 'm':
cr.execute("DROP MATERIALIZED VIEW {} CASCADE".format(viewname))
if kind == TableKind.View:
cr.execute(SQL("DROP VIEW %s CASCADE", SQL.identifier(viewname)))
elif kind == TableKind.Materialized:
cr.execute(SQL("DROP MATERIALIZED VIEW %s CASCADE", SQL.identifier(viewname)))
def escape_psql(to_escape):
return to_escape.replace('\\', r'\\').replace('%', r'\%').replace('_', r'\_')
def pg_varchar(size=0):
""" Returns the VARCHAR declaration for the provided size:
@ -330,6 +570,7 @@ def pg_varchar(size=0):
return 'VARCHAR(%d)' % size
return 'VARCHAR'
def reverse_order(order):
""" Reverse an ORDER BY clause """
items = []
@ -357,20 +598,22 @@ def increment_fields_skiplock(records, *fields):
for field in fields:
assert records._fields[field].type == 'integer'
query = SQL("""
UPDATE {table}
SET {sets}
WHERE id IN (SELECT id FROM {table} WHERE id = ANY(%(ids)s) FOR UPDATE SKIP LOCKED)
""").format(
table=Identifier(records._table),
sets=SQL(', ').join(map(
SQL('{0} = COALESCE({0}, 0) + 1').format,
map(Identifier, fields)
))
)
cr = records._cr
cr.execute(query, {'ids': records.ids})
tablename = records._table
cr.execute(SQL(
"""
UPDATE %s
SET %s
WHERE id IN (SELECT id FROM %s WHERE id = ANY(%s) FOR UPDATE SKIP LOCKED)
""",
SQL.identifier(tablename),
SQL(', ').join(
SQL("%s = COALESCE(%s, 0) + 1", SQL.identifier(field), SQL.identifier(field))
for field in fields
),
SQL.identifier(tablename),
records.ids,
))
return bool(cr.rowcount)
@ -431,3 +674,20 @@ def pattern_to_translated_trigram_pattern(pattern):
# replace the original wildcard characters by %
return f"%{'%'.join(wildcard_escaped)}%" if wildcard_escaped else "%"
def make_identifier(identifier: str) -> str:
""" Return ``identifier``, possibly modified to fit PostgreSQL's identifier size limitation.
If too long, ``identifier`` is truncated and padded with a hash to make it mostly unique.
"""
# if length exceeds the PostgreSQL limit of 63 characters.
if len(identifier) > 63:
# We have to fit a crc32 hash and one underscore into a 63 character
# alias. The remaining space we can use to add a human readable prefix.
return f"{identifier[:54]}_{crc32(identifier.encode()):08x}"
return identifier
def make_index_name(table_name: str, column_name: str) -> str:
""" Return an index name according to conventions for the given table and column. """
return make_identifier(f"{table_name}__{column_name}_index")

View file

@ -22,7 +22,8 @@ def add_stripped_items_before(node, spec, extract):
parent = node.getparent()
result = parent.text and RSTRIP_REGEXP.search(parent.text)
before_text = result.group(0) if result else ''
parent.text = (parent.text or '').rstrip() + text
fallback_text = None if spec.text is None else ''
parent.text = ((parent.text or '').rstrip() + text) or fallback_text
else:
result = prev.tail and RSTRIP_REGEXP.search(prev.tail)
before_text = result.group(0) if result else ''
@ -89,13 +90,7 @@ def locate_node(arch, spec):
return None
for node in arch.iter(spec.tag):
if isinstance(node, SKIPPED_ELEMENT_TYPES):
continue
if all(node.get(attr) == spec.get(attr) for attr in spec.attrib
if attr not in ('position', 'version')):
# Version spec should match parent's root element's version
if spec.get('version') and spec.get('version') != arch.get('version'):
return None
if all(node.get(attr) == spec.get(attr) for attr in spec.attrib if attr != 'position'):
return node
return None

View file

@ -141,7 +141,7 @@ def try_report_action(cr, uid, action_id, active_model=None, active_ids=None,
env = env(context=context)
if action['type'] in ['ir.actions.act_window', 'ir.actions.submenu']:
for key in ('res_id', 'res_model', 'view_mode',
'limit', 'search_view', 'search_view_id'):
'limit', 'search_view_id'):
datas[key] = action.get(key, datas.get(key, None))
view_id = False

View file

@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.
from __future__ import annotations
import codecs
import fnmatch
import functools
@ -17,19 +20,20 @@ import tarfile
import threading
import warnings
from collections import defaultdict, namedtuple
from contextlib import suppress
from datetime import datetime
from os.path import join
from pathlib import Path
from babel.messages import extract
from lxml import etree, html
from markupsafe import escape, Markup
from psycopg2.extras import Json
import odoo
from odoo.exceptions import UserError
from odoo.modules.module import get_resource_path
from . import config, pycompat
from .misc import file_open, get_iso_codes, SKIPPED_ELEMENT_TYPES
from .misc import file_open, file_path, get_iso_codes, SKIPPED_ELEMENT_TYPES
_logger = logging.getLogger(__name__)
@ -209,7 +213,13 @@ def translate_xml_node(node, callback, parse, serialize):
def translatable(node):
""" Return whether the given node can be translated as a whole. """
return (
node.tag in TRANSLATED_ELEMENTS
# Some specific nodes (e.g., text highlights) have an auto-updated
# DOM structure that makes them impossible to translate.
# The introduction of a translation `<span>` in the middle of their
# hierarchy breaks their functionalities. We need to force them to
# be translated as a whole using the `o_translate_inline` class.
"o_translate_inline" in node.attrib.get("class", "").split()
or node.tag in TRANSLATED_ELEMENTS
and not any(key.startswith("t-") for key in node.attrib)
and all(translatable(child) for child in node)
)
@ -326,8 +336,6 @@ def xml_term_adapter(term_en):
right_iter = right.iterchildren()
for lc, rc in zip(left_iter, right_iter):
yield from same_struct_iter(lc, rc)
if next(left_iter, None) is not None or next(right_iter, None) is not None:
raise ValueError("Non matching struct")
def adapter(term):
new_node = parse_xml(f"<div>{term}</div>")
@ -525,6 +533,8 @@ class GettextAlias(object):
translation = self._get_translation(source)
assert not (args and kwargs)
if args or kwargs:
if any(isinstance(a, Markup) for a in itertools.chain(args, kwargs.values())):
translation = escape(translation)
try:
return translation % (args or kwargs)
except (TypeError, ValueError, KeyError):
@ -913,6 +923,13 @@ def trans_export(lang, modules, buffer, format, cr):
writer = TranslationFileWriter(buffer, fileformat=format, lang=lang)
writer.write_rows(reader)
# pylint: disable=redefined-builtin
def trans_export_records(lang, model_name, ids, buffer, format, cr):
reader = TranslationRecordReader(cr, model_name, ids, lang=lang)
writer = TranslationFileWriter(buffer, fileformat=format, lang=lang)
writer.write_rows(reader)
def _push(callback, term, source_line):
""" Sanity check before pushing translation terms """
term = (term or "").strip()
@ -1002,28 +1019,28 @@ def extract_spreadsheet_terms(fileobj, keywords, comment_tags, options):
:return: an iterator over ``(lineno, funcname, message, comments)``
tuples
"""
terms = []
terms = set()
data = json.load(fileobj)
for sheet in data.get('sheets', []):
for cell in sheet['cells'].values():
content = cell.get('content', '')
if content.startswith('='):
terms += extract_formula_terms(content)
terms.update(extract_formula_terms(content))
else:
markdown_link = re.fullmatch(r'\[(.+)\]\(.+\)', content)
if markdown_link:
terms.append(markdown_link[1])
terms.add(markdown_link[1])
for figure in sheet['figures']:
terms.append(figure['data']['title'])
terms.add(figure['data']['title'])
if 'baselineDescr' in figure['data']:
terms.append(figure['data']['baselineDescr'])
terms.add(figure['data']['baselineDescr'])
pivots = data.get('pivots', {}).values()
lists = data.get('lists', {}).values()
for data_source in itertools.chain(lists, pivots):
if 'name' in data_source:
terms.append(data_source['name'])
terms.add(data_source['name'])
for global_filter in data.get('globalFilters', []):
terms.append(global_filter['label'])
terms.add(global_filter['label'])
return (
(0, None, term, [])
for term in terms
@ -1033,33 +1050,14 @@ def extract_spreadsheet_terms(fileobj, keywords, comment_tags, options):
ImdInfo = namedtuple('ExternalId', ['name', 'model', 'res_id', 'module'])
class TranslationModuleReader:
""" Retrieve translated records per module
:param cr: cursor to database to export
:param modules: list of modules to filter the exported terms, can be ['all']
records with no external id are always ignored
:param lang: language code to retrieve the translations
retrieve source terms only if not set
"""
def __init__(self, cr, modules=None, lang=None):
class TranslationReader:
def __init__(self, cr, lang=None):
self._cr = cr
self._modules = modules or ['all']
self._lang = lang or 'en_US'
self.env = odoo.api.Environment(cr, odoo.SUPERUSER_ID, {})
self._to_translate = []
self._path_list = [(path, True) for path in odoo.addons.__path__]
self._installed_modules = [
m['name']
for m in self.env['ir.module.module'].search_read([('state', '=', 'installed')], fields=['name'])
]
self._export_translatable_records()
self._export_translatable_resources()
def __iter__(self):
""" Export ir.translation values for all retrieved records """
for module, source, name, res_id, ttype, comments, _record_id, value in self._to_translate:
yield (module, ttype, name, res_id, source, encode(odoo.tools.ustr(value)), comments)
@ -1080,6 +1078,44 @@ class TranslationModuleReader:
return
self._to_translate.append((module, source, name, res_id, ttype, tuple(comments or ()), record_id, value))
def _export_imdinfo(self, model: str, imd_per_id: dict[int, ImdInfo]):
records = self._get_translatable_records(imd_per_id.values())
if not records:
return
env = records.env
for record in records.with_context(check_translations=True):
module = imd_per_id[record.id].module
xml_name = "%s.%s" % (module, imd_per_id[record.id].name)
for field_name, field in record._fields.items():
# ir_actions_actions.name is filtered because unlike other inherited fields,
# this field is inherited as postgresql inherited columns.
# From our business perspective, the parent column is no need to be translated,
# but it is need to be set to jsonb column, since the child columns need to be translated
# And export the parent field may make one value to be translated twice in transifex
#
# Some ir_model_fields.field_description are filtered
# because their fields have falsy attribute export_string_translation
if (
not (field.translate and field.store)
or str(field) == 'ir.actions.actions.name'
or (str(field) == 'ir.model.fields.field_description'
and not env[record.model]._fields[record.name].export_string_translation)
):
continue
name = model + "," + field_name
value_en = record[field_name] or ''
value_lang = record.with_context(lang=self._lang)[field_name] or ''
trans_type = 'model_terms' if callable(field.translate) else 'model'
try:
translation_dictionary = field.get_translation_dictionary(value_en, {self._lang: value_lang})
except Exception:
_logger.exception("Failed to extract terms from %s %s", xml_name, name)
continue
for term_en, term_langs in translation_dictionary.items():
term_lang = term_langs.get(self._lang)
self._push_translation(module, trans_type, name, xml_name, term_en, record_id=record.id, value=term_lang if term_lang != term_en else '')
def _get_translatable_records(self, imd_records):
""" Filter the records that are translatable
@ -1131,6 +1167,84 @@ class TranslationModuleReader:
return records
class TranslationRecordReader(TranslationReader):
""" Retrieve translations for specified records, the reader will
1. create external ids for records without external ids
2. export translations for stored translated and inherited translated fields
:param cr: cursor to database to export
:param model_name: model_name for the records to export
:param ids: ids of the records to export
:param field_names: field names to export, if not set, export all translatable fields
:param lang: language code to retrieve the translations retrieve source terms only if not set
"""
def __init__(self, cr, model_name, ids, field_names=None, lang=None):
super().__init__(cr, lang)
self._records = self.env[model_name].browse(ids)
self._field_names = field_names or list(self._records._fields.keys())
self._export_translatable_records(self._records, self._field_names)
def _export_translatable_records(self, records, field_names):
""" Export translations of all stored/inherited translated fields. Create external id if needed. """
if not records:
return
fields = records._fields
if records._inherits:
inherited_fields = defaultdict(list)
for field_name in field_names:
field = records._fields[field_name]
if field.translate and not field.store and field.inherited_field:
inherited_fields[field.inherited_field.model_name].append(field_name)
for parent_mname, parent_fname in records._inherits.items():
if parent_mname in inherited_fields:
self._export_translatable_records(records[parent_fname], inherited_fields[parent_mname])
if not any(fields[field_name].translate and fields[field_name].store for field_name in field_names):
return
records._BaseModel__ensure_xml_id()
model_name = records._name
query = """SELECT min(concat(module, '.', name)), res_id
FROM ir_model_data
WHERE model = %s
AND res_id = ANY(%s)
GROUP BY model, res_id"""
self._cr.execute(query, (model_name, records.ids))
imd_per_id = {
res_id: ImdInfo((tmp := module_xml_name.split('.', 1))[1], model_name, res_id, tmp[0])
for module_xml_name, res_id in self._cr.fetchall()
}
self._export_imdinfo(model_name, imd_per_id)
class TranslationModuleReader(TranslationReader):
""" Retrieve translated records per module
:param cr: cursor to database to export
:param modules: list of modules to filter the exported terms, can be ['all']
records with no external id are always ignored
:param lang: language code to retrieve the translations
retrieve source terms only if not set
"""
def __init__(self, cr, modules=None, lang=None):
super().__init__(cr, lang)
self._modules = modules or ['all']
self._path_list = [(path, True) for path in odoo.addons.__path__]
self._installed_modules = [
m['name']
for m in self.env['ir.module.module'].search_read([('state', '=', 'installed')], fields=['name'])
]
self._export_translatable_records()
self._export_translatable_resources()
def _export_translatable_records(self):
""" Export translations of all translated records having an external id """
@ -1152,30 +1266,7 @@ class TranslationModuleReader:
records_per_model[model][res_id] = ImdInfo(xml_name, model, res_id, module)
for model, imd_per_id in records_per_model.items():
records = self._get_translatable_records(imd_per_id.values())
if not records:
continue
for record in records:
module = imd_per_id[record.id].module
xml_name = "%s.%s" % (module, imd_per_id[record.id].name)
for field_name, field in record._fields.items():
# ir_actions_actions.name is filtered because unlike other inherited fields,
# this field is inherited as postgresql inherited columns.
# From our business perspective, the parent column is no need to be translated,
# but it is need to be set to jsonb column, since the child columns need to be translated
# And export the parent field may make one value to be translated twice in transifex
if field.translate and field.store and str(field) != 'ir.actions.actions.name':
name = model + "," + field_name
try:
value_en = record[field_name] or ''
value_lang = record.with_context(lang=self._lang)[field_name] or ''
except Exception:
continue
trans_type = 'model_terms' if callable(field.translate) else 'model'
for term_en, term_langs in field.get_translation_dictionary(value_en, {self._lang: value_lang}).items():
term_lang = term_langs.get(self._lang)
self._push_translation(module, trans_type, name, xml_name, term_en, record_id=record.id, value=term_lang if term_lang != term_en else '')
self._export_imdinfo(model, imd_per_id)
def _get_module_from_path(self, path):
for (mp, rec) in self._path_list:
@ -1204,7 +1295,7 @@ class TranslationModuleReader:
if not module:
return
extra_comments = extra_comments or []
src_file = open(fabsolutepath, 'rb')
src_file = file_open(fabsolutepath, 'rb')
options = {}
if extract_method == 'python':
options['encoding'] = 'UTF-8'
@ -1243,6 +1334,8 @@ class TranslationModuleReader:
self._path_list.append((config['root_path'], False))
_logger.debug("Scanning modules at paths: %s", self._path_list)
spreadsheet_files_regex = re.compile(r".*_dashboard(\.osheet)?\.json$")
for (path, recursive) in self._path_list:
_logger.debug("Scanning files of modules at %s", path)
for root, dummy, files in os.walk(path, followlinks=True):
@ -1261,7 +1354,7 @@ class TranslationModuleReader:
self._babel_extract_terms(fname, path, root, 'odoo.tools.translate:babel_extract_qweb',
extra_comments=[JAVASCRIPT_TRANSLATION_COMMENT])
if fnmatch.fnmatch(root, '*/data/*'):
for fname in fnmatch.filter(files, '*_dashboard.json'):
for fname in filter(spreadsheet_files_regex.match, files):
self._babel_extract_terms(fname, path, root, 'odoo.tools.translate:extract_spreadsheet_terms',
extra_comments=[JAVASCRIPT_TRANSLATION_COMMENT])
if not recursive:
@ -1297,7 +1390,8 @@ class TranslationImporter:
the language must be present and activated in the database
:param xmlids: if given, only translations for records with xmlid in xmlids will be loaded
"""
with file_open(filepath, mode='rb') as fileobj:
with suppress(FileNotFoundError), file_open(filepath, mode='rb', env=self.env) as fileobj:
_logger.info('loading base translation file %s for language %s', filepath, lang)
fileformat = os.path.splitext(filepath)[-1][1:].lower()
self.load(fileobj, fileformat, lang, xmlids=xmlids)
@ -1332,7 +1426,6 @@ class TranslationImporter:
continue
if row.get('type') == 'code': # ignore code translations
continue
# TODO: CWG if the po file should not be trusted, we need to check each model term
model_name = row.get('imd_model')
module_name = row['module']
if model_name not in self.env:
@ -1390,16 +1483,20 @@ class TranslationImporter:
for id_, xmlid, values, noupdate in cr.fetchall():
if not values:
continue
value_en = values.get('en_US')
if not value_en:
_value_en = values.get('_en_US', values['en_US'])
if not _value_en:
continue
# {src: {lang: value}}
record_dictionary = field_dictionary[xmlid]
langs = {lang for translations in record_dictionary.values() for lang in translations.keys()}
translation_dictionary = field.get_translation_dictionary(
value_en,
{k: v for k, v in values.items() if k in langs}
_value_en,
{
k: values.get(f'_{k}', v)
for k, v in values.items()
if k in langs
}
)
if force_overwrite or (not noupdate and overwrite):
@ -1413,7 +1510,9 @@ class TranslationImporter:
translation_dictionary[term_en] = translations
for lang in langs:
values[lang] = field.translate(lambda term: translation_dictionary.get(term, {}).get(lang), value_en)
# translate and confirm model_terms translations
values[lang] = field.translate(lambda term: translation_dictionary.get(term, {}).get(lang), _value_en)
values.pop(f'_{lang}', None)
params.extend((id_, Json(values)))
if params:
env.cr.execute(f"""
@ -1456,7 +1555,7 @@ class TranslationImporter:
self.model_translations.clear()
env.invalidate_all()
env.registry.clear_caches()
env.registry.clear_cache()
if self.verbose:
_logger.info("translations are loaded successfully")
@ -1530,6 +1629,27 @@ def load_language(cr, lang):
installer.lang_install()
def get_po_paths(module_name: str, lang: str):
return get_po_paths_env(module_name, lang)
def get_po_paths_env(module_name: str, lang: str, env: odoo.api.Environment | None = None):
lang_base = lang.split('_', 1)[0]
# Load the base as a fallback in case a translation is missing:
po_names = [lang_base, lang]
# Exception for Spanish locales: they have two bases, es and es_419:
if lang_base == 'es' and lang not in ('es_ES', 'es_419'):
po_names.insert(1, 'es_419')
po_paths = (
join(module_name, dir_, filename + '.po')
for filename in po_names
for dir_ in ('i18n', 'i18n_extra')
)
for path in po_paths:
with suppress(FileNotFoundError):
yield file_path(path, env=env)
class CodeTranslations:
def __init__(self):
# {(module_name, lang): {src: value}}
@ -1537,15 +1657,6 @@ class CodeTranslations:
# {(module_name, lang): {'message': [{'id': src, 'string': value}]}
self.web_translations = {}
@staticmethod
def _get_po_paths(mod, lang):
lang_base = lang.split('_')[0]
po_paths = [get_resource_path(mod, 'i18n', lang_base + '.po'),
get_resource_path(mod, 'i18n', lang + '.po'),
get_resource_path(mod, 'i18n_extra', lang_base + '.po'),
get_resource_path(mod, 'i18n_extra', lang + '.po')]
return [path for path in po_paths if path]
@staticmethod
def _read_code_translations_file(fileobj, filter_func):
""" read and return code translations from fileobj with filter filter_func
@ -1564,7 +1675,7 @@ class CodeTranslations:
@staticmethod
def _get_code_translations(module_name, lang, filter_func):
po_paths = CodeTranslations._get_po_paths(module_name, lang)
po_paths = get_po_paths(module_name, lang)
translations = {}
for po_path in po_paths:
try:

View file

@ -8,6 +8,7 @@ import re
from lxml import etree
from odoo import tools
from odoo.osv.expression import DOMAIN_OPERATORS
_logger = logging.getLogger(__name__)
@ -17,63 +18,238 @@ _relaxng_cache = {}
READONLY = re.compile(r"\breadonly\b")
def _get_attrs_symbols():
""" Return a set of predefined symbols for evaluating attrs. """
return {
'True', 'False', 'None', # those are identifiers in Python 2.7
'self',
'id',
'uid',
'context',
'context_today',
'active_id',
'active_ids',
'allowed_company_ids',
'current_company_id',
'active_model',
'time',
'datetime',
'relativedelta',
'current_date',
'today',
'now',
'abs',
'len',
'bool',
'float',
'str',
'unicode',
}
# predefined symbols for evaluating attributes (invisible, readonly...)
IGNORED_IN_EXPRESSION = {
'True', 'False', 'None', # those are identifiers in Python 2.7
'self',
'uid',
'context',
'context_today',
'allowed_company_ids',
'current_company_id',
'time',
'datetime',
'relativedelta',
'current_date',
'today',
'now',
'abs',
'len',
'bool',
'float',
'str',
'unicode',
'set',
}
def get_variable_names(expr):
""" Return the subexpressions of the kind "VARNAME(.ATTNAME)*" in the given
string or AST node.
def get_domain_value_names(domain):
""" Return all field name used by this domain
eg: [
('id', 'in', [1, 2, 3]),
('field_a', 'in', parent.truc),
('field_b', 'in', context.get('b')),
(1, '=', 1),
bool(context.get('c')),
]
returns {'id', 'field_a', 'field_b'}, {'parent', 'parent.truc', 'context'}
:param domain: list(tuple) or str
:return: set(str), set(str)
"""
IGNORED = _get_attrs_symbols()
names = set()
contextual_values = set()
field_names = set()
def get_name_seq(node):
if isinstance(node, ast.Name):
return [node.id]
elif isinstance(node, ast.Attribute):
left = get_name_seq(node.value)
return left and left + [node.attr]
try:
if isinstance(domain, list):
for leaf in domain:
if leaf in DOMAIN_OPERATORS or leaf in (True, False):
# "&", "|", "!", True, False
continue
left, _operator, _right = leaf
if isinstance(left, str):
field_names.add(left)
elif left not in (1, 0):
# deprecate: True leaf and False leaf
raise ValueError()
def process(node):
seq = get_name_seq(node)
if seq and seq[0] not in IGNORED:
names.add('.'.join(seq))
else:
for child in ast.iter_child_nodes(node):
process(child)
elif isinstance(domain, str):
def extract_from_domain(ast_domain):
if isinstance(ast_domain, ast.IfExp):
# [] if condition else []
extract_from_domain(ast_domain.body)
extract_from_domain(ast_domain.orelse)
return
if isinstance(ast_domain, ast.BoolOp):
# condition and []
# this formating don't check returned domain syntax
for value in ast_domain.values:
if isinstance(value, (ast.List, ast.IfExp, ast.BoolOp, ast.BinOp)):
extract_from_domain(value)
else:
contextual_values.update(_get_expression_contextual_values(value))
return
if isinstance(ast_domain, ast.BinOp):
# [] + []
# this formating don't check returned domain syntax
if isinstance(ast_domain.left, (ast.List, ast.IfExp, ast.BoolOp, ast.BinOp)):
extract_from_domain(ast_domain.left)
else:
contextual_values.update(_get_expression_contextual_values(ast_domain.left))
if isinstance(expr, str):
expr = ast.parse(expr.strip(), mode='eval').body
process(expr)
if isinstance(ast_domain.right, (ast.List, ast.IfExp, ast.BoolOp, ast.BinOp)):
extract_from_domain(ast_domain.right)
else:
contextual_values.update(_get_expression_contextual_values(ast_domain.right))
return
for ast_item in ast_domain.elts:
if isinstance(ast_item, ast.Constant):
# "&", "|", "!", True, False
if ast_item.value not in DOMAIN_OPERATORS and ast_item.value not in (True, False):
raise ValueError()
elif isinstance(ast_item, (ast.List, ast.Tuple)):
left, _operator, right = ast_item.elts
contextual_values.update(_get_expression_contextual_values(right))
if isinstance(left, ast.Constant) and isinstance(left.value, str):
field_names.add(left.value)
elif isinstance(left, ast.Constant) and left.value in (1, 0):
# deprecate: True leaf (1, '=', 1) and False leaf (0, '=', 1)
pass
elif isinstance(right, ast.Constant) and right.value == 1:
# deprecate: True/False leaf (py expression, '=', 1)
contextual_values.update(_get_expression_contextual_values(left))
else:
raise ValueError()
else:
raise ValueError()
return names
expr = domain.strip()
item_ast = ast.parse(f"({expr})", mode='eval').body
if isinstance(item_ast, ast.Name):
# domain="other_field_domain"
contextual_values.update(_get_expression_contextual_values(item_ast))
else:
extract_from_domain(item_ast)
except ValueError:
raise ValueError("Wrong domain formatting.") from None
value_names = set()
for name in contextual_values:
if name == 'parent':
continue
root = name.split('.')[0]
if root not in IGNORED_IN_EXPRESSION:
value_names.add(name if root == 'parent' else root)
return field_names, value_names
def _get_expression_contextual_values(item_ast):
""" Return all contextual value this ast
eg: ast from '''(
id in [1, 2, 3]
and field_a in parent.truc
and field_b in context.get('b')
or (
True
and bool(context.get('c'))
)
)
returns {'parent', 'parent.truc', 'context', 'bool'}
:param item_ast: ast
:return: set(str)
"""
if isinstance(item_ast, ast.Constant):
return set()
if isinstance(item_ast, (ast.List, ast.Tuple)):
values = set()
for item in item_ast.elts:
values |= _get_expression_contextual_values(item)
return values
if isinstance(item_ast, ast.Name):
return {item_ast.id}
if isinstance(item_ast, ast.Attribute):
values = _get_expression_contextual_values(item_ast.value)
if len(values) == 1:
path = sorted(list(values)).pop()
values = {f"{path}.{item_ast.attr}"}
return values
return values
if isinstance(item_ast, ast.Index): # deprecated python ast class for Subscript key
return _get_expression_contextual_values(item_ast.value)
if isinstance(item_ast, ast.Subscript):
values = _get_expression_contextual_values(item_ast.value)
values |= _get_expression_contextual_values(item_ast.slice)
return values
if isinstance(item_ast, ast.Compare):
values = _get_expression_contextual_values(item_ast.left)
for sub_ast in item_ast.comparators:
values |= _get_expression_contextual_values(sub_ast)
return values
if isinstance(item_ast, ast.BinOp):
values = _get_expression_contextual_values(item_ast.left)
values |= _get_expression_contextual_values(item_ast.right)
return values
if isinstance(item_ast, ast.BoolOp):
values = set()
for ast_value in item_ast.values:
values |= _get_expression_contextual_values(ast_value)
return values
if isinstance(item_ast, ast.UnaryOp):
return _get_expression_contextual_values(item_ast.operand)
if isinstance(item_ast, ast.Call):
values = _get_expression_contextual_values(item_ast.func)
for ast_arg in item_ast.args:
values |= _get_expression_contextual_values(ast_arg)
return values
if isinstance(item_ast, ast.IfExp):
values = _get_expression_contextual_values(item_ast.test)
values |= _get_expression_contextual_values(item_ast.body)
values |= _get_expression_contextual_values(item_ast.orelse)
return values
if isinstance(item_ast, ast.Dict):
values = set()
for item in item_ast.keys:
values |= _get_expression_contextual_values(item)
for item in item_ast.values:
values |= _get_expression_contextual_values(item)
return values
raise ValueError(f"Undefined item {item_ast!r}.")
def get_expression_field_names(expression):
""" Return all field name used by this expression
eg: expression = '''(
id in [1, 2, 3]
and field_a in parent.truc.id
and field_b in context.get('b')
or (True and bool(context.get('c')))
)
returns {'parent', 'parent.truc', 'parent.truc.id', 'context', 'context.get'}
:param expression: str
:param ignored: set contains the value name to ignore.
Add '.' to ignore attributes (eg: {'parent.'} will
ignore 'parent.truc' and 'parent.truc.id')
:return: set(str)
"""
item_ast = ast.parse(expression.strip(), mode='eval').body
contextual_values = _get_expression_contextual_values(item_ast)
value_names = set()
for name in contextual_values:
if name == 'parent':
continue
root = name.split('.')[0]
if root not in IGNORED_IN_EXPRESSION:
value_names.add(name if root == 'parent' else root)
return value_names
def get_dict_asts(expr):
@ -86,9 +262,9 @@ def get_dict_asts(expr):
if not isinstance(expr, ast.Dict):
raise ValueError("Non-dict expression")
if not all(isinstance(key, ast.Str) for key in expr.keys):
if not all((isinstance(key, ast.Constant) and isinstance(key.value, str)) for key in expr.keys):
raise ValueError("Non-string literal dict key")
return {key.s: val for key, val in zip(expr.keys, expr.values)}
return {key.value: val for key, val in zip(expr.keys, expr.values)}
def _check(condition, explanation):
@ -96,53 +272,12 @@ def _check(condition, explanation):
raise ValueError("Expression is not a valid domain: %s" % explanation)
def get_domain_identifiers(expr):
""" Check that the given string or AST node represents a domain expression,
and return a pair of sets ``(fields, vars)`` where ``fields`` are the field
names on the left-hand side of conditions, and ``vars`` are the variable
names on the right-hand side of conditions.
"""
if not expr: # case of expr=""
return (set(), set())
if isinstance(expr, str):
expr = ast.parse(expr.strip(), mode='eval').body
fnames = set()
vnames = set()
if isinstance(expr, ast.List):
for elem in expr.elts:
if isinstance(elem, ast.Str):
# note: this doesn't check the and/or structure
_check(elem.s in ('&', '|', '!'),
f"logical operators should be '&', '|', or '!', found {elem.s!r}")
continue
if not isinstance(elem, (ast.List, ast.Tuple)):
continue
_check(len(elem.elts) == 3,
f"segments should have 3 elements, found {len(elem.elts)}")
lhs, operator, rhs = elem.elts
_check(isinstance(operator, ast.Str),
f"operator should be a string, found {type(operator).__name__}")
if isinstance(lhs, ast.Str):
fnames.add(lhs.s)
vnames.update(get_variable_names(expr))
return (fnames, vnames)
def valid_view(arch, **kwargs):
for pred in _validators[arch.tag]:
check = pred(arch, **kwargs)
if not check:
_logger.error("Invalid XML: %s", pred.__doc__)
return False
if check == "Warning":
_logger.warning("Invalid XML: %s", pred.__doc__)
return "Warning"
return False
return True
@ -176,7 +311,7 @@ def schema_valid(arch, **kwargs):
if validator and not validator.validate(arch):
result = True
for error in validator.error_log:
_logger.error(tools.ustr(error))
_logger.warning(tools.ustr(error))
result = False
return result
return True

View file

@ -140,7 +140,8 @@ def cleanup_xml_node(xml_node_or_string, remove_blank_text=True, remove_blank_no
if isinstance(xml_node, str):
xml_node = xml_node.encode() # misnomer: fromstring actually reads bytes
if isinstance(xml_node, bytes):
xml_node = etree.fromstring(remove_control_characters(xml_node))
parser = etree.XMLParser(recover=True, resolve_entities=False)
xml_node = etree.fromstring(remove_control_characters(xml_node), parser=parser)
# Process leaf nodes iteratively
# Depth-first, so any inner node may become a leaf too (if children are removed)
@ -302,3 +303,8 @@ def validate_xml_from_attachment(env, xml_content, xsd_name, reload_files_functi
_logger.info("XSD validation successful!")
except FileNotFoundError:
_logger.info("XSD file not found, skipping validation")
def find_xml_value(xpath, xml_element, namespaces=None):
element = xml_element.xpath(xpath, namespaces=namespaces)
return element[0].text if element else None