File size: 1,394 Bytes
0e4080b
 
 
 
 
 
 
 
 
 
613c8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e4080b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import pipeline
import torch
import os
from dotenv import load_dotenv

load_dotenv()

class LLMPipeline:
    def __init__(self):
        model_id = os.getenv("HF_MODEL_ID", "mradermacher/Huihui-gemma-3n-E4B-it-abliterated-GGUF")
        try:
            # Try to use CUDA if available
            if torch.cuda.is_available():
                device = "cuda"
                dtype = torch.float16
            else:
                device = "cpu"
                dtype = torch.float32
            
            self.pipeline = pipeline(
                "text-generation",
                model=model_id,
                torch_dtype=dtype,
                device_map="auto" if device == "cuda" else None,
                model_kwargs={"low_cpu_mem_usage": True}
            )
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

    async def generate(self, prompt: str, max_length: int = 100) -> str:
        """Generate text using the local Gemma model."""
        try:
            result = self.pipeline(
                prompt,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                top_p=0.9
            )
            return result[0]['generated_text']
        except Exception as e:
            print(f"Error in LLM generation: {e}")
            return ""