Skip to content

Commit 92b10ac

Browse files
author
LittleMouse
committed
[update] add internvl3-1B-ax630c model update main_vlm
1 parent 9d816fe commit 92b10ac

File tree

3 files changed

+269
-60
lines changed

3 files changed

+269
-60
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from transformers import AutoTokenizer
2+
from http.server import HTTPServer, BaseHTTPRequestHandler
3+
import json
4+
import argparse
5+
import uuid
6+
7+
tokenizers = {}
8+
9+
class Tokenizer_Http():
10+
def __init__(self, model_id):
11+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
12+
self.messages = [
13+
{"role": "system", "content": "你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。"},
14+
]
15+
self.token_ids = []
16+
self.token_ids_cache = []
17+
18+
def encode(self, prompt, last_reply=None):
19+
if last_reply is not None:
20+
self.messages.append({"role": "assistant", "content": last_reply})
21+
text = self.tokenizer.apply_chat_template(
22+
self.messages,
23+
tokenize=False,
24+
add_generation_prompt=True
25+
)
26+
self.token_ids = self.tokenizer.encode(text)[:-3]
27+
self.messages.append({"role": "user", "content": prompt})
28+
29+
text = self.tokenizer.apply_chat_template(
30+
self.messages,
31+
tokenize=False,
32+
add_generation_prompt=True
33+
)
34+
token_ids = self.tokenizer.encode(text)
35+
diff = token_ids[len(self.token_ids):]
36+
self.token_ids = token_ids
37+
return token_ids, diff
38+
39+
def encode_with_image(self, question: str, num_of_images: int, imgsz: int, last_reply=None):
40+
if last_reply is not None:
41+
self.messages.append({"role": "assistant", "content": last_reply})
42+
43+
# 根据图片尺寸设定 context_len
44+
if imgsz == 448:
45+
context_len = 256
46+
elif imgsz == 224:
47+
context_len = 64
48+
else:
49+
print(f"Unsupported imgsz: {imgsz}")
50+
return None, None
51+
52+
# 拼接带图片的用户输入
53+
question_with_images = question
54+
if num_of_images > 0:
55+
for _ in range(num_of_images):
56+
question_with_images += "\n<img>" + "<IMG_CONTEXT>" * context_len + "</img>\n"
57+
58+
self.messages.append({"role": "user", "content": question_with_images})
59+
60+
text = self.tokenizer.apply_chat_template(
61+
self.messages,
62+
tokenize=False,
63+
add_generation_prompt=True
64+
)
65+
token_ids = self.tokenizer.encode(text)
66+
diff = token_ids[len(self.token_ids):]
67+
self.token_ids = token_ids
68+
return token_ids, diff
69+
70+
def decode(self, token_ids):
71+
self.token_ids_cache += token_ids
72+
text = self.tokenizer.decode(self.token_ids_cache)
73+
if "\ufffd" in text:
74+
print("Text 中包含非法字符")
75+
return ""
76+
else:
77+
self.token_ids_cache.clear()
78+
return text
79+
80+
@property
81+
def bos_id(self):
82+
return self.tokenizer.bos_token_id
83+
@property
84+
def eos_id(self):
85+
return self.tokenizer.eos_token_id
86+
@property
87+
def bos_token(self):
88+
return self.tokenizer.bos_token
89+
@property
90+
def eos_token(self):
91+
return self.tokenizer.eos_token
92+
@property
93+
def img_start_token(self):
94+
return self.tokenizer.encode("<img>")[0]
95+
@property
96+
def img_context_token(self):
97+
return self.tokenizer.encode("<IMG_CONTEXT>")[0]
98+
99+
def reset(self, system_prompt=None):
100+
if system_prompt is None:
101+
system_prompt = args.content
102+
self.messages = [
103+
{"role": "system", "content": system_prompt},
104+
]
105+
text = self.tokenizer.apply_chat_template(
106+
self.messages,
107+
tokenize=False,
108+
add_generation_prompt=True
109+
)
110+
token_ids = self.tokenizer.encode(text)[:-3]
111+
self.token_ids = token_ids
112+
print(self.decode(token_ids))
113+
return token_ids
114+
115+
116+
class Request(BaseHTTPRequestHandler):
117+
timeout = 5
118+
server_version = 'Apache'
119+
120+
def do_GET(self):
121+
print(self.path)
122+
self.send_response(200)
123+
self.send_header("Content-Type", "application/json")
124+
self.end_headers()
125+
if '/get_uid' in self.path:
126+
new_uid = str(uuid.uuid4())
127+
print("新 uid:", new_uid)
128+
tokenizers[new_uid] = Tokenizer_Http(args.model_id)
129+
msg = json.dumps({'uid': new_uid})
130+
elif '/bos_id' in self.path:
131+
uid = self.get_query_param("uid")
132+
instance: Tokenizer_Http = tokenizers.get(uid)
133+
if instance is None:
134+
msg = json.dumps({'error': 'Invalid uid'})
135+
else:
136+
msg = json.dumps({'bos_id': instance.bos_id if instance.bos_id is not None else -1})
137+
elif '/eos_id' in self.path:
138+
uid = self.get_query_param("uid")
139+
instance: Tokenizer_Http = tokenizers.get(uid)
140+
if instance is None:
141+
msg = json.dumps({'error': 'Invalid uid'})
142+
else:
143+
msg = json.dumps({'eos_id': instance.eos_id if instance.eos_id is not None else -1})
144+
elif '/img_start_token' in self.path:
145+
uid = self.get_query_param("uid")
146+
instance: Tokenizer_Http = tokenizers.get(uid)
147+
if instance is None:
148+
msg = json.dumps({'error': 'Invalid uid'})
149+
else:
150+
msg = json.dumps({'img_start_token': instance.img_start_token})
151+
152+
elif '/img_context_token' in self.path:
153+
uid = self.get_query_param("uid")
154+
instance: Tokenizer_Http = tokenizers.get(uid)
155+
if instance is None:
156+
msg = json.dumps({'error': 'Invalid uid'})
157+
else:
158+
msg = json.dumps({'img_context_token': instance.img_context_token})
159+
else:
160+
msg = json.dumps({'error': 'Invalid GET endpoint'})
161+
print(msg)
162+
self.wfile.write(msg.encode())
163+
164+
def do_POST(self):
165+
content_length = int(self.headers.get('content-length', 0))
166+
data = self.rfile.read(content_length).decode()
167+
req = json.loads(data)
168+
self.send_response(200)
169+
self.send_header("Content-Type", "application/json")
170+
self.end_headers()
171+
172+
if '/encode' in self.path:
173+
uid = req.get('uid')
174+
prompt = req.get('text')
175+
last_reply = req.get('last_reply')
176+
b_img_prompt = False
177+
instance: Tokenizer_Http = tokenizers.get(uid)
178+
if 'img_prompt' in req:
179+
b_img_prompt = req['img_prompt']
180+
if b_img_prompt:
181+
num_img = req['num_img']
182+
imgsz = req['imgsz']
183+
184+
if instance is None:
185+
msg = json.dumps({'error': 'Invalid uid'})
186+
else:
187+
if b_img_prompt:
188+
token_ids, diff = instance.encode_with_image(prompt, num_img, imgsz, last_reply)
189+
else:
190+
token_ids, diff = instance.encode(prompt, last_reply)
191+
msg = json.dumps({'token_ids': token_ids, 'diff': diff})
192+
193+
elif '/decode' in self.path:
194+
uid = req.get('uid')
195+
token_ids = req.get('token_ids')
196+
instance: Tokenizer_Http = tokenizers.get(uid)
197+
if instance is None:
198+
msg = json.dumps({'error': 'Invalid uid'})
199+
else:
200+
text = instance.decode(token_ids)
201+
msg = json.dumps({'text': text})
202+
203+
elif '/reset' in self.path:
204+
uid = req.get("uid")
205+
system_prompt = req.get("system_prompt")
206+
instance: Tokenizer_Http = tokenizers.get(uid)
207+
if instance is None:
208+
msg = json.dumps({'error': 'Invalid uid'})
209+
else:
210+
if system_prompt is not None:
211+
print("system_prompt:", system_prompt)
212+
token_ids = instance.reset(system_prompt)
213+
else:
214+
token_ids = instance.reset()
215+
msg = json.dumps({'token_ids': token_ids})
216+
217+
else:
218+
msg = json.dumps({'error': 'Invalid POST endpoint'})
219+
220+
self.wfile.write(msg.encode())
221+
222+
def get_query_param(self, key):
223+
from urllib.parse import urlparse, parse_qs
224+
query = urlparse(self.path).query
225+
params = parse_qs(query)
226+
values = params.get(key)
227+
return values[0] if values else None
228+
229+
230+
if __name__ == "__main__":
231+
parser = argparse.ArgumentParser()
232+
parser.add_argument("--host", type=str, default="localhost")
233+
parser.add_argument("--port", type=int, default=8080)
234+
parser.add_argument('--model_id', type=str, default='internvl3_tokenizer')
235+
parser.add_argument('--content', type=str, default='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。')
236+
args = parser.parse_args()
237+
238+
host = (args.host, args.port)
239+
print("http://%s:%s" % host)
240+
server = HTTPServer(host, Request)
241+
server.serve_forever()

projects/llm_framework/main_vlm/src/main.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class llm_task {
135135
CONFIG_AUTO_SET(file_body["mode_param"], filename_tokens_embed);
136136
CONFIG_AUTO_SET(file_body["mode_param"], filename_post_axmodel);
137137
CONFIG_AUTO_SET(file_body["mode_param"], filename_vpm_resampler_axmodedl);
138+
CONFIG_AUTO_SET(file_body["mode_param"], filename_image_encoder_axmodedl);
138139
CONFIG_AUTO_SET(file_body["mode_param"], template_filename_axmodel);
139140
CONFIG_AUTO_SET(file_body["mode_param"], b_use_topk);
140141
CONFIG_AUTO_SET(file_body["mode_param"], b_vpm_two_stage);
@@ -218,6 +219,7 @@ class llm_task {
218219
mode_config_.filename_post_axmodel = base_model + mode_config_.filename_post_axmodel;
219220
mode_config_.template_filename_axmodel = base_model + mode_config_.template_filename_axmodel;
220221
mode_config_.filename_vpm_resampler_axmodedl = base_model + mode_config_.filename_vpm_resampler_axmodedl;
222+
mode_config_.filename_image_encoder_axmodedl = base_model + mode_config_.filename_image_encoder_axmodedl;
221223
mode_config_.runing_callback = [this](int *p_token, int n_token, const char *p_str, float token_per_sec,
222224
void *reserve) {
223225
if (this->out_callback_) {
@@ -341,36 +343,39 @@ class llm_task {
341343

342344
if (lLaMa_ctx_) {
343345
if (image_data_.empty()) {
344-
lLaMa_ctx_->Encode(prompt_data_, prompt_complete(prompt_), last_reply, tokens_ids, tokens_diff);
346+
lLaMa_ctx_->Encode(prompt_data_, prompt_complete(msg), last_reply, tokens_ids, tokens_diff);
345347
if (auto ret = lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size());
346348
ret != 0) {
347349
ALOGE("SetKVCache failed: %d,the context may be full,input \"reset\" to reset context", ret);
348350
return;
349351
}
350352
last_reply = lLaMa_ctx_->Run(prompt_data_);
351353
lLaMa_ctx_->GetKVCache(k_caches, v_caches, precompute_len);
354+
if (out_callback_) out_callback_(last_reply, true);
355+
} else {
356+
cv::Mat src = cv::imdecode(image_data_, cv::IMREAD_COLOR);
357+
if (src.empty()) return;
358+
image_data_.clear();
359+
std::vector<unsigned short> img_embed;
360+
if (auto ret = lLaMa_ctx_->Encode(src, img_embed); ret != 0) {
361+
ALOGE("lLaMaCtx.Encode failed");
362+
return;
363+
}
364+
if (auto ret =
365+
lLaMa_ctx_->Encode(img_embed, prompt_data_, prompt_complete(msg), tokens_ids, tokens_diff);
366+
ret != 0) {
367+
ALOGE("lLaMaCtx.Encode failed");
368+
return;
369+
}
370+
if (auto ret = lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size());
371+
ret != 0) {
372+
ALOGE("SetKVCache failed: %d,the context may be full,input \"reset\" to reset context", ret);
373+
return;
374+
}
375+
last_reply = lLaMa_ctx_->Run(prompt_data_);
376+
lLaMa_ctx_->GetKVCache(k_caches, v_caches, precompute_len);
377+
if (out_callback_) out_callback_(last_reply, true);
352378
}
353-
cv::Mat src = cv::imdecode(image_data_, cv::IMREAD_COLOR);
354-
if (src.empty()) return;
355-
image_data_.clear();
356-
std::vector<unsigned short> img_embed;
357-
if (auto ret = lLaMa_ctx_->Encode(src, img_embed); ret != 0) {
358-
ALOGE("lLaMaCtx.Encode failed");
359-
return;
360-
}
361-
if (auto ret =
362-
lLaMa_ctx_->Encode(img_embed, prompt_data_, prompt_complete(prompt_), tokens_ids, tokens_diff);
363-
ret != 0) {
364-
ALOGE("lLaMaCtx.Encode failed");
365-
return;
366-
}
367-
if (auto ret = lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size());
368-
ret != 0) {
369-
ALOGE("SetKVCache failed: %d,the context may be full,input \"reset\" to reset context", ret);
370-
return;
371-
}
372-
last_reply = lLaMa_ctx_->Run(prompt_data_);
373-
lLaMa_ctx_->GetKVCache(k_caches, v_caches, precompute_len);
374379
}
375380
} catch (...) {
376381
SLOGW("lLaMa_->Run have error!");

projects/llm_framework/main_vlm/src/runner/LLM.hpp

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -469,20 +469,8 @@ class LLM {
469469
if (_attr.b_dynamic_load_axmodel_layer) {
470470
layer.layer.deinit();
471471
}
472-
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(),
473-
// bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
474472
}
475473

476-
// ALOGI("prefill time cost: %.2f s", t_cost.cost() / 1000);
477-
478-
// print token_ids
479-
// printf("%s\n", input_str.c_str());
480-
// for (size_t i = 0; i < token_ids.size(); i++)
481-
// {
482-
// printf("%d ", token_ids[i]);
483-
// }
484-
// printf("\n");
485-
486474
int next_token = -1;
487475
t_cqdm cqdm = create_cqdm(_attr.max_token_len, 32);
488476
std::vector<unsigned short> embed(_attr.tokens_embed_size, 0);
@@ -522,10 +510,7 @@ class LLM {
522510
break;
523511
}
524512

525-
// ALOGI("out %d %d", indices, next_token);
526513
embed_selector.getByIndex(next_token, embed);
527-
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(),
528-
// bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
529514

530515
for (int m = 0; m < _attr.axmodel_num; m++) {
531516
if (b_stop) {
@@ -580,10 +565,8 @@ class LLM {
580565
if (_attr.b_dynamic_load_axmodel_layer) {
581566
layer.layer.deinit();
582567
}
583-
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(),
584-
// bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
585568
}
586-
// ALOGI("");
569+
587570
mask[indices] = 0;
588571
{
589572
// post process
@@ -644,9 +627,6 @@ class LLM {
644627
float t_cost_ms = t_cost.cost();
645628
ALOGN("hit eos,avg %.2f token/s\n", token_ids.size() / (t_cost_ms / 1000));
646629

647-
// 去掉 len_of_input 那部分
648-
// token_ids.erase(token_ids.begin(), token_ids.begin() + len_of_input);
649-
650630
final_out = tokenizer->Decode(token_ids);
651631

652632
return final_out;
@@ -1358,23 +1338,6 @@ class LLM_CTX {
13581338
return 0;
13591339
}
13601340

1361-
int Encode(std::vector<unsigned short> &out_embed, std::string prompt = "What is in the image?")
1362-
{
1363-
ImageInfo img_info;
1364-
img_info.img_prompt = false;
1365-
std::vector<int> input_ids = tokenizer->Encode(prompt, img_info);
1366-
if (input_ids.size() > _attr.prefill_token_num) {
1367-
ALOGE("input_ids(%ld) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num);
1368-
return -1;
1369-
}
1370-
out_embed.resize(input_ids.size() * _attr.tokens_embed_size);
1371-
1372-
for (size_t i = 0; i < input_ids.size(); i++) {
1373-
embed_selector.getByIndex(input_ids[i], out_embed.data() + i * _attr.tokens_embed_size);
1374-
}
1375-
1376-
return 0;
1377-
}
13781341

13791342
int Encode(std::vector<std::vector<unsigned short>> &imgs_embed, std::vector<unsigned short> &out_embed,
13801343
std::string prompt, std::vector<int> &tokens_ids, std::vector<int> &tokens_diff)

0 commit comments

Comments
 (0)