oberbics commited on
Commit
38f0b3d
·
verified ·
1 Parent(s): 7778425

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -53
app.py CHANGED
@@ -16,7 +16,6 @@ import string
16
  import spaces
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
  import torch
19
- from transformers import AutoModelForCausalLM, AutoTokenizer
20
  from transformers import AutoConfig
21
  import torch
22
 
@@ -84,87 +83,73 @@ class SafeGeocoder:
84
  def load_model():
85
  global tokenizer, model
86
  try:
87
- # First ensure we have the right tokenizer class available
88
- from transformers import Qwen2Tokenizer
89
- except ImportError:
90
- # Fallback to AutoTokenizer if specific import fails
91
- pass
92
-
93
- try:
94
- # Generate a random location and text each time
95
- random_city = random.choice(["Berlin", "Paris", "London", "Tokyo", "Rome", "Madrid"])
96
- random_suffix = ''.join(random.choices(string.ascii_lowercase, k=5))
97
- test_text = f"Test in {random_city}_{random_suffix}."
98
- test_template = '{"test_location": ""}'
99
-
100
- # Initialize model if not already loaded
101
  if model is None:
102
- # Load config first to check for tokenizer class
103
- config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
104
-
105
- # Load tokenizer with explicit class if needed
106
- if hasattr(config, "tokenizer_class"):
107
- tokenizer = AutoTokenizer.from_pretrained(
108
- MODEL_NAME,
109
- trust_remote_code=True,
110
- tokenizer_class=config.tokenizer_class
111
- )
112
- else:
113
- tokenizer = AutoTokenizer.from_pretrained(
114
- MODEL_NAME,
115
- trust_remote_code=True
116
- )
117
 
118
  model = AutoModelForCausalLM.from_pretrained(
119
  MODEL_NAME,
120
  torch_dtype=TORCH_DTYPE,
121
- trust_remote_code=True,
122
- device_map="auto"
123
  ).eval()
 
124
  print(f"✅ Loaded {MODEL_NAME} on {DEVICE}")
125
-
126
- # Test the model
127
- prompt = f"<|input|>\n### Template:\n{test_template}\n### Text:\n{test_text}\n\n<|output|>"
128
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
129
- outputs = model.generate(**inputs, max_new_tokens=50)
130
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
131
-
132
- if "<|output|>" in result and random_city in result:
133
- return "✅ Modell erfolgreich geladen und getestet! Sie können jetzt mit der Extraktion beginnen."
134
- return "⚠️ Modell-Test nicht erfolgreich. Bitte versuchen Sie es erneut."
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
  return f"❌ Fehler beim Laden des Modells: {str(e)}"
 
138
  @spaces.GPU
139
  def extract_info(template, text):
140
  global tokenizer, model
141
-
142
  if model is None:
143
- return "❌ Modell nicht geladen", "Bitte zuerst das Modell laden (1. Schritt)"
144
 
145
  try:
146
  prompt = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>"
147
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LENGTH).to(DEVICE)
 
 
 
 
 
148
 
149
  outputs = model.generate(
150
  **inputs,
151
- max_new_tokens=MAX_NEW_TOKENS,
152
  temperature=0.0,
153
  do_sample=False,
154
  pad_token_id=tokenizer.eos_token_id
155
  )
156
 
157
  result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
158
-
159
- if "<|output|>" in result_text:
160
- json_text = result_text.split("<|output|>")[1].strip()
161
- else:
162
- json_text = result_text
163
 
164
  try:
165
  extracted = json.loads(json_text)
166
- formatted = json.dumps(extracted, indent=2)
167
- return "✅ Erfolgreich extrahiert", formatted
168
  except json.JSONDecodeError:
169
  return "❌ JSON Parsing Fehler", json_text
170
 
 
16
  import spaces
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
  import torch
 
19
  from transformers import AutoConfig
20
  import torch
21
 
 
83
  def load_model():
84
  global tokenizer, model
85
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if model is None:
87
+ # Special handling for NuExtract tokenizer
88
+ tokenizer = AutoTokenizer.from_pretrained(
89
+ MODEL_NAME,
90
+ trust_remote_code=True
91
+ )
 
 
 
 
 
 
 
 
 
 
92
 
93
  model = AutoModelForCausalLM.from_pretrained(
94
  MODEL_NAME,
95
  torch_dtype=TORCH_DTYPE,
96
+ device_map="auto",
97
+ trust_remote_code=True
98
  ).eval()
99
+
100
  print(f"✅ Loaded {MODEL_NAME} on {DEVICE}")
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # Test the model
103
+ test_text = "Test in Berlin."
104
+ test_template = '{"test_location": ""}'
105
+ prompt = f"<|input|>\n### Template:\n{test_template}\n### Text:\n{test_text}\n\n<|output|>"
106
+
107
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=20000, truncation=True).to(DEVICE)
108
+ outputs = model.generate(
109
+ **inputs,
110
+ max_new_tokens=50,
111
+ temperature=0.0,
112
+ do_sample=False
113
+ )
114
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
+
116
+ if "<|output|>" in result and "Berlin" in result:
117
+ return "✅ Modell erfolgreich geladen und getestet!"
118
+
119
+ return "⚠️ Modell-Test nicht erfolgreich. Bitte versuchen Sie es erneut."
120
+
121
  except Exception as e:
122
  return f"❌ Fehler beim Laden des Modells: {str(e)}"
123
+
124
  @spaces.GPU
125
  def extract_info(template, text):
126
  global tokenizer, model
 
127
  if model is None:
128
+ return "❌ Modell nicht geladen", "Bitte zuerst das Modell laden"
129
 
130
  try:
131
  prompt = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>"
132
+ inputs = tokenizer(
133
+ prompt,
134
+ return_tensors="pt",
135
+ truncation=True,
136
+ max_length=20000
137
+ ).to(DEVICE)
138
 
139
  outputs = model.generate(
140
  **inputs,
141
+ max_new_tokens=1000,
142
  temperature=0.0,
143
  do_sample=False,
144
  pad_token_id=tokenizer.eos_token_id
145
  )
146
 
147
  result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
148
+ json_text = result_text.split("<|output|>")[1].strip() if "<|output|>" in result_text else result_text
 
 
 
 
149
 
150
  try:
151
  extracted = json.loads(json_text)
152
+ return "✅ Erfolgreich extrahiert", json.dumps(extracted, indent=2)
 
153
  except json.JSONDecodeError:
154
  return "❌ JSON Parsing Fehler", json_text
155