diginoron commited on
Commit
6867162
·
verified ·
1 Parent(s): 0f7ce0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -2,20 +2,30 @@ import os
2
  import gradio as gr
3
  import pandas as pd
4
  import comtradeapicall
5
- from huggingface_hub import InferenceClient
 
6
  from deep_translator import GoogleTranslator
7
-
8
 
9
  # کلید COMTRADE
10
  subscription_key = os.getenv("COMTRADE_API_KEY", "")
11
  # توکن Hugging Face
12
  hf_token = os.getenv("HF_API_TOKEN")
13
 
 
 
14
 
15
- client = InferenceClient(token=hf_token)
16
- translator = GoogleTranslator(source='en', target='fa')
17
-
 
 
 
 
 
 
18
 
 
19
  def get_importers(hs_code: str, year: str, month: str):
20
  period = f"{year}{int(month):02d}"
21
  df = comtradeapicall.previewFinalData(
@@ -36,10 +46,12 @@ def get_importers(hs_code: str, year: str, month: str):
36
  result.columns = ["کد کشور", "نام کشور", "ارزش CIF"]
37
  return result
38
 
39
-
 
40
  def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str):
41
  if table_data is None or table_data.empty:
42
  return "ابتدا باید اطلاعات واردات را نمایش دهید."
 
43
  table_str = table_data.to_string(index=False)
44
  period = f"{year}/{int(month):02d}"
45
  prompt = (
@@ -49,19 +61,21 @@ def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str
49
  )
50
  print("پرامپت ساخته‌شده:")
51
  print(prompt)
 
52
  try:
53
- print("در حال فراخوانی مدل mistralai/Mixtral-8x7B-Instruct-v0.1...")
54
- outputs = client.text_generation(
55
- prompt=prompt,
56
- model="mistralai/Mixtral-8x7B-Instruct-v0.1",
57
- max_new_tokens=1024 # افزایش برای تکمیل جملات
 
 
58
  )
 
59
  print("خروجی مدل دریافت شد (به انگلیسی):")
60
- print(outputs)
61
-
62
 
63
- # ترجمه خروجی به فارسی
64
- translated_outputs = translator.translate(outputs)
65
  print("خروجی ترجمه‌شده به فارسی:")
66
  print(translated_outputs)
67
  return translated_outputs
@@ -70,12 +84,11 @@ def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str
70
  print(error_msg)
71
  return error_msg
72
 
73
-
74
  current_year = pd.Timestamp.now().year
75
  years = [str(y) for y in range(2000, current_year+1)]
76
  months = [str(m) for m in range(1, 13)]
77
 
78
-
79
  with gr.Blocks() as demo:
80
  gr.Markdown("##تولید شده توسط DIGINORON نمایش کشورهایی که یک کالا را وارد کرده‌اند")
81
  with gr.Row():
@@ -90,17 +103,14 @@ with gr.Blocks() as demo:
90
  )
91
  btn_show.click(get_importers, [inp_hs, inp_year, inp_month], out_table)
92
 
93
-
94
  btn_advice = gr.Button("ارائه مشاوره تخصصی")
95
  out_advice = gr.Textbox(label="مشاوره تخصصی", lines=6)
96
 
97
-
98
  btn_advice.click(
99
  provide_advice,
100
  inputs=[out_table, inp_hs, inp_year, inp_month],
101
  outputs=out_advice
102
  )
103
 
104
-
105
  if __name__ == "__main__":
106
- demo.launch()
 
2
  import gradio as gr
3
  import pandas as pd
4
  import comtradeapicall
5
+ import spaces
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from deep_translator import GoogleTranslator
8
+ import torch
9
 
10
  # کلید COMTRADE
11
  subscription_key = os.getenv("COMTRADE_API_KEY", "")
12
  # توکن Hugging Face
13
  hf_token = os.getenv("HF_API_TOKEN")
14
 
15
+ # تنظیم کوانت‌سازی
16
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
17
 
18
+ # بارگذاری توکنایزر و مدل
19
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b", token=hf_token)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ "google/gemma-7b",
22
+ token=hf_token,
23
+ quantization_config=quantization_config,
24
+ device_map="auto",
25
+ torch_dtype=torch.float16
26
+ )
27
 
28
+ # تابع دریافت اطلاعات واردکنندگان
29
  def get_importers(hs_code: str, year: str, month: str):
30
  period = f"{year}{int(month):02d}"
31
  df = comtradeapicall.previewFinalData(
 
46
  result.columns = ["کد کشور", "نام کشور", "ارزش CIF"]
47
  return result
48
 
49
+ # تابع ارائه مشاوره با استفاده از GPU
50
+ @spaces.GPU(duration=120)
51
  def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str):
52
  if table_data is None or table_data.empty:
53
  return "ابتدا باید اطلاعات واردات را نمایش دهید."
54
+
55
  table_str = table_data.to_string(index=False)
56
  period = f"{year}/{int(month):02d}"
57
  prompt = (
 
61
  )
62
  print("پرامپت ساخته‌شده:")
63
  print(prompt)
64
+
65
  try:
66
+ input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
67
+ outputs = model.generate(
68
+ **input_ids,
69
+ max_new_tokens=1024,
70
+ do_sample=True,
71
+ temperature=0.7,
72
+ top_p=0.9
73
  )
74
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
  print("خروجی مدل دریافت شد (به انگلیسی):")
76
+ print(generated_text)
 
77
 
78
+ translated_outputs = translator.translate(generated_text)
 
79
  print("خروجی ترجمه‌شده به فارسی:")
80
  print(translated_outputs)
81
  return translated_outputs
 
84
  print(error_msg)
85
  return error_msg
86
 
87
+ # تنظیمات رابط Gradio
88
  current_year = pd.Timestamp.now().year
89
  years = [str(y) for y in range(2000, current_year+1)]
90
  months = [str(m) for m in range(1, 13)]
91
 
 
92
  with gr.Blocks() as demo:
93
  gr.Markdown("##تولید شده توسط DIGINORON نمایش کشورهایی که یک کالا را وارد کرده‌اند")
94
  with gr.Row():
 
103
  )
104
  btn_show.click(get_importers, [inp_hs, inp_year, inp_month], out_table)
105
 
 
106
  btn_advice = gr.Button("ارائه مشاوره تخصصی")
107
  out_advice = gr.Textbox(label="مشاوره تخصصی", lines=6)
108
 
 
109
  btn_advice.click(
110
  provide_advice,
111
  inputs=[out_table, inp_hs, inp_year, inp_month],
112
  outputs=out_advice
113
  )
114
 
 
115
  if __name__ == "__main__":
116
+ demo.launch()