Spaces:
Running
Running
import os | |
import gradio as gr | |
import pandas as pd | |
import comtradeapicall | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from deep_translator import GoogleTranslator | |
import torch | |
# تنظیم متغیر محیطی برای دیباگ CUDA | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
# کلید COMTRADE | |
subscription_key = os.getenv("COMTRADE_API_KEY", "") | |
# توکن Hugging Face | |
hf_token = os.getenv("HF_API_TOKEN") | |
# تعریف مترجم | |
translator = GoogleTranslator(source='en', target='fa') | |
# تنظیم کوانتسازی برای کاهش مصرف حافظه | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
# بارگذاری توکنایزر و مدل | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b-it", | |
token=hf_token, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
except Exception as e: | |
print(f"خطا در بارگذاری مدل: {str(e)}") | |
raise e | |
# تنظیم صریح pad_token_id | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
# تابع دریافت اطلاعات واردکنندگان | |
def get_importers(hs_code: str, year: str, month: str): | |
period = f"{year}{int(month):02d}" | |
df = comtradeapicall.previewFinalData( | |
typeCode='C', freqCode='M', clCode='HS', period=period, | |
reporterCode=None, cmdCode=hs_code, flowCode='M', | |
partnerCode=None, partner2Code=None, | |
customsCode=None, motCode=None, | |
maxRecords=500, includeDesc=True | |
) | |
if df is None or df.empty: | |
return pd.DataFrame(columns=["کد کشور", "نام کشور", "ارزش CIF"]) | |
df = df[df['cifvalue'] > 0] | |
result = ( | |
df.groupby(["reporterCode", "reporterDesc"], as_index=False) | |
.agg({"cifvalue": "sum"}) | |
.sort_values("cifvalue", ascending=False) | |
) | |
result.columns = ["کد کشور", "نام کشور", "ارزش CIF"] | |
return result | |
# تابع ارائه مشاوره با استفاده از GPU | |
def provide_advice(table_data: pd.DataFrame, hs_code: str, year: str, month: str): | |
if table_data is None or table_data.empty: | |
return "ابتدا باید اطلاعات واردات را نمایش دهید." | |
table_str = table_data.to_string(index=False) | |
period = f"{year}/{int(month):02d}" | |
# پرامپت بهینهشده | |
prompt = ( | |
f"Table of countries importing HS code {hs_code} in {period}:\n{table_str}\n\n" | |
f"Analyze market opportunities and cultural/economic factors in one paragraph. " | |
f"Provide strategic recommendations for exporters in another paragraph." | |
) | |
print("پرامپت ساختهشده:") | |
print(prompt) | |
try: | |
# آمادهسازی ورودی برای مدل | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
input_ids = inputs.input_ids.to("cuda") | |
attention_mask = inputs.attention_mask.to("cuda") | |
# تولید خروجی | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# دیکد کردن خروجی و حذف پرامپت | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# حذف پرامپت از خروجی | |
if generated_text.startswith(prompt): | |
generated_text = generated_text[len(prompt):].strip() | |
# بررسی اینکه خروجی خالی نباشد | |
if not generated_text: | |
return "مدل نتوانست پاسخ مناسبی تولید کند. لطفاً دوباره امتحان کنید." | |
print("خروجی مدل دریافت شد (به انگلیسی):") | |
print(generated_text) | |
# ترجمه خروجی به فارسی | |
translated_outputs = translator.translate(generated_text) | |
print("خروجی ترجمهشده به فارسی:") | |
print(translated_outputs) | |
return translated_outputs | |
except Exception as e: | |
error_msg = f"خطا در تولید مشاوره: {str(e)}" | |
print(error_msg) | |
return error_msg | |
# تنظیمات رابط Gradio | |
current_year = pd.Timestamp.now().year | |
years = [str(y) for y in range(2000, current_year+1)] | |
months = [str(m) for m in range(1, 13)] | |
with gr.Blocks() as demo: | |
gr.Markdown("##تولید شده توسط DIGINORON نمایش کشورهایی که یک کالا را وارد کردهاند") | |
with gr.Row(): | |
inp_hs = gr.Textbox(label="HS Code") | |
inp_year = gr.Dropdown(choices=years, label="سال", value=str(current_year)) | |
inp_month = gr.Dropdown(choices=months, label="ماه", value=str(pd.Timestamp.now().month)) | |
btn_show = gr.Button("نمایش اطلاعات") | |
out_table = gr.Dataframe( | |
headers=["کد کشور", "نام کشور", "ارزش CIF"], | |
datatype=["number", "text", "number"], | |
interactive=True, | |
) | |
btn_show.click(get_importers, [inp_hs, inp_year, inp_month], out_table) | |
btn_advice = gr.Button("ارائه مشاوره تخصصی") | |
out_advice = gr.Textbox(label="مشاوره تخصصی", lines=6) | |
btn_advice.click( | |
provide_advice, | |
inputs=[out_table, inp_hs, inp_year, inp_month], | |
outputs=out_advice | |
) | |
if __name__ == "__main__": | |
demo.launch() |