diginoron commited on
Commit
8633440
·
verified ·
1 Parent(s): 81a41a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -7,6 +7,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from deep_translator import GoogleTranslator
8
  import torch
9
 
 
 
 
10
  # کلید COMTRADE
11
  subscription_key = os.getenv("COMTRADE_API_KEY", "")
12
  # توکن Hugging Face
@@ -19,14 +22,22 @@ translator = GoogleTranslator(source='en', target='fa')
19
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
20
 
21
  # بارگذاری توکنایزر و مدل
22
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-27b-it", token=hf_token)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- "google/gemma-3-27b-it",
25
- token=hf_token,
26
- quantization_config=quantization_config,
27
- device_map="auto",
28
- torch_dtype=torch.float16
29
- )
 
 
 
 
 
 
 
 
30
 
31
  # تابع دریافت اطلاعات واردکنندگان
32
  def get_importers(hs_code: str, year: str, month: str):
@@ -50,32 +61,37 @@ def get_importers(hs_code: str, year: str, month: str):
50
  return result
51
 
52
  # تابع ارائه مشاوره با استفاده از GPU
53
- @spaces.GPU(duration=180) # افزایش مدت زمان برای مدل سنگین
54
  def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str):
55
  if table_data is None or table_data.empty:
56
  return "ابتدا باید اطلاعات واردات را نمایش دهید."
57
 
58
  table_str = table_data.to_string(index=False)
59
  period = f"{year}/{int(month):02d}"
 
60
  prompt = (
61
- f"The following table shows countries that imported a product with HS code {hs_code} during the period {period}:\n"
62
- f"{table_str}\n\n"
63
- f"Please provide a detailed and comprehensive analysis in two paragraphs. The first paragraph should discuss market opportunities, potential demand, and specific cultural or economic factors influencing the demand for this product in these countries. The second paragraph should offer actionable strategic recommendations for exporters, including detailed trade strategies, risk management techniques, and steps to establish local partnerships."
64
  )
65
  print("پرامپت ساخته‌شده:")
66
  print(prompt)
67
 
68
  try:
69
  # آماده‌سازی ورودی برای مدل
70
- input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
 
 
 
71
  # تولید خروجی
72
  outputs = model.generate(
73
- **input_ids,
74
- max_new_tokens=1024,
 
75
  do_sample=True,
76
- temperature=0.6, # برای پاسخ‌های منسجم
77
- top_p=0.85, # برای کیفیت بهتر
78
- pad_token_id=tokenizer.eos_token_id # جلوگیری از خطای pad token
79
  )
80
  # دیکد کردن خروجی و حذف پرامپت
81
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
7
  from deep_translator import GoogleTranslator
8
  import torch
9
 
10
+ # تنظیم متغیر محیطی برای دیباگ CUDA
11
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
12
+
13
  # کلید COMTRADE
14
  subscription_key = os.getenv("COMTRADE_API_KEY", "")
15
  # توکن Hugging Face
 
22
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
23
 
24
  # بارگذاری توکنایزر و مدل
25
+ try:
26
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=hf_token)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ "google/gemma-2b-it",
29
+ token=hf_token,
30
+ quantization_config=quantization_config,
31
+ device_map="auto",
32
+ torch_dtype=torch.float16
33
+ )
34
+ except Exception as e:
35
+ print(f"خطا در بارگذاری مدل: {str(e)}")
36
+ raise e
37
+
38
+ # تنظیم صریح pad_token_id
39
+ if tokenizer.pad_token_id is None:
40
+ tokenizer.pad_token_id = tokenizer.eos_token_id
41
 
42
  # تابع دریافت اطلاعات واردکنندگان
43
  def get_importers(hs_code: str, year: str, month: str):
 
61
  return result
62
 
63
  # تابع ارائه مشاوره با استفاده از GPU
64
+ @spaces.GPU(duration=120)
65
  def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str):
66
  if table_data is None or table_data.empty:
67
  return "ابتدا باید اطلاعات واردات را نمایش دهید."
68
 
69
  table_str = table_data.to_string(index=False)
70
  period = f"{year}/{int(month):02d}"
71
+ # پرامپت بهینه‌شده
72
  prompt = (
73
+ f"Table of countries importing HS code {hs_code} in {period}:\n{table_str}\n\n"
74
+ f"Analyze market opportunities and cultural/economic factors in one paragraph. "
75
+ f"Provide strategic recommendations for exporters in another paragraph."
76
  )
77
  print("پرامپت ساخته‌شده:")
78
  print(prompt)
79
 
80
  try:
81
  # آماده‌سازی ورودی برای مدل
82
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
83
+ input_ids = inputs.input_ids.to("cuda")
84
+ attention_mask = inputs.attention_mask.to("cuda")
85
+
86
  # تولید خروجی
87
  outputs = model.generate(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
+ max_new_tokens=512,
91
  do_sample=True,
92
+ temperature=0.7,
93
+ top_p=0.9,
94
+ pad_token_id=tokenizer.eos_token_id
95
  )
96
  # دیکد کردن خروجی و حذف پرامپت
97
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)