RihemXX commited on
Commit
fe0fb7e
·
verified ·
1 Parent(s): 488f79f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -7
app.py CHANGED
@@ -640,6 +640,54 @@ def build_demo():
640
 
641
  return demo
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  if __name__ == "__main__":
645
  parser = argparse.ArgumentParser()
@@ -652,10 +700,9 @@ if __name__ == "__main__":
652
  logger.info(f"args: {args}")
653
 
654
  logger.info(args)
655
- demo = build_demo()
656
- demo.queue(api_open=False).launch(
657
- server_name=args.host,
658
- server_port=args.port,
659
- share=args.share,
660
- max_threads=args.concurrency_count,
661
- )
 
640
 
641
  return demo
642
 
643
+ # --- API endpoint: /chat ---
644
+ api_interface = gr.Interface(
645
+ fn=lambda image, question: http_api_infer(image, question),
646
+ inputs=[gr.Image(type="pil"), gr.Textbox()],
647
+ outputs="text",
648
+ allow_flagging="never",
649
+ api_name="/chat"
650
+ )
651
+
652
+ def http_api_infer(image, question):
653
+ """
654
+ Simple API endpoint that mimics InternVL logic with one image + text.
655
+ """
656
+ # Build a simplified version of `state` object here
657
+ dummy_state = init_state()
658
+ dummy_state.set_system_message("You are a vision-language assistant.")
659
+ dummy_state.append_message(Conversation.USER, question, [image])
660
+ dummy_state.skip_next = False
661
+
662
+ # Simulate inference (you can directly call your model function here instead)
663
+ worker_addr = os.environ.get("WORKER_ADDR", "")
664
+ api_token = os.environ.get("API_TOKEN", "")
665
+ headers = {
666
+ "Authorization": f"Bearer {api_token}",
667
+ "Content-Type": "application/json"
668
+ }
669
+
670
+ if not worker_addr:
671
+ return "⚠️ Model backend is not configured."
672
+
673
+ all_image_paths = [dummy_state.save_image(image)]
674
+
675
+ pload = {
676
+ "model": "InternVL2.5-78B",
677
+ "messages": dummy_state.get_prompt_v2(inlude_image=True, max_dynamic_patch=12),
678
+ "temperature": 0.2,
679
+ "top_p": 0.7,
680
+ "max_tokens": 1024,
681
+ "repetition_penalty": 1.1,
682
+ "stream": False
683
+ }
684
+
685
+ try:
686
+ response = requests.post(worker_addr, json=pload, headers=headers, timeout=120)
687
+ reply = response.json()["choices"][0]["message"]["content"]
688
+ return reply
689
+ except Exception as e:
690
+ return f"Error: {str(e)}"
691
 
692
  if __name__ == "__main__":
693
  parser = argparse.ArgumentParser()
 
700
  logger.info(f"args: {args}")
701
 
702
  logger.info(args)
703
+ demo = gr.TabbedInterface(
704
+ interface_list=[build_demo(), api_interface],
705
+ tab_names=["UI", "API"]
706
+ )
707
+ demo.queue(api_open=True).launch(...)
708
+