|
import threading |
|
from typing import Any, TypeVar, cast |
|
from streamlit.errors import NoSessionContext |
|
from streamlit.runtime.scriptrunner.script_run_context import SCRIPT_RUN_CONTEXT_ATTR_NAME, get_script_run_ctx |
|
|
|
T = TypeVar("T") |
|
|
|
def with_streamlit_context(fn: T) -> T: |
|
"""Fix bug in streamlit which raises streamlit.errors.NoSessionContext.""" |
|
ctx = get_script_run_ctx() |
|
if ctx is None: |
|
raise NoSessionContext( |
|
"with_streamlit_context must be called inside a context; " |
|
"construct your function on the fly, not earlier." |
|
) |
|
|
|
def _cb(*args: Any, **kwargs: Any) -> Any: |
|
"""Do it.""" |
|
thread = threading.current_thread() |
|
do_nothing = hasattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME) and ( |
|
getattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME) == ctx |
|
) |
|
if not do_nothing: |
|
setattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, ctx) |
|
|
|
ret = fn(*args, **kwargs) |
|
if not do_nothing: |
|
|
|
|
|
delattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME) |
|
return ret |
|
|
|
return cast(T, _cb) |