File size: 2,709 Bytes
0e47704
 
7329ecf
0e47704
 
 
 
 
 
 
 
b25beac
0e47704
 
 
 
b25beac
0e47704
7329ecf
b25beac
0e47704
 
 
b25beac
0e47704
7329ecf
0e47704
 
 
 
7329ecf
0e47704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7329ecf
 
 
 
 
0e47704
7329ecf
0e47704
7329ecf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
"""MedGenesis – **Gemini** (Google Generative AI) async helper.

Key behaviours
~~~~~~~~~~~~~~
* Tries the fast **`gemini-1.5-flash`** model first → falls back to
  **`gemini-pro`** when flash unavailable or quota‑exceeded.
* Exponential back‑off retry (2×, 4×) for transient 5xx/429.
* Singleton model cache to avoid re‑instantiation cost.
* Returns **empty string** on irrecoverable errors so orchestrator can
  gracefully pivot to OpenAI.
"""
from __future__ import annotations

import os, asyncio, functools
from typing import Dict

import google.generativeai as genai
from google.api_core import exceptions as gexc

_API_KEY = os.getenv("GEMINI_KEY")
if not _API_KEY:
    raise RuntimeError("GEMINI_KEY env variable missing – set it in HF Secrets")

genai.configure(api_key=_API_KEY)

# ---------------------------------------------------------------------
# Model cache
# ---------------------------------------------------------------------
@functools.lru_cache(maxsize=4)
def _get_model(name: str):
    return genai.GenerativeModel(name)


async def _generate(prompt: str, model_name: str, *, temperature: float = 0.3, retries: int = 3) -> str:
    """Run generation inside a ThreadPool – Gemini SDK is blocking."""
    delay = 2
    for _ in range(retries):
        try:
            resp = await asyncio.to_thread(
                _get_model(model_name).generate_content,
                prompt,
                generation_config={"temperature": temperature},
            )
            return resp.text.strip()
        except (gexc.ResourceExhausted, gexc.ServiceUnavailable):
            await asyncio.sleep(delay)
            delay *= 2
        except (gexc.NotFound, gexc.PermissionDenied):
            return ""  # unrecoverable – model/key unavailable
    return ""  # after retries

# ---------------------------------------------------------------------
# Public wrappers
# ---------------------------------------------------------------------
async def gemini_summarize(text: str, *, words: int = 150) -> str:
    prompt = f"Summarize in ≤{words} words:\n\n{text[:12000]}"
    out = await _generate(prompt, "gemini-1.5-flash")
    if not out:
        out = await _generate(prompt, "gemini-pro")
    return out

async def gemini_qa(question: str, *, context: str = "") -> str:
    prompt = (
        "You are an advanced biomedical research agent. Use the context to answer concisely.\n\n"
        f"Context:\n{context[:10000]}\n\nQ: {question}\nA:"
    )
    out = await _generate(prompt, "gemini-1.5-flash")
    if not out:
        out = await _generate(prompt, "gemini-pro")
    return out or "Gemini could not answer (model/key unavailable)."