oberbics commited on
Commit
90d7012
·
verified ·
1 Parent(s): d446745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -36
app.py CHANGED
@@ -1,57 +1,52 @@
1
  import gradio as gr
2
- import torch
3
  import json
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
- # Initialize model with error handling
7
- try:
8
- tokenizer = AutoTokenizer.from_pretrained("numind/NuExtract-1.5")
9
- model = AutoModelForCausalLM.from_pretrained(
10
- "numind/NuExtract-1.5",
11
- device_map="auto",
12
- torch_dtype=torch.float16,
13
- trust_remote_code=True
14
- )
15
- MODEL_LOADED = True
16
- print("Model loaded successfully!")
17
- except Exception as e:
18
- MODEL_LOADED = False
19
- print(f"Model loading failed: {e}")
20
 
21
  def test_function(template, text):
22
  print(f"Test function called with template: {template[:30]} and text: {text[:30]}")
23
  return "Button clicked successfully", "Function was called"
24
 
25
  def extract_info(template, text):
26
- if not MODEL_LOADED:
27
- return "❌ Model not loaded", "{}"
28
-
29
  try:
30
  # Format prompt according to NuExtract-1.5 requirements
31
  prompt = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>"
32
  print(f"Processing with prompt: {prompt[:100]}...")
33
 
34
- # Tokenize
35
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
36
 
37
- # Generate with cache disabled
38
- print("Generating output...")
39
- outputs = model.generate(
40
- **inputs,
41
- max_new_tokens=1000,
42
- do_sample=False,
43
- use_cache=False # This disables the problematic cache
44
- )
45
 
46
- # Decode and extract result
47
- print("Decoding output...")
48
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
49
 
50
- # Split at output marker
51
- if "<|output|>" in result:
52
- json_text = result.split("<|output|>")[1].strip()
53
  else:
54
- json_text = result
55
 
56
  # Try to parse as JSON
57
  print("Parsing JSON...")
 
1
  import gradio as gr
 
2
  import json
3
+ import requests
4
+ import os
5
 
6
+ # Use the Hugging Face Inference API instead of loading the model
7
+ API_URL = "https://api-inference.huggingface.co/models/numind/NuExtract-1.5"
8
+ headers = {"Authorization": f"Bearer {os.environ.get('HF_TOKEN', '')}"}
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def test_function(template, text):
11
  print(f"Test function called with template: {template[:30]} and text: {text[:30]}")
12
  return "Button clicked successfully", "Function was called"
13
 
14
  def extract_info(template, text):
 
 
 
15
  try:
16
  # Format prompt according to NuExtract-1.5 requirements
17
  prompt = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>"
18
  print(f"Processing with prompt: {prompt[:100]}...")
19
 
20
+ # Call API instead of using local model
21
+ payload = {
22
+ "inputs": prompt,
23
+ "parameters": {
24
+ "max_new_tokens": 1000,
25
+ "do_sample": False
26
+ }
27
+ }
28
+
29
+ print("Calling API...")
30
+ response = requests.post(API_URL, headers=headers, json=payload)
31
 
32
+ if response.status_code != 200:
33
+ print(f"API error: {response.status_code}, {response.text}")
34
+ return f"❌ API Error: {response.status_code}", response.text
 
 
 
 
 
35
 
36
+ # Process result
37
+ result = response.json()
38
+
39
+ # Handle different response formats
40
+ if isinstance(result, list) and len(result) > 0:
41
+ result_text = result[0].get("generated_text", "")
42
+ else:
43
+ result_text = str(result)
44
 
45
+ # Split at output marker if present
46
+ if "<|output|>" in result_text:
47
+ json_text = result_text.split("<|output|>")[1].strip()
48
  else:
49
+ json_text = result_text
50
 
51
  # Try to parse as JSON
52
  print("Parsing JSON...")