Skip to content

Commit 56b6c2b

Browse files
committed
fix: initial work on handler fix
1 parent cd2671c commit 56b6c2b

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

aws_lambda_powertools/logging/logger.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,14 @@ def __init__(
233233
self.child = child
234234
self.logger_formatter = logger_formatter
235235
self._stream = stream or sys.stdout
236-
self.logger_handler = logger_handler or logging.StreamHandler(self._stream)
237236
self.log_uncaught_exceptions = log_uncaught_exceptions
238237

239238
self._is_deduplication_disabled = resolve_truthy_env_var_choice(
240239
env=os.getenv(constants.LOGGER_LOG_DEDUPLICATION_ENV, "false"),
241240
)
242241
self._default_log_keys = {"service": self.service, "sampling_rate": self.sampling_rate}
243242
self._logger = self._get_logger()
243+
self.logger_handler = self._get_handler(custom_handler=logger_handler)
244244

245245
# NOTE: This is primarily to improve UX, so IDEs can autocomplete LambdaPowertoolsFormatter options
246246
# previously, we masked all of them as kwargs thus limiting feature discovery
@@ -279,6 +279,15 @@ def _get_logger(self) -> logging.Logger:
279279

280280
return logging.getLogger(logger_name)
281281

282+
def _get_handler(self, custom_handler: logging.Handler):
283+
if self.child:
284+
return self._logger.parent.powertools_handler
285+
286+
if custom_handler:
287+
return custom_handler
288+
289+
return logging.StreamHandler(self._stream)
290+
282291
def _init_logger(
283292
self,
284293
formatter_options: Optional[Dict] = None,
@@ -317,6 +326,7 @@ def _init_logger(
317326
# std logging will return the same Logger with our attribute if name is reused
318327
logger.debug(f"Marking logger {self.service} as preconfigured")
319328
self._logger.init = True # type: ignore[attr-defined]
329+
self._logger.powertools_handler = self.logger_handler # type: ignore[attr-defined]
320330

321331
def _configure_sampling(self) -> None:
322332
"""Dynamically set log level based on sampling rate
@@ -672,10 +682,7 @@ def removeFilter(self, filter: logging._FilterType) -> None: # noqa: A002 # fil
672682
@property
673683
def registered_handler(self) -> logging.Handler:
674684
"""Convenience property to access the first logger handler"""
675-
# We ignore mypy here because self.child encodes whether or not self._logger.parent is
676-
# None, mypy can't see this from context but we can
677-
handlers = self._logger.parent.handlers if self.child else self._logger.handlers # type: ignore[union-attr]
678-
return handlers[0]
685+
return self.logger_handler
679686

680687
@property
681688
def registered_formatter(self) -> BasePowertoolsFormatter:

tests/functional/test_logger.py

+37
Original file line numberDiff line numberDiff line change
@@ -1176,3 +1176,40 @@ def test_logger_json_unicode(stdout, service_name):
11761176

11771177
assert log["message"] == non_ascii_chars
11781178
assert log[japanese_field] == japanese_string
1179+
1180+
1181+
def test_logger_registered_handler_is_custom_handler(service_name):
1182+
# GIVEN a library or environment pre-setup a logger for us using the same name (see #4277)
1183+
class ForeignHandler(logging.StreamHandler): ...
1184+
1185+
foreign_handler = ForeignHandler()
1186+
logging.getLogger(service_name).addHandler(foreign_handler)
1187+
1188+
# WHEN Logger init with a custom handler
1189+
custom_handler = logging.StreamHandler()
1190+
logger = Logger(service=service_name, logger_handler=custom_handler)
1191+
1192+
# THEN registered handler should always return what we provided
1193+
assert logger.registered_handler is not foreign_handler
1194+
assert logger.registered_handler is custom_handler
1195+
assert logger.logger_handler is custom_handler
1196+
1197+
1198+
def test_child_logger_registered_handler_is_custom_handler(service_name):
1199+
# GIVEN
1200+
class ForeignHandler(logging.StreamHandler): ...
1201+
1202+
foreign_handler = ForeignHandler()
1203+
logging.getLogger(service_name).addHandler(foreign_handler)
1204+
1205+
custom_handler = logging.StreamHandler()
1206+
custom_handler.name = "CUSTOM HANDLER"
1207+
parent = Logger(service=service_name, logger_handler=custom_handler)
1208+
1209+
# WHEN a child Logger init
1210+
child = Logger(service=service_name, child=True)
1211+
1212+
# THEN child registered handler should always return what we provided in the parent
1213+
assert child.registered_handler is not foreign_handler
1214+
assert child.registered_handler is custom_handler
1215+
assert child.registered_handler is parent.registered_handler

0 commit comments

Comments
 (0)