24Sureshkumar commited on
Commit
b7ac1a7
·
verified ·
1 Parent(s): 56cd428

Create core.py

Browse files
Files changed (1) hide show
  1. core.py +81 -0
core.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core.py (All logic here)
2
+ import os
3
+ import time
4
+ import tempfile
5
+ import torch
6
+ from transformers import (
7
+ MBartForConditionalGeneration,
8
+ MBart50TokenizerFast,
9
+ AutoTokenizer,
10
+ AutoModelForCausalLM
11
+ )
12
+ from diffusers import StableDiffusionPipeline
13
+ from PIL import Image
14
+ from rouge_score import rouge_scorer
15
+
16
+ # Device setup
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # HF token from env
20
+ token = os.getenv("HF_TOKEN")
21
+ if not token:
22
+ raise ValueError("❌ Please set your HF_TOKEN in the HF Spaces secrets.")
23
+
24
+ # Load models
25
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
26
+ translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
27
+ translator_tokenizer.src_lang = "ta_IN"
28
+
29
+ gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
30
+ gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
31
+
32
+ pipe = StableDiffusionPipeline.from_pretrained(
33
+ "stabilityai/stable-diffusion-2-1",
34
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
35
+ use_auth_token=token
36
+ ).to(device)
37
+ pipe.safety_checker = None
38
+
39
+
40
+ def translate_tamil_to_english(text, reference=None):
41
+ start = time.time()
42
+ inputs = translator_tokenizer(text, return_tensors="pt").to(device)
43
+ outputs = translator_model.generate(
44
+ **inputs,
45
+ forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
46
+ )
47
+ translation = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
48
+ duration = round(time.time() - start, 2)
49
+
50
+ rouge_l = None
51
+ if reference:
52
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
53
+ scores = scorer.score(reference.lower(), translation.lower())
54
+ rouge_l = round(scores['rougeL'].fmeasure, 4)
55
+
56
+ return translation, duration, rouge_l
57
+
58
+
59
+ def generate_creative_text(prompt, max_length=100):
60
+ start = time.time()
61
+ input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
62
+ out = gen_model.generate(
63
+ input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9
64
+ )
65
+ text = gen_tokenizer.decode(out[0], skip_special_tokens=True)
66
+ duration = round(time.time() - start, 2)
67
+ tokens = text.split()
68
+ repetition = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
69
+ return text, duration, len(tokens), round(repetition, 4)
70
+
71
+
72
+ def generate_image(prompt):
73
+ try:
74
+ start = time.time()
75
+ out = pipe(prompt)
76
+ img = out.images[0].resize((256, 256))
77
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
78
+ img.save(tmp.name)
79
+ return tmp.name, round(time.time() - start, 2)
80
+ except Exception as e:
81
+ return None, f"Image generation failed: {e}"