rdave88 commited on
Commit
92174e9
·
verified ·
1 Parent(s): 5459fc5

Update model_tools.py

Browse files
Files changed (1) hide show
  1. model_tools.py +57 -66
model_tools.py CHANGED
@@ -1,66 +1,57 @@
1
- # model_tools.py
2
-
3
- import ollama
4
- import requests
5
- from bs4 import BeautifulSoup
6
-
7
- # ---- LLM Task Extractor ----
8
-
9
- def extract_task(user_input: str) -> str:
10
- """
11
- Use local Ollama LLM to classify user query into Hugging Face task.
12
- """
13
- prompt = f"""
14
- You are an AI agent helping a developer select the right ML model.
15
- Given this request: "{user_input}"
16
-
17
- Reply with only the corresponding Hugging Face task like:
18
- - text-classification
19
- - summarization
20
- - translation
21
- - image-classification
22
- - etc.
23
- Only reply with the task name, and nothing else.
24
- """
25
-
26
- response = ollama.chat(
27
- model="mistral", # Replace with llama3, phi3, etc. if needed
28
- messages=[{"role": "user", "content": prompt}]
29
- )
30
-
31
- return response['message']['content'].strip().lower()
32
-
33
- # ---- Hugging Face Scraper ----
34
-
35
- def scrape_huggingface_models(task: str, max_results=5) -> list[dict]:
36
- """
37
- Scrapes Hugging Face for top models for a given task.
38
- """
39
- url = f"https://huggingface.co/models?pipeline_tag={task}&sort=downloads"
40
-
41
- try:
42
- resp = requests.get(url)
43
- soup = BeautifulSoup(resp.text, "html.parser")
44
- model_cards = soup.find_all("article", class_="model-card")[:max_results]
45
-
46
- results = []
47
- for card in model_cards:
48
- name_tag = card.find("a", class_="model-link")
49
- model_name = name_tag.text.strip() if name_tag else "unknown"
50
-
51
- task_div = card.find("div", class_="task-tag")
52
- task_name = task_div.text.strip() if task_div else task
53
-
54
- arch = "encoder-decoder" if "bart" in model_name.lower() or "t5" in model_name.lower() else "unknown"
55
-
56
- results.append({
57
- "model_name": model_name,
58
- "task": task_name,
59
- "architecture": arch
60
- })
61
-
62
- return results
63
-
64
- except Exception as e:
65
- print(f"Scraping error: {e}")
66
- return []
 
1
+ # model_tools.py
2
+
3
+ import ollama
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from transformers import pipeline
7
+
8
+ # ---- LLM Task Extractor ----
9
+
10
+ # Load T5-based model once
11
+ task_extractor = pipeline("text2text-generation", model="google/flan-t5-small")
12
+
13
+ def extract_task(user_input):
14
+ prompt = f"Classify the following ML task: {user_input}. Just reply with the task name."
15
+ result = task_extractor(prompt, max_new_tokens=10)
16
+ return result[0]["generated_text"].strip().lower()
17
+ response = ollama.chat(
18
+ model="mistral", # Replace with llama3, phi3, etc. if needed
19
+ messages=[{"role": "user", "content": prompt}]
20
+ )
21
+
22
+ return response['message']['content'].strip().lower()
23
+
24
+ # ---- Hugging Face Scraper ----
25
+
26
+ def scrape_huggingface_models(task: str, max_results=5) -> list[dict]:
27
+ """
28
+ Scrapes Hugging Face for top models for a given task.
29
+ """
30
+ url = f"https://huggingface.co/models?pipeline_tag={task}&sort=downloads"
31
+
32
+ try:
33
+ resp = requests.get(url)
34
+ soup = BeautifulSoup(resp.text, "html.parser")
35
+ model_cards = soup.find_all("article", class_="model-card")[:max_results]
36
+
37
+ results = []
38
+ for card in model_cards:
39
+ name_tag = card.find("a", class_="model-link")
40
+ model_name = name_tag.text.strip() if name_tag else "unknown"
41
+
42
+ task_div = card.find("div", class_="task-tag")
43
+ task_name = task_div.text.strip() if task_div else task
44
+
45
+ arch = "encoder-decoder" if "bart" in model_name.lower() or "t5" in model_name.lower() else "unknown"
46
+
47
+ results.append({
48
+ "model_name": model_name,
49
+ "task": task_name,
50
+ "architecture": arch
51
+ })
52
+
53
+ return results
54
+
55
+ except Exception as e:
56
+ print(f"Scraping error: {e}")
57
+ return []