stillerman commited on
Commit
76887e4
·
1 Parent(s): e91ced9

fastapi server

Browse files
Files changed (4) hide show
  1. README.md +3 -0
  2. api_server.py +68 -0
  3. db/wiki_db_api.py +60 -0
  4. engine.py +6 -2
README.md CHANGED
@@ -2,3 +2,6 @@ wget https://dumps.wikimedia.org/simplewiki/20250420/simplewiki-20250420-pages-a
2
 
3
 
4
  python db/wiki_parser_sqlite.py simplewiki-20250420-pages-articles-multistream.xml.bz2 db/data/wikihop.db --batch-size 10000
 
 
 
 
2
 
3
 
4
  python db/wiki_parser_sqlite.py simplewiki-20250420-pages-articles-multistream.xml.bz2 db/data/wikihop.db --batch-size 10000
5
+
6
+
7
+ WIKI_DB_PATH=db/data/wikihop.db python api_server.py
api_server.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import uvicorn
4
+ import os
5
+ from typing import List, Dict, Any, Optional
6
+ from db.wiki_db_sqlite import WikiDBSqlite
7
+
8
+ app = FastAPI(title="Wiki API Server")
9
+
10
+ # Add CORS middleware
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
+
19
+ # Global database connection
20
+ db = None
21
+
22
+ @app.on_event("startup")
23
+ async def startup_db_client():
24
+ global db
25
+ db_path = os.environ.get("WIKI_DB_PATH")
26
+ if not db_path:
27
+ raise ValueError("WIKI_DB_PATH environment variable not set")
28
+ db = WikiDBSqlite(db_path)
29
+
30
+ @app.get("/article_count")
31
+ async def get_article_count() -> Dict[str, int]:
32
+ """Get the number of articles in the database"""
33
+ return {"count": db.get_article_count()}
34
+
35
+ @app.get("/article_titles")
36
+ async def get_article_titles() -> Dict[str, List[str]]:
37
+ """Get all article titles"""
38
+ return {"titles": db.get_all_article_titles()}
39
+
40
+ @app.get("/article")
41
+ async def get_article(title: str = Query(..., description="Article title")) -> Dict[str, Any]:
42
+ """Get article data by title"""
43
+ article = db.get_article(title)
44
+ if not article:
45
+ raise HTTPException(status_code=404, detail=f"Article '{title}' not found")
46
+ return article
47
+
48
+ @app.get("/article_exists")
49
+ async def article_exists(title: str = Query(..., description="Article title")) -> Dict[str, bool]:
50
+ """Check if an article exists"""
51
+ return {"exists": db.article_exists(title)}
52
+
53
+ @app.get("/article_text")
54
+ async def get_article_text(title: str = Query(..., description="Article title")) -> Dict[str, str]:
55
+ """Get the text of an article"""
56
+ if not db.article_exists(title):
57
+ raise HTTPException(status_code=404, detail=f"Article '{title}' not found")
58
+ return {"text": db.get_article_text(title)}
59
+
60
+ @app.get("/article_links")
61
+ async def get_article_links(title: str = Query(..., description="Article title")) -> Dict[str, List[str]]:
62
+ """Get the links of an article"""
63
+ if not db.article_exists(title):
64
+ raise HTTPException(status_code=404, detail=f"Article '{title}' not found")
65
+ return {"links": db.get_article_links(title)}
66
+
67
+ if __name__ == "__main__":
68
+ uvicorn.run("api_server:app", host="0.0.0.0", port=8000, reload=True)
db/wiki_db_api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ class WikiDBAPI:
4
+ def __init__(self, api_endpoint):
5
+ """Initialize the database with API endpoint URL"""
6
+ self.api_endpoint = api_endpoint.rstrip('/')
7
+ # Verify connection and get initial article count
8
+ self._article_count = self._get_article_count()
9
+ print(f"Connected to Wiki API at {api_endpoint} with {self._article_count} articles")
10
+
11
+ def __del__(self):
12
+ """Clean up resources when object is destroyed"""
13
+ # No persistent connection to close in the API version
14
+ pass
15
+
16
+ def _get_article_count(self):
17
+ """Get the number of articles via API"""
18
+ response = requests.get(f"{self.api_endpoint}/article_count")
19
+ response.raise_for_status()
20
+ return response.json()["count"]
21
+
22
+ def get_article_count(self):
23
+ """Return the number of articles"""
24
+ return self._article_count
25
+
26
+ def get_all_article_titles(self):
27
+ """Return a list of all article titles"""
28
+ response = requests.get(f"{self.api_endpoint}/article_titles")
29
+ response.raise_for_status()
30
+ return response.json()["titles"]
31
+
32
+ def get_article(self, title):
33
+ """Get article data by title"""
34
+ response = requests.get(f"{self.api_endpoint}/article", params={"title": title})
35
+ if response.status_code == 404:
36
+ return {}
37
+ response.raise_for_status()
38
+ return response.json()
39
+
40
+ def article_exists(self, title):
41
+ """Check if an article exists"""
42
+ response = requests.get(f"{self.api_endpoint}/article_exists", params={"title": title})
43
+ response.raise_for_status()
44
+ return response.json()["exists"]
45
+
46
+ def get_article_text(self, title):
47
+ """Get the text of an article"""
48
+ response = requests.get(f"{self.api_endpoint}/article_text", params={"title": title})
49
+ if response.status_code == 404:
50
+ return ''
51
+ response.raise_for_status()
52
+ return response.json()["text"]
53
+
54
+ def get_article_links(self, title):
55
+ """Get the links of an article"""
56
+ response = requests.get(f"{self.api_endpoint}/article_links", params={"title": title})
57
+ if response.status_code == 404:
58
+ return []
59
+ response.raise_for_status()
60
+ return response.json()["links"]
engine.py CHANGED
@@ -6,14 +6,18 @@
6
  import random
7
  from db.wiki_db_sqlite import WikiDBSqlite
8
  from db.wiki_db_json import WikiDBJson
9
-
10
  class WikiRunEnvironment:
11
  def __init__(self, wiki_data_path):
12
  """Initialize with path to Wikipedia data"""
13
  if wiki_data_path.endswith('.json'):
14
  self.db = WikiDBJson(wiki_data_path)
15
- else:
16
  self.db = WikiDBSqlite(wiki_data_path)
 
 
 
 
17
 
18
  self.current_article = None
19
  self.target_article = None
 
6
  import random
7
  from db.wiki_db_sqlite import WikiDBSqlite
8
  from db.wiki_db_json import WikiDBJson
9
+ from db.wiki_db_api import WikiDBAPI
10
  class WikiRunEnvironment:
11
  def __init__(self, wiki_data_path):
12
  """Initialize with path to Wikipedia data"""
13
  if wiki_data_path.endswith('.json'):
14
  self.db = WikiDBJson(wiki_data_path)
15
+ elif wiki_data_path.endswith('.db'):
16
  self.db = WikiDBSqlite(wiki_data_path)
17
+ elif wiki_data_path.startswith('http'):
18
+ self.db = WikiDBAPI(wiki_data_path)
19
+ else:
20
+ raise ValueError("Invalid file extension. Must be .json or .db")
21
 
22
  self.current_article = None
23
  self.target_article = None