File size: 910 Bytes
50bd1fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import streamlit as st
import torch

from src.inference import CrossAttentionInference


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

inference = CrossAttentionInference(
    model_path="best_attention_classifier.pth",
    device=device
)

st.title("Random Image Inference")

st.write(
    "Нажмите кнопку ниже, чтобы сгенерировать пару случайных изображений и получить предсказание модели."
)

if st.button("Сгенерировать изображения"):
    pred_label, (img1, img2) = inference.predict_random_pair()

    col1, col2 = st.columns(2)

    with col1:
        st.image(img1, caption="Image 1", use_container_width=True)
    with col2:
        st.image(img2, caption="Image 2", use_container_width=True)

    st.write(f"**Предсказанная метка**: {pred_label}")