1+ from transformers import AutoTokenizer
2+ from http .server import HTTPServer , BaseHTTPRequestHandler
3+ import json
4+ import argparse
5+ import uuid
6+
7+ tokenizers = {}
8+
9+ class Tokenizer_Http ():
10+ def __init__ (self , model_id ):
11+ self .tokenizer = AutoTokenizer .from_pretrained (model_id )
12+ self .messages = [
13+ {"role" : "system" , "content" : "你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。" },
14+ ]
15+ self .token_ids = []
16+ self .token_ids_cache = []
17+
18+ def encode (self , prompt , last_reply = None ):
19+ if last_reply is not None :
20+ self .messages .append ({"role" : "assistant" , "content" : last_reply })
21+ text = self .tokenizer .apply_chat_template (
22+ self .messages ,
23+ tokenize = False ,
24+ add_generation_prompt = True
25+ )
26+ self .token_ids = self .tokenizer .encode (text )[:- 3 ]
27+ self .messages .append ({"role" : "user" , "content" : prompt })
28+
29+ text = self .tokenizer .apply_chat_template (
30+ self .messages ,
31+ tokenize = False ,
32+ add_generation_prompt = True
33+ )
34+ token_ids = self .tokenizer .encode (text )
35+ diff = token_ids [len (self .token_ids ):]
36+ self .token_ids = token_ids
37+ return token_ids , diff
38+
39+ def encode_with_image (self , question : str , num_of_images : int , imgsz : int , last_reply = None ):
40+ if last_reply is not None :
41+ self .messages .append ({"role" : "assistant" , "content" : last_reply })
42+
43+ # 根据图片尺寸设定 context_len
44+ if imgsz == 448 :
45+ context_len = 256
46+ elif imgsz == 224 :
47+ context_len = 64
48+ else :
49+ print (f"Unsupported imgsz: { imgsz } " )
50+ return None , None
51+
52+ # 拼接带图片的用户输入
53+ question_with_images = question
54+ if num_of_images > 0 :
55+ for _ in range (num_of_images ):
56+ question_with_images += "\n <img>" + "<IMG_CONTEXT>" * context_len + "</img>\n "
57+
58+ self .messages .append ({"role" : "user" , "content" : question_with_images })
59+
60+ text = self .tokenizer .apply_chat_template (
61+ self .messages ,
62+ tokenize = False ,
63+ add_generation_prompt = True
64+ )
65+ token_ids = self .tokenizer .encode (text )
66+ diff = token_ids [len (self .token_ids ):]
67+ self .token_ids = token_ids
68+ return token_ids , diff
69+
70+ def decode (self , token_ids ):
71+ self .token_ids_cache += token_ids
72+ text = self .tokenizer .decode (self .token_ids_cache )
73+ if "\ufffd " in text :
74+ print ("Text 中包含非法字符" )
75+ return ""
76+ else :
77+ self .token_ids_cache .clear ()
78+ return text
79+
80+ @property
81+ def bos_id (self ):
82+ return self .tokenizer .bos_token_id
83+ @property
84+ def eos_id (self ):
85+ return self .tokenizer .eos_token_id
86+ @property
87+ def bos_token (self ):
88+ return self .tokenizer .bos_token
89+ @property
90+ def eos_token (self ):
91+ return self .tokenizer .eos_token
92+ @property
93+ def img_start_token (self ):
94+ return self .tokenizer .encode ("<img>" )[0 ]
95+ @property
96+ def img_context_token (self ):
97+ return self .tokenizer .encode ("<IMG_CONTEXT>" )[0 ]
98+
99+ def reset (self , system_prompt = None ):
100+ if system_prompt is None :
101+ system_prompt = args .content
102+ self .messages = [
103+ {"role" : "system" , "content" : system_prompt },
104+ ]
105+ text = self .tokenizer .apply_chat_template (
106+ self .messages ,
107+ tokenize = False ,
108+ add_generation_prompt = True
109+ )
110+ token_ids = self .tokenizer .encode (text )[:- 3 ]
111+ self .token_ids = token_ids
112+ print (self .decode (token_ids ))
113+ return token_ids
114+
115+
116+ class Request (BaseHTTPRequestHandler ):
117+ timeout = 5
118+ server_version = 'Apache'
119+
120+ def do_GET (self ):
121+ print (self .path )
122+ self .send_response (200 )
123+ self .send_header ("Content-Type" , "application/json" )
124+ self .end_headers ()
125+ if '/get_uid' in self .path :
126+ new_uid = str (uuid .uuid4 ())
127+ print ("新 uid:" , new_uid )
128+ tokenizers [new_uid ] = Tokenizer_Http (args .model_id )
129+ msg = json .dumps ({'uid' : new_uid })
130+ elif '/bos_id' in self .path :
131+ uid = self .get_query_param ("uid" )
132+ instance : Tokenizer_Http = tokenizers .get (uid )
133+ if instance is None :
134+ msg = json .dumps ({'error' : 'Invalid uid' })
135+ else :
136+ msg = json .dumps ({'bos_id' : instance .bos_id if instance .bos_id is not None else - 1 })
137+ elif '/eos_id' in self .path :
138+ uid = self .get_query_param ("uid" )
139+ instance : Tokenizer_Http = tokenizers .get (uid )
140+ if instance is None :
141+ msg = json .dumps ({'error' : 'Invalid uid' })
142+ else :
143+ msg = json .dumps ({'eos_id' : instance .eos_id if instance .eos_id is not None else - 1 })
144+ elif '/img_start_token' in self .path :
145+ uid = self .get_query_param ("uid" )
146+ instance : Tokenizer_Http = tokenizers .get (uid )
147+ if instance is None :
148+ msg = json .dumps ({'error' : 'Invalid uid' })
149+ else :
150+ msg = json .dumps ({'img_start_token' : instance .img_start_token })
151+
152+ elif '/img_context_token' in self .path :
153+ uid = self .get_query_param ("uid" )
154+ instance : Tokenizer_Http = tokenizers .get (uid )
155+ if instance is None :
156+ msg = json .dumps ({'error' : 'Invalid uid' })
157+ else :
158+ msg = json .dumps ({'img_context_token' : instance .img_context_token })
159+ else :
160+ msg = json .dumps ({'error' : 'Invalid GET endpoint' })
161+ print (msg )
162+ self .wfile .write (msg .encode ())
163+
164+ def do_POST (self ):
165+ content_length = int (self .headers .get ('content-length' , 0 ))
166+ data = self .rfile .read (content_length ).decode ()
167+ req = json .loads (data )
168+ self .send_response (200 )
169+ self .send_header ("Content-Type" , "application/json" )
170+ self .end_headers ()
171+
172+ if '/encode' in self .path :
173+ uid = req .get ('uid' )
174+ prompt = req .get ('text' )
175+ last_reply = req .get ('last_reply' )
176+ b_img_prompt = False
177+ instance : Tokenizer_Http = tokenizers .get (uid )
178+ if 'img_prompt' in req :
179+ b_img_prompt = req ['img_prompt' ]
180+ if b_img_prompt :
181+ num_img = req ['num_img' ]
182+ imgsz = req ['imgsz' ]
183+
184+ if instance is None :
185+ msg = json .dumps ({'error' : 'Invalid uid' })
186+ else :
187+ if b_img_prompt :
188+ token_ids , diff = instance .encode_with_image (prompt , num_img , imgsz , last_reply )
189+ else :
190+ token_ids , diff = instance .encode (prompt , last_reply )
191+ msg = json .dumps ({'token_ids' : token_ids , 'diff' : diff })
192+
193+ elif '/decode' in self .path :
194+ uid = req .get ('uid' )
195+ token_ids = req .get ('token_ids' )
196+ instance : Tokenizer_Http = tokenizers .get (uid )
197+ if instance is None :
198+ msg = json .dumps ({'error' : 'Invalid uid' })
199+ else :
200+ text = instance .decode (token_ids )
201+ msg = json .dumps ({'text' : text })
202+
203+ elif '/reset' in self .path :
204+ uid = req .get ("uid" )
205+ system_prompt = req .get ("system_prompt" )
206+ instance : Tokenizer_Http = tokenizers .get (uid )
207+ if instance is None :
208+ msg = json .dumps ({'error' : 'Invalid uid' })
209+ else :
210+ if system_prompt is not None :
211+ print ("system_prompt:" , system_prompt )
212+ token_ids = instance .reset (system_prompt )
213+ else :
214+ token_ids = instance .reset ()
215+ msg = json .dumps ({'token_ids' : token_ids })
216+
217+ else :
218+ msg = json .dumps ({'error' : 'Invalid POST endpoint' })
219+
220+ self .wfile .write (msg .encode ())
221+
222+ def get_query_param (self , key ):
223+ from urllib .parse import urlparse , parse_qs
224+ query = urlparse (self .path ).query
225+ params = parse_qs (query )
226+ values = params .get (key )
227+ return values [0 ] if values else None
228+
229+
230+ if __name__ == "__main__" :
231+ parser = argparse .ArgumentParser ()
232+ parser .add_argument ("--host" , type = str , default = "localhost" )
233+ parser .add_argument ("--port" , type = int , default = 8080 )
234+ parser .add_argument ('--model_id' , type = str , default = 'internvl3_tokenizer' )
235+ parser .add_argument ('--content' , type = str , default = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。' )
236+ args = parser .parse_args ()
237+
238+ host = (args .host , args .port )
239+ print ("http://%s:%s" % host )
240+ server = HTTPServer (host , Request )
241+ server .serve_forever ()
0 commit comments