kardosdrur commited on
Commit
97054aa
·
1 Parent(s): 899954b
Files changed (2) hide show
  1. Dockerfile +33 -0
  2. main.py +116 -0
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ FROM python:3.11-slim-bullseye
4
+
5
+ RUN apt update
6
+ RUN apt install -y build-essential
7
+
8
+ RUN pip install gunicorn==20.1.0
9
+ RUN pip install typing-extensions
10
+ RUN pip install topic-wizard
11
+ RUN pip install "turftopic>=0.13.0"
12
+
13
+ RUN useradd -m -u 1000 user
14
+ # Switch to the "user" user
15
+ USER user
16
+ # Set home to the user's home directory
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ COPY --chown=user . $HOME/app
21
+
22
+
23
+ RUN mkdir /home/user/numba_cache
24
+ RUN chmod 777 /home/user/numba_cache
25
+
26
+ ENV NUMBA_CACHE_DIR=/home/user/numba_cache
27
+
28
+ # Set the working directory to the user's home directory
29
+ WORKDIR $HOME/app
30
+ RUN git clone https://github.com/x-tabdeveloping/topicwizard
31
+ RUN git checkout topic-arena
32
+ EXPOSE 7860
33
+ CMD gunicorn --timeout 0 -b 0.0.0.0:7860 --workers=2 --threads=4 --worker-class=gthread main:server
main.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash_mantine_components as dmc
2
+ import joblib
3
+ import numpy as np
4
+ from dash_extensions.enrich import (Dash, DashBlueprint, Input, Output, State,
5
+ dcc, exceptions, html)
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.datasets import fetch_20newsgroups
8
+ from topicwizard.widgets import (ConceptClusters, DocumentClusters,
9
+ TopicBrowser, TopicHierarchy,
10
+ create_widget_container)
11
+ from turftopic import ClusteringTopicModel, KeyNMF
12
+
13
+
14
+ def create_app(blueprint):
15
+ app = Dash(
16
+ __name__,
17
+ blueprint=blueprint,
18
+ title="topicwizard",
19
+ external_scripts=[
20
+ {
21
+ "src": "https://cdn.tailwindcss.com",
22
+ },
23
+ {
24
+ "src": "https://kit.fontawesome.com/9640e5cd85.js",
25
+ "crossorigin": "anonymous",
26
+ },
27
+ ],
28
+ )
29
+ return app
30
+
31
+
32
+ print("Fetching data")
33
+ newsgroups = fetch_20newsgroups(
34
+ subset="all",
35
+ remove=("headers", "footers", "quotes"),
36
+ categories=["alt.atheism", "sci.space"],
37
+ )
38
+ corpus = newsgroups.data
39
+
40
+ print("Calculating embeddings")
41
+ encoder = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1")
42
+ embeddings = encoder.encode(corpus, show_progress_bar=True)
43
+
44
+ print("Fitting keynmf")
45
+ keynmf = KeyNMF(5, encoder=encoder, random_state=42)
46
+ keynmf_data = keynmf.prepare_topic_data(corpus, embeddings=embeddings)
47
+ keynmf_data.hierarchy.divide_children(5)
48
+
49
+ print("Fitting top2vec")
50
+ top2vec = ClusteringTopicModel(
51
+ n_reduce_to=5,
52
+ feature_importance="centroid",
53
+ encoder=encoder,
54
+ random_state=0,
55
+ )
56
+ top2vec_data = top2vec.prepare_topic_data(corpus, embeddings=embeddings)
57
+
58
+ print("Building blueprints.")
59
+ keynmf_blueprint = create_widget_container(
60
+ [TopicBrowser(), ConceptClusters(), TopicHierarchy()],
61
+ keynmf_data,
62
+ app_id="keynmf",
63
+ )
64
+ top2vec_blueprint = create_widget_container(
65
+ [TopicBrowser(), DocumentClusters(), TopicHierarchy()],
66
+ top2vec_data,
67
+ app_id="top2vec",
68
+ )
69
+
70
+ app_blueprint = DashBlueprint()
71
+ app_blueprint.layout = html.Div(
72
+ dmc.Group(
73
+ [
74
+ html.Div(
75
+ [
76
+ dmc.Text(
77
+ "KeyNMF",
78
+ size="xl",
79
+ fw=700,
80
+ color="blue.9",
81
+ align="center",
82
+ className="pt-8",
83
+ ),
84
+ keynmf_blueprint.layout,
85
+ ],
86
+ className="h-full flex-1 items-center",
87
+ ),
88
+ html.Div(
89
+ [
90
+ dmc.Text(
91
+ "Top2Vec",
92
+ size="xl",
93
+ fw=700,
94
+ color="teal.9",
95
+ align="center",
96
+ className="pt-8",
97
+ ),
98
+ top2vec_blueprint.layout,
99
+ ],
100
+ className="h-full flex-1 items-center",
101
+ ),
102
+ ],
103
+ grow=True,
104
+ className="h-full flex-1",
105
+ ),
106
+ className="""
107
+ w-full h-full flex-col flex items-stretch fixed
108
+ bg-white
109
+ """,
110
+ )
111
+
112
+ app = create_app(app_blueprint)
113
+ server = app.server
114
+
115
+ if __name__ == "__main__":
116
+ app.run_server(debug=False, port=7860)