tianzhechu commited on
Commit
82b5e45
·
verified ·
1 Parent(s): 84f58ea

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +31 -1
  2. app.py +93 -9
  3. requirements.txt +5 -1
README.md CHANGED
@@ -9,4 +9,34 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
13
+
14
+ ## Local Transformers mode with alternate tokenizer
15
+
16
+ If the target model repository does not include a tokenizer, you can instruct the app to run locally with `transformers` and use a tokenizer from another repository.
17
+
18
+ Environment variables:
19
+
20
+ - `MODEL_ID` (optional): model repo to load. Defaults to `tianzhechu/BookQA-7B-Instruct`.
21
+ - `TOKENIZER_ID` (optional): tokenizer repo to use locally (e.g., a base model's tokenizer). When set, the app switches to a local `transformers` backend and streams tokens from your machine.
22
+ - `USE_LOCAL_TRANSFORMERS` (optional): set to `1` to force local mode even without `TOKENIZER_ID`.
23
+
24
+ Install extra dependencies:
25
+
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ Run with an alternate tokenizer (example):
31
+
32
+ ```bash
33
+ export MODEL_ID=tianzhechu/BookQA-7B-Instruct
34
+ export TOKENIZER_ID=TheBaseModel/TokenizerRepo
35
+ python app.py
36
+ ```
37
+
38
+ Notes:
39
+
40
+ - Local inference will download and load the model weights via `transformers` and may require significant memory.
41
+ - If the tokenizer exposes a chat template, it is applied automatically. Otherwise a simple fallback template is used.
42
+ - You'll need a compatible version of `torch` installed for your platform. If the default pip install fails, follow the official install instructions for your OS/GPU.
app.py CHANGED
@@ -1,10 +1,38 @@
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def respond(
@@ -27,17 +55,73 @@ def respond(
27
 
28
  response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  temperature=temperature,
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
38
 
39
- response += token
40
- yield response
 
 
41
 
42
 
43
  """
 
1
+ import os
2
+ import threading
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
+
10
+ os.environ["HF_HOME"] = "/tmp/huggingface"
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
+ MODEL_ID = os.getenv("MODEL_ID", "tianzhechu/BookQA-7B-Instruct")
13
+ TOKENIZER_ID = os.getenv("TOKENIZER_ID", "Qwen/Qwen2.5-0.5B-Instruct") # Optional: tokenizer repo to use locally
14
+ USE_LOCAL_TRANSFORMERS = bool(TOKENIZER_ID) or os.getenv("USE_LOCAL_TRANSFORMERS") == "1"
15
+
16
+ # Remote inference (default)
17
+ client = None if USE_LOCAL_TRANSFORMERS else InferenceClient(MODEL_ID)
18
+
19
+ # Lazy-loaded local model/tokenizer when TOKENIZER_ID is provided
20
+ local_model = None
21
+ local_tokenizer = None
22
+
23
+
24
+ def _ensure_local_model_loaded():
25
+ global local_model, local_tokenizer
26
+ if local_model is not None and local_tokenizer is not None:
27
+ return
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+
30
+ if not TOKENIZER_ID:
31
+ raise RuntimeError(
32
+ "Local transformers backend requires TOKENIZER_ID to be set to a tokenizer repo."
33
+ )
34
+ local_tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True)
35
+ local_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
36
 
37
 
38
  def respond(
 
55
 
56
  response = ""
57
 
58
+ if not USE_LOCAL_TRANSFORMERS:
59
+ for message in client.chat_completion(
60
+ messages,
61
+ max_tokens=max_tokens,
62
+ stream=True,
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ ):
66
+ token = message.choices[0].delta.content
67
+ if token:
68
+ response += token
69
+ yield response
70
+ return
71
+
72
+ # Local generation using transformers with an alternate tokenizer
73
+ _ensure_local_model_loaded()
74
+
75
+ try:
76
+ from transformers import TextIteratorStreamer
77
+ except Exception as e:
78
+ raise RuntimeError(
79
+ "transformers TextIteratorStreamer is required for local streaming; ensure transformers is installed."
80
+ ) from e
81
+
82
+ # Use chat template if available; otherwise fall back to a simple concatenation
83
+ try:
84
+ prompt_text = local_tokenizer.apply_chat_template(
85
+ messages,
86
+ tokenize=False,
87
+ add_generation_prompt=True,
88
+ )
89
+ except Exception:
90
+ convo_parts = []
91
+ for m in messages:
92
+ role = m.get("role", "user")
93
+ content = m.get("content", "")
94
+ if role == "system":
95
+ convo_parts.append(f"<system>\n{content}\n</system>")
96
+ elif role == "assistant":
97
+ convo_parts.append(f"<assistant>\n{content}\n</assistant>")
98
+ else:
99
+ convo_parts.append(f"<user>\n{content}\n</user>")
100
+ prompt_text = "\n".join(convo_parts) + "\n<assistant>\n"
101
+
102
+ inputs = local_tokenizer(prompt_text, return_tensors="pt")
103
+
104
+ streamer = TextIteratorStreamer(
105
+ local_tokenizer, skip_prompt=True, skip_special_tokens=True
106
+ )
107
+
108
+ generate_kwargs = dict(
109
+ inputs=inputs.input_ids,
110
+ attention_mask=inputs.get("attention_mask"),
111
+ max_new_tokens=max_tokens,
112
+ do_sample=temperature > 0,
113
  temperature=temperature,
114
  top_p=top_p,
115
+ streamer=streamer,
116
+ )
117
+
118
+ thread = threading.Thread(target=local_model.generate, kwargs=generate_kwargs)
119
+ thread.start()
120
 
121
+ for new_text in streamer:
122
+ if new_text:
123
+ response += new_text
124
+ yield response
125
 
126
 
127
  """
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ gradio==5.0.1
3
+ transformers>=4.38.0
4
+ torch>=2.1.0
5
+ transformers>=4.38.0