[InfoExtractor] Correctly resolve BaseURL in DASH manifest

Specs:
* ISO/IEC 23009-1:2012 section 5.6
* RFC 3986 section 5.
This commit is contained in:
dirkf
2024-01-27 15:45:43 +00:00
parent 4eaeb9b2c6
commit 1fd8f802b8

View File

@@ -2262,9 +2262,24 @@ class InfoExtractor(object):
def is_drm_protected(element): def is_drm_protected(element):
return element.find(_add_ns('ContentProtection')) is not None return element.find(_add_ns('ContentProtection')) is not None
from ..utils import YoutubeDLHandler
fix_path = YoutubeDLHandler._fix_path
def resolve_base_url(element, parent_base_url=None):
# TODO: use native XML traversal when ready
b_url = traverse_obj(element, (
T(lambda e: e.find(_add_ns('BaseURL')).text)))
if parent_base_url and b_url:
if not parent_base_url[-1] in ('/', ':'):
parent_base_url += '/'
b_url = compat_urlparse.urljoin(parent_base_url, b_url)
if b_url:
b_url = fix_path(b_url)
return b_url or parent_base_url
def extract_multisegment_info(element, ms_parent_info): def extract_multisegment_info(element, ms_parent_info):
ms_info = ms_parent_info.copy() ms_info = ms_parent_info.copy()
base_url = ms_info.get('base_url') base_url = ms_info['base_url'] = resolve_base_url(element, ms_info.get('base_url'))
# As per [1, 5.3.9.2.2] SegmentList and SegmentTemplate share some # As per [1, 5.3.9.2.2] SegmentList and SegmentTemplate share some
# common attributes and elements. We will only extract relevant # common attributes and elements. We will only extract relevant
@@ -2336,11 +2351,13 @@ class InfoExtractor(object):
mpd_duration = parse_duration(mpd_doc.get('mediaPresentationDuration')) mpd_duration = parse_duration(mpd_doc.get('mediaPresentationDuration'))
formats, subtitles = [], {} formats, subtitles = [], {}
stream_numbers = collections.defaultdict(int) stream_numbers = collections.defaultdict(int)
mpd_base_url = resolve_base_url(mpd_doc, mpd_base_url or mpd_url)
for period in mpd_doc.findall(_add_ns('Period')): for period in mpd_doc.findall(_add_ns('Period')):
period_duration = parse_duration(period.get('duration')) or mpd_duration period_duration = parse_duration(period.get('duration')) or mpd_duration
period_ms_info = extract_multisegment_info(period, { period_ms_info = extract_multisegment_info(period, {
'start_number': 1, 'start_number': 1,
'timescale': 1, 'timescale': 1,
'base_url': mpd_base_url,
}) })
for adaptation_set in period.findall(_add_ns('AdaptationSet')): for adaptation_set in period.findall(_add_ns('AdaptationSet')):
if is_drm_protected(adaptation_set): if is_drm_protected(adaptation_set):