Initial commit

This commit is contained in:
2025-10-29 14:37:08 +08:00
commit 034509181f
4131 changed files with 555736 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
# flake8: noqa: F403
from ..compat.compat_utils import passthrough_module
passthrough_module(__name__, '._deprecated')
del passthrough_module
# isort: off
from .traversal import *
from ._utils import *
from ._utils import _configuration_args, _get_exe_version_output # noqa: F401

View File

@@ -0,0 +1,49 @@
"""Deprecated - New code should avoid these"""
import base64
import hashlib
import hmac
import json
import warnings
from ..compat.compat_utils import passthrough_module
# XXX: Implement this the same way as other DeprecationWarnings without circular import
passthrough_module(__name__, '.._legacy', callback=lambda attr: warnings.warn(
DeprecationWarning(f'{__name__}.{attr} is deprecated'), stacklevel=6))
del passthrough_module
import re
import struct
def bytes_to_intlist(bs):
if not bs:
return []
if isinstance(bs[0], int): # Python 3
return list(bs)
else:
return [ord(c) for c in bs]
def intlist_to_bytes(xs):
if not xs:
return b''
return struct.pack('%dB' % len(xs), *xs)
def jwt_encode_hs256(payload_data, key, headers={}):
header_data = {
'alg': 'HS256',
'typ': 'JWT',
}
if headers:
header_data.update(headers)
header_b64 = base64.b64encode(json.dumps(header_data).encode())
payload_b64 = base64.b64encode(json.dumps(payload_data).encode())
h = hmac.new(key.encode(), header_b64 + b'.' + payload_b64, hashlib.sha256)
signature_b64 = base64.b64encode(h.digest())
return header_b64 + b'.' + payload_b64 + b'.' + signature_b64
compiled_regex_type = type(re.compile(''))

View File

@@ -0,0 +1,269 @@
"""No longer used and new code should not use. Exists only for API compat."""
import platform
import struct
import sys
import urllib.error
import urllib.parse
import urllib.request
import zlib
from ._utils import Popen, decode_base_n, preferredencoding
from .traversal import traverse_obj
from ..dependencies import certifi, websockets
from ..networking._helper import make_ssl_context
from ..networking._urllib import HTTPHandler
# isort: split
from .networking import escape_rfc3986 # noqa: F401
from .networking import normalize_url as escape_url
from .networking import random_user_agent, std_headers # noqa: F401
from ..cookies import YoutubeDLCookieJar # noqa: F401
from ..networking._urllib import PUTRequest # noqa: F401
from ..networking._urllib import SUPPORTED_ENCODINGS, HEADRequest # noqa: F401
from ..networking._urllib import ProxyHandler as PerRequestProxyHandler # noqa: F401
from ..networking._urllib import RedirectHandler as YoutubeDLRedirectHandler # noqa: F401
from ..networking._urllib import ( # noqa: F401
make_socks_conn_class,
update_Request,
)
from ..networking.exceptions import HTTPError, network_exceptions # noqa: F401
has_certifi = bool(certifi)
has_websockets = bool(websockets)
def load_plugins(name, suffix, namespace):
from ..plugins import load_plugins
ret = load_plugins(name, suffix)
namespace.update(ret)
return ret
def traverse_dict(dictn, keys, casesense=True):
return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True)
def decode_base(value, digits):
return decode_base_n(value, table=digits)
def platform_name():
""" Returns the platform name as a str """
return platform.platform()
def get_subprocess_encoding():
if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
# For subprocess calls, encode with locale encoding
# Refer to http://stackoverflow.com/a/9951851/35070
encoding = preferredencoding()
else:
encoding = sys.getfilesystemencoding()
if encoding is None:
encoding = 'utf-8'
return encoding
# UNUSED
# Based on png2str() written by @gdkchan and improved by @yokrysty
# Originally posted at https://github.com/ytdl-org/youtube-dl/issues/9706
def decode_png(png_data):
# Reference: https://www.w3.org/TR/PNG/
header = png_data[8:]
if png_data[:8] != b'\x89PNG\x0d\x0a\x1a\x0a' or header[4:8] != b'IHDR':
raise OSError('Not a valid PNG file.')
int_map = {1: '>B', 2: '>H', 4: '>I'}
unpack_integer = lambda x: struct.unpack(int_map[len(x)], x)[0]
chunks = []
while header:
length = unpack_integer(header[:4])
header = header[4:]
chunk_type = header[:4]
header = header[4:]
chunk_data = header[:length]
header = header[length:]
header = header[4:] # Skip CRC
chunks.append({
'type': chunk_type,
'length': length,
'data': chunk_data,
})
ihdr = chunks[0]['data']
width = unpack_integer(ihdr[:4])
height = unpack_integer(ihdr[4:8])
idat = b''
for chunk in chunks:
if chunk['type'] == b'IDAT':
idat += chunk['data']
if not idat:
raise OSError('Unable to read PNG data.')
decompressed_data = bytearray(zlib.decompress(idat))
stride = width * 3
pixels = []
def _get_pixel(idx):
x = idx % stride
y = idx // stride
return pixels[y][x]
for y in range(height):
base_pos = y * (1 + stride)
filter_type = decompressed_data[base_pos]
current_row = []
pixels.append(current_row)
for x in range(stride):
color = decompressed_data[1 + base_pos + x]
basex = y * stride + x
left = 0
up = 0
if x > 2:
left = _get_pixel(basex - 3)
if y > 0:
up = _get_pixel(basex - stride)
if filter_type == 1: # Sub
color = (color + left) & 0xff
elif filter_type == 2: # Up
color = (color + up) & 0xff
elif filter_type == 3: # Average
color = (color + ((left + up) >> 1)) & 0xff
elif filter_type == 4: # Paeth
a = left
b = up
c = 0
if x > 2 and y > 0:
c = _get_pixel(basex - stride - 3)
p = a + b - c
pa = abs(p - a)
pb = abs(p - b)
pc = abs(p - c)
if pa <= pb and pa <= pc:
color = (color + a) & 0xff
elif pb <= pc:
color = (color + b) & 0xff
else:
color = (color + c) & 0xff
current_row.append(color)
return width, height, pixels
def register_socks_protocols():
# "Register" SOCKS protocols
# In Python < 2.6.5, urlsplit() suffers from bug https://bugs.python.org/issue7904
# URLs with protocols not in urlparse.uses_netloc are not handled correctly
for scheme in ('socks', 'socks4', 'socks4a', 'socks5'):
if scheme not in urllib.parse.uses_netloc:
urllib.parse.uses_netloc.append(scheme)
def handle_youtubedl_headers(headers):
filtered_headers = headers
if 'Youtubedl-no-compression' in filtered_headers:
filtered_headers = {k: v for k, v in filtered_headers.items() if k.lower() != 'accept-encoding'}
del filtered_headers['Youtubedl-no-compression']
return filtered_headers
def request_to_url(req):
if isinstance(req, urllib.request.Request):
return req.get_full_url()
else:
return req
def sanitized_Request(url, *args, **kwargs):
from ..utils import extract_basic_auth, sanitize_url
url, auth_header = extract_basic_auth(escape_url(sanitize_url(url)))
if auth_header is not None:
headers = args[1] if len(args) >= 2 else kwargs.setdefault('headers', {})
headers['Authorization'] = auth_header
return urllib.request.Request(url, *args, **kwargs)
class YoutubeDLHandler(HTTPHandler):
def __init__(self, params, *args, **kwargs):
self._params = params
super().__init__(*args, **kwargs)
YoutubeDLHTTPSHandler = YoutubeDLHandler
class YoutubeDLCookieProcessor(urllib.request.HTTPCookieProcessor):
def __init__(self, cookiejar=None):
urllib.request.HTTPCookieProcessor.__init__(self, cookiejar)
def http_response(self, request, response):
return urllib.request.HTTPCookieProcessor.http_response(self, request, response)
https_request = urllib.request.HTTPCookieProcessor.http_request
https_response = http_response
def make_HTTPS_handler(params, **kwargs):
return YoutubeDLHTTPSHandler(params, context=make_ssl_context(
verify=not params.get('nocheckcertificate'),
client_certificate=params.get('client_certificate'),
client_certificate_key=params.get('client_certificate_key'),
client_certificate_password=params.get('client_certificate_password'),
legacy_support=params.get('legacyserverconnect'),
use_certifi='no-certifi' not in params.get('compat_opts', []),
), **kwargs)
def process_communicate_or_kill(p, *args, **kwargs):
return Popen.communicate_or_kill(p, *args, **kwargs)
def encodeFilename(s, for_subprocess=False):
assert isinstance(s, str)
return s
def decodeFilename(b, for_subprocess=False):
return b
def decodeArgument(b):
return b
def decodeOption(optval):
if optval is None:
return optval
if isinstance(optval, bytes):
optval = optval.decode(preferredencoding())
assert isinstance(optval, str)
return optval
def error_to_compat_str(err):
return str(err)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
# Utility functions for handling web input based on commonly used JavaScript libraries

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import array
import base64
import datetime as dt
import math
import re
from .._utils import parse_iso8601
TYPE_CHECKING = False
if TYPE_CHECKING:
import collections.abc
import typing
T = typing.TypeVar('T')
_ARRAY_TYPE_LOOKUP = {
'Int8Array': 'b',
'Uint8Array': 'B',
'Uint8ClampedArray': 'B',
'Int16Array': 'h',
'Uint16Array': 'H',
'Int32Array': 'i',
'Uint32Array': 'I',
'Float32Array': 'f',
'Float64Array': 'd',
'BigInt64Array': 'l',
'BigUint64Array': 'L',
'ArrayBuffer': 'B',
}
def parse_iter(parsed: typing.Any, /, *, revivers: dict[str, collections.abc.Callable[[list], typing.Any]] | None = None):
# based on https://github.com/Rich-Harris/devalue/blob/f3fd2aa93d79f21746555671f955a897335edb1b/src/parse.js
resolved = {
-1: None,
-2: None,
-3: math.nan,
-4: math.inf,
-5: -math.inf,
-6: -0.0,
}
if isinstance(parsed, int) and not isinstance(parsed, bool):
if parsed not in resolved or parsed == -2:
raise ValueError('invalid integer input')
return resolved[parsed]
elif not isinstance(parsed, list):
raise ValueError('expected int or list as input')
elif not parsed:
raise ValueError('expected a non-empty list as input')
if revivers is None:
revivers = {}
return_value = [None]
stack: list[tuple] = [(return_value, 0, 0)]
while stack:
target, index, source = stack.pop()
if isinstance(source, tuple):
name, source, reviver = source
try:
resolved[source] = target[index] = reviver(target[index])
except Exception as error:
yield TypeError(f'failed to parse {source} as {name!r}: {error}')
resolved[source] = target[index] = None
continue
if source in resolved:
target[index] = resolved[source]
continue
# guard against Python negative indexing
if source < 0:
yield IndexError(f'invalid index: {source!r}')
continue
try:
value = parsed[source]
except IndexError as error:
yield error
continue
if isinstance(value, list):
if value and isinstance(value[0], str):
# TODO: implement zips `strict=True`
if reviver := revivers.get(value[0]):
if value[1] == source:
# XXX: avoid infinite loop
yield IndexError(f'{value[0]!r} cannot point to itself (index: {source})')
continue
# inverse order: resolve index, revive value
stack.append((target, index, (value[0], value[1], reviver)))
stack.append((target, index, value[1]))
continue
elif value[0] == 'Date':
try:
result = dt.datetime.fromtimestamp(parse_iso8601(value[1]), tz=dt.timezone.utc)
except Exception:
yield ValueError(f'invalid date: {value[1]!r}')
result = None
elif value[0] == 'Set':
result = [None] * (len(value) - 1)
for offset, new_source in enumerate(value[1:]):
stack.append((result, offset, new_source))
elif value[0] == 'Map':
result = []
for key, new_source in zip(*(iter(value[1:]),) * 2, strict=True):
pair = [None, None]
stack.append((pair, 0, key))
stack.append((pair, 1, new_source))
result.append(pair)
elif value[0] == 'RegExp':
# XXX: use jsinterp to translate regex flags
# currently ignores `value[2]`
result = re.compile(value[1])
elif value[0] == 'Object':
result = value[1]
elif value[0] == 'BigInt':
result = int(value[1])
elif value[0] == 'null':
result = {}
for key, new_source in zip(*(iter(value[1:]),) * 2, strict=True):
stack.append((result, key, new_source))
elif value[0] in _ARRAY_TYPE_LOOKUP:
typecode = _ARRAY_TYPE_LOOKUP[value[0]]
data = base64.b64decode(value[1])
result = array.array(typecode, data).tolist()
else:
yield TypeError(f'invalid type at {source}: {value[0]!r}')
result = None
else:
result = len(value) * [None]
for offset, new_source in enumerate(value):
stack.append((result, offset, new_source))
elif isinstance(value, dict):
result = {}
for key, new_source in value.items():
stack.append((result, key, new_source))
else:
result = value
target[index] = resolved[source] = result
return return_value[0]
def parse(parsed: typing.Any, /, *, revivers: dict[str, collections.abc.Callable[[typing.Any], typing.Any]] | None = None):
generator = parse_iter(parsed, revivers=revivers)
while True:
try:
raise generator.send(None)
except StopIteration as error:
return error.value

View File

@@ -0,0 +1,256 @@
from __future__ import annotations
import collections
import collections.abc
import random
import typing
import urllib.parse
import urllib.request
if typing.TYPE_CHECKING:
T = typing.TypeVar('T')
from ._utils import NO_DEFAULT, remove_start, format_field
from .traversal import traverse_obj
def random_user_agent():
USER_AGENT_TMPL = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/{} Safari/537.36'
# Target versions released within the last ~6 months
CHROME_MAJOR_VERSION_RANGE = (134, 140)
return USER_AGENT_TMPL.format(f'{random.randint(*CHROME_MAJOR_VERSION_RANGE)}.0.0.0')
class HTTPHeaderDict(dict):
"""
Store and access keys case-insensitively.
The constructor can take multiple dicts, in which keys in the latter are prioritised.
Retains a case sensitive mapping of the headers, which can be accessed via `.sensitive()`.
"""
def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
obj = dict.__new__(cls, *args, **kwargs)
obj.__sensitive_map = {}
return obj
def __init__(self, /, *args, **kwargs):
super().__init__()
self.__sensitive_map = {}
for dct in filter(None, args):
self.update(dct)
if kwargs:
self.update(kwargs)
def sensitive(self, /) -> dict[str, str]:
return {
self.__sensitive_map[key]: value
for key, value in self.items()
}
def __contains__(self, key: str, /) -> bool:
return super().__contains__(key.title() if isinstance(key, str) else key)
def __delitem__(self, key: str, /) -> None:
key = key.title()
del self.__sensitive_map[key]
super().__delitem__(key)
def __getitem__(self, key, /) -> str:
return super().__getitem__(key.title())
def __ior__(self, other, /):
if isinstance(other, type(self)):
other = other.sensitive()
if isinstance(other, dict):
self.update(other)
return
return NotImplemented
def __or__(self, other, /) -> typing.Self:
if isinstance(other, type(self)):
other = other.sensitive()
if isinstance(other, dict):
return type(self)(self.sensitive(), other)
return NotImplemented
def __ror__(self, other, /) -> typing.Self:
if isinstance(other, type(self)):
other = other.sensitive()
if isinstance(other, dict):
return type(self)(other, self.sensitive())
return NotImplemented
def __setitem__(self, key: str, value, /) -> None:
if isinstance(value, bytes):
value = value.decode('latin-1')
key_title = key.title()
self.__sensitive_map[key_title] = key
super().__setitem__(key_title, str(value).strip())
def clear(self, /) -> None:
self.__sensitive_map.clear()
super().clear()
def copy(self, /) -> typing.Self:
return type(self)(self.sensitive())
@typing.overload
def get(self, key: str, /) -> str | None: ...
@typing.overload
def get(self, key: str, /, default: T) -> str | T: ...
def get(self, key, /, default=NO_DEFAULT):
key = key.title()
if default is NO_DEFAULT:
return super().get(key)
return super().get(key, default)
@typing.overload
def pop(self, key: str, /) -> str: ...
@typing.overload
def pop(self, key: str, /, default: T) -> str | T: ...
def pop(self, key, /, default=NO_DEFAULT):
key = key.title()
if default is NO_DEFAULT:
self.__sensitive_map.pop(key)
return super().pop(key)
self.__sensitive_map.pop(key, default)
return super().pop(key, default)
def popitem(self) -> tuple[str, str]:
self.__sensitive_map.popitem()
return super().popitem()
@typing.overload
def setdefault(self, key: str, /) -> str: ...
@typing.overload
def setdefault(self, key: str, /, default) -> str: ...
def setdefault(self, key, /, default=None) -> str:
key = key.title()
if key in self.__sensitive_map:
return super().__getitem__(key)
self[key] = default or ''
return self[key]
def update(self, other, /, **kwargs) -> None:
if isinstance(other, type(self)):
other = other.sensitive()
if isinstance(other, collections.abc.Mapping):
for key, value in other.items():
self[key] = value
elif hasattr(other, 'keys'):
for key in other.keys(): # noqa: SIM118
self[key] = other[key]
else:
for key, value in other:
self[key] = value
for key, value in kwargs.items():
self[key] = value
std_headers = HTTPHeaderDict({
'User-Agent': random_user_agent(),
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Language': 'en-us,en;q=0.5',
'Sec-Fetch-Mode': 'navigate',
})
def clean_proxies(proxies: dict, headers: HTTPHeaderDict):
req_proxy = headers.pop('Ytdl-Request-Proxy', None)
if req_proxy:
proxies.clear() # XXX: compat: Ytdl-Request-Proxy takes preference over everything, including NO_PROXY
proxies['all'] = req_proxy
for proxy_key, proxy_url in proxies.items():
if proxy_url == '__noproxy__':
proxies[proxy_key] = None
continue
if proxy_key == 'no': # special case
continue
if proxy_url is not None:
# Ensure proxies without a scheme are http.
try:
proxy_scheme = urllib.request._parse_proxy(proxy_url)[0]
except ValueError:
# Ignore invalid proxy URLs. Sometimes these may be introduced through environment
# variables unrelated to proxy settings - e.g. Colab `COLAB_LANGUAGE_SERVER_PROXY`.
# If the proxy is going to be used, the Request Handler proxy validation will handle it.
continue
if proxy_scheme is None:
proxies[proxy_key] = 'http://' + remove_start(proxy_url, '//')
replace_scheme = {
'socks5': 'socks5h', # compat: socks5 was treated as socks5h
'socks': 'socks4', # compat: non-standard
}
if proxy_scheme in replace_scheme:
proxies[proxy_key] = urllib.parse.urlunparse(
urllib.parse.urlparse(proxy_url)._replace(scheme=replace_scheme[proxy_scheme]))
def clean_headers(headers: HTTPHeaderDict):
if 'Youtubedl-No-Compression' in headers: # compat
del headers['Youtubedl-No-Compression']
headers['Accept-Encoding'] = 'identity'
headers.pop('Ytdl-socks-proxy', None)
def remove_dot_segments(path):
# Implements RFC3986 5.2.4 remote_dot_segments
# Pseudo-code: https://tools.ietf.org/html/rfc3986#section-5.2.4
# https://github.com/urllib3/urllib3/blob/ba49f5c4e19e6bca6827282feb77a3c9f937e64b/src/urllib3/util/url.py#L263
output = []
segments = path.split('/')
for s in segments:
if s == '.':
continue
elif s == '..':
if output:
output.pop()
else:
output.append(s)
if not segments[0] and (not output or output[0]):
output.insert(0, '')
if segments[-1] in ('.', '..'):
output.append('')
return '/'.join(output)
def escape_rfc3986(s):
"""Escape non-ASCII characters as suggested by RFC 3986"""
return urllib.parse.quote(s, b"%/;:@&=+$,!~*'()?#[]")
def normalize_url(url):
"""Normalize URL as suggested by RFC 3986"""
url_parsed = urllib.parse.urlparse(url)
return url_parsed._replace(
netloc=url_parsed.netloc.encode('idna').decode('ascii'),
path=escape_rfc3986(remove_dot_segments(url_parsed.path)),
params=escape_rfc3986(url_parsed.params),
query=escape_rfc3986(url_parsed.query),
fragment=escape_rfc3986(url_parsed.fragment),
).geturl()
def select_proxy(url, proxies):
"""Unified proxy selector for all backends"""
url_components = urllib.parse.urlparse(url)
if 'no' in proxies:
hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
return
elif urllib.request.proxy_bypass(hostport): # check system settings
return
return traverse_obj(proxies, url_components.scheme or 'http', 'all')

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
import bisect
import threading
import time
class ProgressCalculator:
# Time to calculate the speed over (seconds)
SAMPLING_WINDOW = 3
# Minimum timeframe before to sample next downloaded bytes (seconds)
SAMPLING_RATE = 0.05
# Time before showing eta (seconds)
GRACE_PERIOD = 1
def __init__(self, initial: int):
self._initial = initial or 0
self.downloaded = self._initial
self.elapsed: float = 0
self.speed = SmoothValue(0, smoothing=0.7)
self.eta = SmoothValue(None, smoothing=0.9)
self._total = 0
self._start_time = time.monotonic()
self._last_update = self._start_time
self._lock = threading.Lock()
self._thread_sizes: dict[int, int] = {}
self._times = [self._start_time]
self._downloaded = [self.downloaded]
@property
def total(self):
return self._total
@total.setter
def total(self, value: int | None):
with self._lock:
if value is not None and value < self.downloaded:
value = self.downloaded
self._total = value
def thread_reset(self):
current_thread = threading.get_ident()
with self._lock:
self._thread_sizes[current_thread] = 0
def update(self, size: int | None):
if not size:
return
current_thread = threading.get_ident()
with self._lock:
last_size = self._thread_sizes.get(current_thread, 0)
self._thread_sizes[current_thread] = size
self._update(size - last_size)
def _update(self, size: int):
current_time = time.monotonic()
self.downloaded += size
self.elapsed = current_time - self._start_time
if self.total is not None and self.downloaded > self.total:
self._total = self.downloaded
if self._last_update + self.SAMPLING_RATE > current_time:
return
self._last_update = current_time
self._times.append(current_time)
self._downloaded.append(self.downloaded)
offset = bisect.bisect_left(self._times, current_time - self.SAMPLING_WINDOW)
del self._times[:offset]
del self._downloaded[:offset]
if len(self._times) < 2:
self.speed.reset()
self.eta.reset()
return
download_time = current_time - self._times[0]
if not download_time:
return
self.speed.set((self.downloaded - self._downloaded[0]) / download_time)
if self.total and self.speed.value and self.elapsed > self.GRACE_PERIOD:
self.eta.set((self.total - self.downloaded) / self.speed.value)
else:
self.eta.reset()
class SmoothValue:
def __init__(self, initial: float | None, smoothing: float):
self.value = self.smooth = self._initial = initial
self._smoothing = smoothing
def set(self, value: float):
self.value = value
if self.smooth is None:
self.smooth = self.value
else:
self.smooth = (1 - self._smoothing) * value + self._smoothing * self.smooth
def reset(self):
self.value = self.smooth = self._initial

View File

@@ -0,0 +1,477 @@
from __future__ import annotations
import collections
import collections.abc
import contextlib
import functools
import http.cookies
import inspect
import itertools
import re
import typing
import xml.etree.ElementTree
from ._utils import (
IDENTITY,
NO_DEFAULT,
ExtractorError,
LazyList,
deprecation_warning,
get_elements_html_by_class,
get_elements_html_by_attribute,
get_elements_by_attribute,
get_element_by_class,
get_element_html_by_attribute,
get_element_by_attribute,
get_element_html_by_id,
get_element_by_id,
get_element_html_by_class,
get_elements_by_class,
get_element_text_and_html_by_tag,
is_iterable_like,
try_call,
url_or_none,
variadic,
)
def traverse_obj(
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
"""
Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
'value'
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
`xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of:
- `None`: Return the current object.
- `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{type, type, ...}`/`{func}`. If a `type`, return only
values of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values.
Read as: `[traverse_obj(obj, branch) for branch in branches]`.
- `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`.
For `Iterable`s, `key` is the index of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict`: Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
- `any`-builtin: Take the first matching object and return it, resetting branching.
- `all`-builtin: Take all matching objects and return them as a list, resetting branching.
- `filter`-builtin: Return the value if it is truthy, `None` otherwise.
`tuple`, `list`, and `dict` all support nested paths and branches.
@params paths Paths by which to traverse.
@param default Value to return if the paths do not match.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, recursively. This does respect branching paths.
@param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive.
`traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
@param traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
The return value of that path will be a string instead,
not respecting any further branching.
@returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead.
If no `default` is given and the last path branches, a `list` of results
is always returned. If a path ends on a `dict` that result will always be a `dict`.
"""
if is_user_input is not NO_DEFAULT:
deprecation_warning('The is_user_input parameter is deprecated and no longer works')
casefold = lambda k: k.casefold() if isinstance(k, str) else k
if isinstance(expected_type, type):
type_test = lambda val: val if isinstance(val, expected_type) else None
else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def apply_key(key, obj, is_last):
branching = False
result = None
if obj is None and traverse_string:
if key is ... or callable(key) or isinstance(key, slice):
branching = True
result = ()
elif key is None:
result = obj
elif isinstance(key, set):
item = next(iter(key))
if len(key) > 1 or isinstance(item, type):
assert all(isinstance(item, type) for item in key)
if isinstance(obj, tuple(key)):
result = obj
else:
result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)):
branching = True
result = itertools.chain.from_iterable(
apply_path(obj, branch, is_last)[0] for branch in key)
elif key is ...:
branching = True
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, collections.abc.Mapping):
result = obj.values()
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
result = obj
elif isinstance(obj, re.Match):
result = obj.groups()
elif traverse_string:
branching = False
result = str(obj)
else:
result = ()
elif callable(key):
branching = True
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
iter_obj = enumerate(obj)
elif isinstance(obj, re.Match):
iter_obj = itertools.chain(
enumerate((obj.group(), *obj.groups())),
obj.groupdict().items())
elif traverse_string:
branching = False
iter_obj = enumerate(str(obj))
else:
iter_obj = ()
result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
if not branching: # string traversal
result = ''.join(result)
elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
result = {
k: v if v is not None else default for k, v in iter_obj
if v is not None or default is not NO_DEFAULT
} or None
elif isinstance(obj, collections.abc.Mapping):
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match):
if isinstance(key, int) or casesense:
with contextlib.suppress(IndexError):
result = obj.group(key)
elif isinstance(key, str):
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, (int, slice)):
if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
branching = isinstance(key, slice)
with contextlib.suppress(IndexError):
result = obj[key]
elif traverse_string:
with contextlib.suppress(IndexError):
result = str(obj)[key]
elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
xpath, _, special = key.rpartition('/')
if not special.startswith('@') and not special.endswith('()'):
xpath = key
special = None
# Allow abbreviations of relative paths, absolute paths error
if xpath.startswith('/'):
xpath = f'.{xpath}'
elif xpath and not xpath.startswith('./'):
xpath = f'./{xpath}'
def apply_specials(element):
if special is None:
return element
if special == '@':
return element.attrib
if special.startswith('@'):
return try_call(element.attrib.get, args=(special[1:],))
if special == 'text()':
return element.text
raise SyntaxError(f'apply_specials is missing case for {special!r}')
if xpath:
result = list(map(apply_specials, obj.iterfind(xpath)))
else:
result = apply_specials(obj)
return branching, result if branching else (result,)
def lazy_last(iterable):
iterator = iter(iterable)
prev = next(iterator, NO_DEFAULT)
if prev is NO_DEFAULT:
return
for item in iterator:
yield False, prev
prev = item
yield True, prev
def apply_path(start_obj, path, test_type):
objs = (start_obj,)
has_branched = False
key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if not casesense and isinstance(key, str):
key = key.casefold()
if key in (any, all):
has_branched = False
filtered_objs = (obj for obj in objs if obj not in (None, {}))
if key is any:
objs = (next(filtered_objs, None),)
else:
objs = (list(filtered_objs),)
continue
if key is filter:
objs = filter(None, objs)
continue
if __debug__ and callable(key):
# Verify function signature
inspect.signature(key).bind(None, None)
new_objs = []
for obj in objs:
branching, results = apply_key(key, obj, last)
has_branched |= branching
new_objs.append(results)
objs = itertools.chain.from_iterable(new_objs)
if test_type and not isinstance(key, (dict, list, tuple)):
objs = map(type_test, objs)
return objs, has_branched, isinstance(key, dict)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
results = LazyList(item for item in results if item not in (None, {}))
if get_all and has_branched:
if results:
return results.exhaust()
if allow_empty:
return [] if default is NO_DEFAULT else default
return None
return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1):
is_last = index == len(paths)
try:
result = _traverse_obj(obj, path, is_last, True)
if result is not None:
return result
except _RequiredError as e:
if is_last:
# Reraise to get cleaner stack trace
raise ExtractorError(e.orig_msg, expected=e.expected) from None
return None if default is NO_DEFAULT else default
def value(value, /):
return lambda _: value
def require(name, /, *, expected=False):
def func(value):
if value is None:
raise _RequiredError(f'Unable to extract {name}', expected=expected)
return value
return func
class _RequiredError(ExtractorError):
pass
@typing.overload
def subs_list_to_dict(*, lang: str | None = 'und', ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ...
@typing.overload
def subs_list_to_dict(subs: list[dict] | None, /, *, lang: str | None = 'und', ext: str | None = None) -> dict[str, list[dict]]: ...
def subs_list_to_dict(subs: list[dict] | None = None, /, *, lang='und', ext=None):
"""
Convert subtitles from a traversal into a subtitle dict.
The path should have an `all` immediately before this function.
Arguments:
`ext` The default value for `ext` in the subtitle dict
In the dict you can set the following additional items:
`id` The subtitle id to sort the dict into
`quality` The sort order for each subtitle
"""
if subs is None:
return functools.partial(subs_list_to_dict, lang=lang, ext=ext)
result = collections.defaultdict(list)
for sub in subs:
if not url_or_none(sub.get('url')) and not sub.get('data'):
continue
sub_id = sub.pop('id', None)
if not isinstance(sub_id, str):
if not lang:
continue
sub_id = lang
sub_ext = sub.get('ext')
if not isinstance(sub_ext, str):
if not ext:
sub.pop('ext', None)
else:
sub['ext'] = ext
result[sub_id].append(sub)
result = dict(result)
for subs in result.values():
subs.sort(key=lambda x: x.pop('quality', 0) or 0)
return result
@typing.overload
def find_element(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ...
@typing.overload
def find_element(*, cls: str, html=False): ...
@typing.overload
def find_element(*, id: str, tag: str | None = None, html=False, regex=False): ...
@typing.overload
def find_element(*, tag: str, html=False, regex=False): ...
def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False, regex=False):
# deliberately using `id=` and `cls=` for ease of readability
assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required'
ANY_TAG = r'[\w:.-]+'
if attr and value:
assert not cls, 'Cannot match both attr and cls'
assert not id, 'Cannot match both attr and id'
func = get_element_html_by_attribute if html else get_element_by_attribute
return functools.partial(func, attr, value, tag=tag or ANY_TAG, escape_value=not regex)
elif cls:
assert not id, 'Cannot match both cls and id'
assert tag is None, 'Cannot match both cls and tag'
assert not regex, 'Cannot use regex with cls'
func = get_element_html_by_class if html else get_element_by_class
return functools.partial(func, cls)
elif id:
func = get_element_html_by_id if html else get_element_by_id
return functools.partial(func, id, tag=tag or ANY_TAG, escape_value=not regex)
index = int(bool(html))
return lambda html: get_element_text_and_html_by_tag(tag, html)[index]
@typing.overload
def find_elements(*, cls: str, html=False): ...
@typing.overload
def find_elements(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ...
def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False, regex=False):
# deliberately using `cls=` for ease of readability
assert cls or (attr and value), 'One of cls or (attr AND value) is required'
if attr and value:
assert not cls, 'Cannot match both attr and cls'
func = get_elements_html_by_attribute if html else get_elements_by_attribute
return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+', escape_value=not regex)
assert not tag, 'Cannot match both cls and tag'
assert not regex, 'Cannot use regex with cls'
func = get_elements_html_by_class if html else get_elements_by_class
return functools.partial(func, cls)
def trim_str(*, start=None, end=None):
def trim(s):
if s is None:
return None
start_idx = 0
if start and s.startswith(start):
start_idx = len(start)
if end and s.endswith(end):
return s[start_idx:-len(end)]
return s[start_idx:]
return trim
def unpack(func, **kwargs):
@functools.wraps(func)
def inner(items):
return func(*items, **kwargs)
return inner
def get_first(obj, *paths, **kwargs):
return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
for val in map(d.get, variadic(key_or_keys)):
if val is not None and (val or not skip_false_values):
return val
return default