mgbam commited on
Commit
e995e3b
·
verified ·
1 Parent(s): e2c04b6

Create providers.py

Browse files
Files changed (1) hide show
  1. genesis/providers.py +133 -0
genesis/providers.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import asyncio
5
+ from typing import List, Tuple
6
+
7
+ import httpx
8
+
9
+ # Optional: Gemini and DeepSeek post-processors to polish final text ONLY.
10
+ # They must never add wet-lab protocols or operational steps.
11
+
12
+
13
+ async def gemini_postprocess(text: str, citations: List[dict]) -> str:
14
+ """
15
+ Polish for clarity/flow using Gemini (sync SDK wrapped in a thread).
16
+ Falls back to original text if not configured or on error.
17
+ """
18
+ api_key = os.getenv("GEMINI_API_KEY")
19
+ if not api_key:
20
+ return text
21
+
22
+ try:
23
+ import google.generativeai as genai
24
+
25
+ genai.configure(api_key=api_key)
26
+ model = genai.GenerativeModel("gemini-1.5-flash")
27
+
28
+ prompt = (
29
+ "Polish the following high-level scientific synthesis for clarity and flow. "
30
+ "Do NOT add wet-lab procedures or any operational/step-by-step details. "
31
+ "Preserve factual claims and cautious tone.\n\n"
32
+ f"{text}"
33
+ )
34
+
35
+ def _call_sync() -> str:
36
+ resp = model.generate_content(prompt)
37
+ return getattr(resp, "text", None) or text
38
+
39
+ return await asyncio.to_thread(_call_sync)
40
+
41
+ except Exception:
42
+ return text
43
+
44
+
45
+ async def deepseek_postprocess(text: str, citations: List[dict]) -> str:
46
+ """
47
+ Polish using a generic OpenAI-compatible DeepSeek endpoint.
48
+ Configure DEEPSEEK_BASE_URL and DEEPSEEK_API_KEY (and optionally DEEPSEEK_MODEL).
49
+ """
50
+ base = os.getenv("DEEPSEEK_BASE_URL")
51
+ key = os.getenv("DEEPSEEK_API_KEY")
52
+ if not base or not key:
53
+ return text
54
+
55
+ try:
56
+ async with httpx.AsyncClient(timeout=60.0) as http:
57
+ r = await http.post(
58
+ f"{base.rstrip('/')}/v1/chat/completions",
59
+ headers={
60
+ "Authorization": f"Bearer {key}",
61
+ "Content-Type": "application/json",
62
+ },
63
+ json={
64
+ "model": os.getenv("DEEPSEEK_MODEL", "deepseek-chat"),
65
+ "messages": [
66
+ {
67
+ "role": "system",
68
+ "content": (
69
+ "You are a scientific editor. Improve structure and clarity only. "
70
+ "Never add wet-lab protocols, experimental steps, or operational advice."
71
+ ),
72
+ },
73
+ {
74
+ "role": "user",
75
+ "content": (
76
+ "Polish the following high-level synthesis without adding operational details.\n\n"
77
+ f"{text}"
78
+ ),
79
+ },
80
+ ],
81
+ "temperature": 0.3,
82
+ },
83
+ )
84
+ data = r.json()
85
+ return (
86
+ data.get("choices", [{}])[0]
87
+ .get("message", {})
88
+ .get("content", text)
89
+ )
90
+ except Exception:
91
+ return text
92
+
93
+
94
+ async def postprocess_summary(base_text: str, citations: List[dict], engine: str = "none") -> str:
95
+ """
96
+ Dispatch to the chosen post-processor (none|gemini|deepseek).
97
+ Always returns safe, high-level text.
98
+ """
99
+ engine = (engine or "none").lower()
100
+ if engine == "gemini":
101
+ return await gemini_postprocess(base_text, citations)
102
+ if engine == "deepseek":
103
+ return await deepseek_postprocess(base_text, citations)
104
+ return base_text
105
+
106
+
107
+ async def synthesize_tts(text: str) -> Tuple[bytes | None, str]:
108
+ """
109
+ ElevenLabs TTS → (audio_bytes, mime). Returns (None, "") if not configured.
110
+ Requires ELEVEN_LABS_API_KEY and optional ELEVEN_LABS_VOICE_ID.
111
+ """
112
+ key = os.getenv("ELEVEN_LABS_API_KEY")
113
+ voice_id = os.getenv("ELEVEN_LABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM")
114
+ if not key:
115
+ return None, ""
116
+
117
+ payload = {
118
+ "text": text[:6000], # safety limit
119
+ "model_id": "eleven_multilingual_v2",
120
+ "voice_settings": {"stability": 0.5, "similarity_boost": 0.75},
121
+ }
122
+
123
+ url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
124
+ async with httpx.AsyncClient(timeout=120.0) as http:
125
+ r = await http.post(
126
+ url,
127
+ headers={"xi-api-key": key, "Accept": "audio/mpeg"},
128
+ json=payload,
129
+ )
130
+ if r.status_code >= 400:
131
+ # Propagate a readable error to the UI layer
132
+ raise RuntimeError(f"TTS API error {r.status_code}: {r.text[:200]}")
133
+ return r.content, r.headers.get("content-type", "audio/mpeg")