Skip to content

Commit 73c4a49

Browse files
author
LittleMouse
committed
[update] ModuleLLM support ctx model, add HomeAssistant model, add model post process config.
1 parent a916ca0 commit 73c4a49

File tree

10 files changed

+1826
-489
lines changed

10 files changed

+1826
-489
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"mode":"qwen2.5-HA-0.5B-ctx-ax630c",
3+
"type":"llm",
4+
"homepage":"https://huggingface.co/yunyu1258/qwen2.5-0.5b-ha",
5+
"compile_flage":"pulsar2 llm_build --input_path Qwen/qwen2.5-0.5b-ha --output_path Qwen/qwen2.5-0.5B-p1024-ha-ax630c --hidden_state_type bf16 --prefill_len 128 --kv_cache_len 1280 --last_kv_cache_len 128 --last_kv_cache_len 512 --last_kv_cache_len 1024 --chip AX620E --parallel 24",
6+
"pulsar_version":"4.1-patch1-c37957c7",
7+
"capabilities":[
8+
"text_generation",
9+
"chat"
10+
],
11+
"input_type":[
12+
"llm.utf-8",
13+
"llm.utf-8.stream",
14+
"llm.chat_completion",
15+
"llm.chat_completion.stream"
16+
],
17+
"output_type":[
18+
"llm.utf-8",
19+
"llm.utf-8.stream"
20+
],
21+
"mode_param":{
22+
"tokenizer_type":2,
23+
"url_tokenizer_model":"http://localhost:8080",
24+
"filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin",
25+
"filename_post_axmodel":"qwen2_post.axmodel",
26+
"template_filename_axmodel":"qwen2_p128_l%d_together.axmodel",
27+
"b_use_topk":false,
28+
"b_bos":false,
29+
"b_eos":false,
30+
"axmodel_num":24,
31+
"tokens_embed_num":151936,
32+
"tokens_embed_size":896,
33+
"b_use_mmap_load_embed":true,
34+
"b_dynamic_load_axmodel_layer":false,
35+
"precompute_len":1202,
36+
"ext_scripts":["tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py"]
37+
}
38+
}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from transformers import AutoTokenizer, PreTrainedTokenizerFast
2+
from http.server import HTTPServer, BaseHTTPRequestHandler
3+
import json
4+
import argparse
5+
import uuid
6+
7+
# 全局字典:存储 uid 到 Tokenizer_Http 实例的映射
8+
tokenizers = {}
9+
10+
class Tokenizer_Http():
11+
def __init__(self, model_id):
12+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
self.messages = [
14+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
15+
]
16+
self.token_ids = []
17+
18+
self.token_ids_cache = []
19+
20+
def encode(self, prompt, last_reply=None):
21+
if last_reply is not None:
22+
self.messages.append({"role": "assistant", "content": last_reply})
23+
text = self.tokenizer.apply_chat_template(
24+
self.messages,
25+
tokenize=False,
26+
add_generation_prompt=True
27+
)
28+
# print("生成的文本:\n============\n", text, "============\n")
29+
self.token_ids = self.tokenizer.encode(text)[:-3]
30+
self.messages.append({"role": "user", "content": prompt})
31+
32+
text = self.tokenizer.apply_chat_template(
33+
self.messages,
34+
tokenize=False,
35+
add_generation_prompt=True
36+
)
37+
print("生成的文本:\n============\n", text, "============\n")
38+
token_ids = self.tokenizer.encode(text)
39+
# 找出新增部分
40+
diff = token_ids[len(self.token_ids):]
41+
self.token_ids = token_ids
42+
print(self.decode(diff))
43+
return token_ids, diff
44+
45+
def decode(self, token_ids):
46+
self.token_ids_cache += token_ids
47+
text = self.tokenizer.decode(self.token_ids_cache)
48+
if "\ufffd" in text:
49+
print("text 中包含非法字符")
50+
return ""
51+
else:
52+
self.token_ids_cache.clear()
53+
return text
54+
55+
56+
@property
57+
def bos_id(self):
58+
return self.tokenizer.bos_token_id
59+
60+
@property
61+
def eos_id(self):
62+
return self.tokenizer.eos_token_id
63+
64+
@property
65+
def bos_token(self):
66+
return self.tokenizer.bos_token
67+
68+
@property
69+
def eos_token(self):
70+
return self.tokenizer.eos_token
71+
72+
def reset(self, system_prompt=None):
73+
if system_prompt is None:
74+
system_prompt = args.content
75+
self.messages = [
76+
{"role": "system", "content": system_prompt},
77+
]
78+
text = self.tokenizer.apply_chat_template(
79+
self.messages,
80+
tokenize=False,
81+
add_generation_prompt=True
82+
)
83+
token_ids = self.tokenizer.encode(text)[:-3]
84+
self.token_ids = token_ids
85+
print(self.decode(token_ids))
86+
return token_ids
87+
88+
89+
class Request(BaseHTTPRequestHandler):
90+
timeout = 5
91+
server_version = 'Apache'
92+
93+
def do_GET(self):
94+
print("GET 请求路径:", self.path)
95+
self.send_response(200)
96+
self.send_header("Content-Type", "application/json")
97+
self.end_headers()
98+
99+
# 新增接口:获取 uid
100+
if '/get_uid' in self.path:
101+
new_uid = str(uuid.uuid4())
102+
print("新 uid:", new_uid)
103+
# 为该 uid 创建一个新的 Tokenizer_Http 实例
104+
tokenizers[new_uid] = Tokenizer_Http(args.model_id)
105+
msg = json.dumps({'uid': new_uid})
106+
elif '/bos_id' in self.path:
107+
# 获取 uid 参数(例如 ?uid=xxx)
108+
uid = self.get_query_param("uid")
109+
instance: Tokenizer_Http = tokenizers.get(uid)
110+
if instance is None:
111+
msg = json.dumps({'error': 'Invalid uid'})
112+
else:
113+
bos_id = instance.bos_id
114+
msg = json.dumps({'bos_id': bos_id if bos_id is not None else -1})
115+
elif '/eos_id' in self.path:
116+
uid = self.get_query_param("uid")
117+
instance: Tokenizer_Http = tokenizers.get(uid)
118+
if instance is None:
119+
msg = json.dumps({'error': 'Invalid uid'})
120+
else:
121+
eos_id = instance.eos_id
122+
msg = json.dumps({'eos_id': eos_id if eos_id is not None else -1})
123+
else:
124+
msg = json.dumps({'error': 'Invalid GET endpoint'})
125+
126+
print("响应消息:", msg)
127+
self.wfile.write(msg.encode())
128+
129+
def do_POST(self):
130+
content_length = int(self.headers.get('content-length', 0))
131+
data = self.rfile.read(content_length).decode()
132+
print("POST 请求路径:", self.path)
133+
print("接收到的数据:", data)
134+
req = json.loads(data)
135+
136+
self.send_response(200)
137+
self.send_header("Content-Type", "application/json")
138+
self.end_headers()
139+
140+
if '/encode' in self.path:
141+
# 请求数据中必须包含 uid, text, 和可选的 last_reply
142+
uid = req.get('uid')
143+
prompt = req.get('text')
144+
last_reply = req.get('last_reply')
145+
instance: Tokenizer_Http = tokenizers.get(uid)
146+
if instance is None:
147+
msg = json.dumps({'error': 'Invalid uid'})
148+
else:
149+
token_ids, diff = instance.encode(prompt, last_reply)
150+
msg = json.dumps({'token_ids': token_ids, 'diff': diff})
151+
elif '/decode' in self.path:
152+
uid = req.get('uid')
153+
token_ids = req.get('token_ids')
154+
instance: Tokenizer_Http = tokenizers.get(uid)
155+
if instance is None:
156+
msg = json.dumps({'error': 'Invalid uid'})
157+
else:
158+
text = instance.decode(token_ids)
159+
msg = json.dumps({'text': text})
160+
elif '/reset' in self.path:
161+
uid = req.get("uid")
162+
system_prompt = req.get("system_prompt")
163+
instance: Tokenizer_Http = tokenizers.get(uid)
164+
if instance is None:
165+
msg = json.dumps({'error': 'Invalid uid'})
166+
else:
167+
if system_prompt is not None:
168+
print("system_prompt:", system_prompt)
169+
token_ids = instance.reset(system_prompt)
170+
msg = json.dumps({'token_ids': token_ids})
171+
else:
172+
token_ids = instance.reset()
173+
msg = json.dumps({'token_ids': token_ids})
174+
else:
175+
msg = json.dumps({'error': 'Invalid POST endpoint'})
176+
177+
print("响应消息:", msg)
178+
self.wfile.write(msg.encode())
179+
180+
def get_query_param(self, key):
181+
"""
182+
辅助函数:从 GET 请求的 URL 中获取查询参数的值
183+
例如:/bos_id?uid=xxx
184+
"""
185+
from urllib.parse import urlparse, parse_qs
186+
query = urlparse(self.path).query
187+
params = parse_qs(query)
188+
values = params.get(key)
189+
return values[0] if values else None
190+
191+
if __name__ == "__main__":
192+
parser = argparse.ArgumentParser()
193+
parser.add_argument('--host', type=str, default='0.0.0.0')
194+
parser.add_argument('--port', type=int, default=12345)
195+
parser.add_argument('--model_id', type=str, default='qwen3_1.7B_tokenizer')
196+
parser.add_argument('--content', type=str, default='You are Qwen, created by Alibaba Cloud. You are a helpful assistant.')
197+
198+
args = parser.parse_args()
199+
200+
host = (args.host, args.port)
201+
print('Server running at http://%s:%s' % host)
202+
server = HTTPServer(host, Request)
203+
server.serve_forever()

0 commit comments

Comments
 (0)