from __future__ import annotations
import logging
import warnings
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from time import perf_counter
from typing import Dict, Iterable, List, Tuple
from rics.performance import format_perf_counter
from rics.translation.fetching import Fetcher, exceptions
from rics.translation.fetching.types import IdsToFetch
from rics.translation.offline.types import SourcePlaceholderTranslations
from rics.translation.types import IdType, SourceType
from rics.utility.action_level import ActionLevel, ActionLevelHelper
from rics.utility.collections.dicts import reverse_dict
from rics.utility.misc import tname
LOGGER = logging.getLogger(__package__).getChild("MultiFetcher")
FetchResult = Tuple[int, SourcePlaceholderTranslations[SourceType]]
_ACTION_LEVEL_HELPER = ActionLevelHelper(
duplicate_translation_action=ActionLevel.IGNORE,
duplicate_source_discovered_action=None,
)
[docs]class MultiFetcher(Fetcher[SourceType, IdType]):
"""Fetcher which combines the results of other fetchers.
Args:
*fetchers: Fetchers to wrap.
max_workers: Number of threads to use for fetching. Fetch instructions will be dispatched using a
:py:class:`~concurrent.futures.ThreadPoolExecutor`. Individual fetchers will be called at most once per
``fetch()`` or ``fetch_all()`` call made with the ``MultiFetcher``.
duplicate_translation_action: Action to take when multiple fetchers return translations for the same source.
duplicate_source_discovered_action: Action to take when multiple fetchers claim the same source.
"""
def __init__(
self,
*fetchers: Fetcher[SourceType, IdType],
max_workers: int = 2,
duplicate_translation_action: ActionLevel = ActionLevel.WARN,
duplicate_source_discovered_action: ActionLevel = ActionLevel.IGNORE,
) -> None:
for pos, f in enumerate(fetchers):
if not isinstance(f, Fetcher): # pragma: no cover
raise TypeError(f"Argument {pos} is of type {type(f)}, expected Fetcher subtype.")
self._id_to_rank: Dict[int, int] = {id(f): rank for rank, f in enumerate(fetchers)}
self._id_to_fetcher: Dict[int, Fetcher[SourceType, IdType]] = {id(f): f for f in fetchers}
self._source_to_fetcher_id_actual: Dict[SourceType, int] = {}
self.max_workers: int = max_workers
self._duplicate_translation_action = _ACTION_LEVEL_HELPER.verify(
duplicate_translation_action, "duplicate_translation_action"
)
self._duplicate_source_discovered_action = _ACTION_LEVEL_HELPER.verify(
duplicate_source_discovered_action, "duplicate_source_discovered_action"
)
if len(self.fetchers) != len(fetchers):
raise ValueError("Repeat fetcher instance(s)!") # pragma: no cover
@property
def allow_fetch_all(self) -> bool:
return all(f.allow_fetch_all for f in self._id_to_fetcher.values()) # pragma: no cover
[docs] def online(self) -> bool:
return all(f.online for f in self._id_to_fetcher.values()) # pragma: no cover
@property
def placeholders(self) -> Dict[SourceType, List[str]]:
return {
source: self._id_to_fetcher[self._source_to_fetcher_id[source]].placeholders[source]
for source in self.sources
}
@property
def fetchers(self) -> List[Fetcher[SourceType, IdType]]:
"""Return child fetchers."""
return list(self._id_to_fetcher.values())
@property
def sources(self) -> List[SourceType]:
return list(self._source_to_fetcher_id)
[docs] def fetch(
self,
ids_to_fetch: Iterable[IdsToFetch[SourceType, IdType]],
placeholders: Iterable[str] = (),
required: Iterable[str] = (),
) -> SourcePlaceholderTranslations[SourceType]:
tasks: Dict[int, List[IdsToFetch[SourceType, IdType]]] = {}
num_instructions = 0
for idt in ids_to_fetch:
tasks.setdefault(self._source_to_fetcher_id[idt.source], []).append(idt)
num_instructions += 1
placeholders = tuple(placeholders)
required = tuple(required)
start = perf_counter()
LOGGER.debug(f"Dispatch {num_instructions} jobs to {len(tasks)} fetchers using {self.max_workers} threads.")
def fetch(fid: int) -> FetchResult[SourceType]:
return fid, self._id_to_fetcher[fid].fetch(tasks[fid], placeholders, required=required)
with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix=tname(self)) as executor:
futures = [executor.submit(fetch, fid) for fid in tasks]
ans = self._gather(futures)
LOGGER.debug(f"Completed {num_instructions} jobs in {format_perf_counter(start)}.")
return ans
[docs] def fetch_all(
self, placeholders: Iterable[str] = (), required: Iterable[str] = ()
) -> SourcePlaceholderTranslations[SourceType]:
placeholders = tuple(placeholders)
required = tuple(required)
start = perf_counter()
LOGGER.debug(f"Dispatch FETCH_ALL jobs to {len(self.fetchers)} fetchers using {self.max_workers} threads.")
def fetch_all(fetcher: Fetcher[SourceType, IdType]) -> FetchResult[SourceType]:
return id(fetcher), fetcher.fetch_all(placeholders, required=required)
with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix=tname(self)) as executor:
futures = [executor.submit(fetch_all, fetcher) for fetcher in self.fetchers]
ans = self._gather(futures)
LOGGER.debug(f"Completed all FETCH_ALL jobs completed in {format_perf_counter(start)}.")
return ans
@property
def duplicate_translation_action(self) -> ActionLevel:
"""Return action to take when multiple fetchers return translations for the same source."""
return self._duplicate_translation_action
@property
def duplicate_source_discovered_action(self) -> ActionLevel:
"""Return action to take when multiple fetchers claim the same source."""
return self._duplicate_source_discovered_action
@property
def _source_to_fetcher_id(self) -> Dict[SourceType, int]:
if not self._source_to_fetcher_id_actual:
source_ranks: Dict[SourceType, int] = {}
source_to_fetcher_id: Dict[SourceType, int] = {}
for fid, fetcher in list(self._id_to_fetcher.items()):
sources = fetcher.sources
if not sources:
LOGGER.warning(f"Discarding {self._fmt_fetcher(fetcher)}: No sources.")
fetcher.close()
del self._id_to_rank[fid]
del self._id_to_fetcher[fid]
continue
rank = self._id_to_rank[fid]
for source in sources:
if source in source_to_fetcher_id:
self._log_rejection(source, rank, source_ranks[source], translation=False)
else:
source_to_fetcher_id[source] = fid
source_ranks[source] = rank
self._source_to_fetcher_id_actual = source_to_fetcher_id
return self._source_to_fetcher_id_actual
def _gather(self, futures: Iterable[Future[FetchResult[SourceType]]]) -> SourcePlaceholderTranslations[SourceType]:
ans: SourcePlaceholderTranslations[SourceType] = {}
source_ranks: Dict[SourceType, int] = {}
for future in as_completed(futures):
fid, translations = future.result()
rank = self._id_to_rank[fid]
self._process_future_result(translations, rank, source_ranks, ans)
return ans
def _process_future_result(
self,
translations: SourcePlaceholderTranslations[SourceType],
rank: int,
source_ranks: Dict[SourceType, int],
ans: SourcePlaceholderTranslations[SourceType],
) -> None:
for source_translations in translations.values():
source = source_translations.source
other_rank = source_ranks.setdefault(source, rank)
if other_rank != rank:
self._log_rejection(source, rank, other_rank, translation=True)
if rank > other_rank:
continue # Don't save -- other rank is greater (lower-is-better).
ans[source] = source_translations
def _log_rejection(self, source: SourceType, rank0: int, rank1: int, translation: bool) -> None: # pragma: no cover
accepted_rank, rejected_rank = (rank0, rank1) if rank0 < rank1 else (rank1, rank0)
rank_to_id = reverse_dict(self._id_to_rank)
accepted = self._fmt_fetcher(self._id_to_fetcher[rank_to_id[accepted_rank]])
rejected = self._fmt_fetcher(self._id_to_fetcher[rank_to_id[rejected_rank]])
msg = (
f"Discarded translations for {source=} retrieved from {rejected} since the {accepted} returned "
"translations for the same source."
if translation
else f"Discarded {source=} retrieved from {rejected} since the {accepted} already claimed same source."
)
msg += " Hint: Rank is determined input order at initialization."
action = self.duplicate_translation_action if translation else self.duplicate_source_discovered_action
if action is ActionLevel.IGNORE:
LOGGER.debug(msg)
else:
if action is ActionLevel.RAISE:
LOGGER.error(msg)
raise exceptions.DuplicateSourceError(msg)
else:
warnings.warn(msg, exceptions.DuplicateSourceWarning)
LOGGER.warning(msg)
def __repr__(self) -> str:
fetchers = list(self._id_to_fetcher.values())
max_workers = self.max_workers
return f"{tname(self)}({max_workers=}, {fetchers=})"
def _fmt_fetcher(self, fetcher: Fetcher[SourceType, IdType]) -> str:
"""Format a managed fetcher with rank and hex ID."""
fetcher_id = id(fetcher)
rank = self._id_to_rank[fetcher_id]
return f"rank-{rank} fetcher {fetcher} at {hex(fetcher_id)}"