-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathget_react_data.py
526 lines (491 loc) · 21.6 KB
/
get_react_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
import re
from vllm import LLM, SamplingParams
from build_react_prompt import (
build_input_text,
TOOL_DESC,
PROMPT_REACT,
parse_latest_plugin_call,
)
import json
import json5
from transformers import AutoTokenizer
# from agent_demo import *
cpm3_path = "/root/ld/ld_model_pretrained/minicpm3" # if you want to get react data,cpm3_path can be none
model_path = "/root/ld/ld_model_pretrain/Qwen2.5-72B-Instruct-GPTQ-Int4" # 教师模型地址
save_question_json = "/root/ld/ld_project/MiniCPM-CookBook/agent_demo/question_react.json" # 保存query的json地址
save_complex_question_json = "/root/ld/ld_project/MiniCPM-CookBook/agent_demo/question_complex_react.json"
save_react_qa_json = "/root/ld/ld_project/MiniCPM-CookBook/agent_demo/react_qa_react.json" # 保存react的json地址
inference_batch_size = 8 # 教师模型生成数据时的batch
gen_datas_per_tool = 10 # 每个tool生成多少条react数据
cpm3_data_save_path = "/root/ld/ld_project/pull_request/MiniCPM_Series_Tutorial/agent_demo/cpm3_fc_train_data.json" # cpm3的数据保存json地址
tools = [
{
"name_for_human": "图生文",
"name_for_model": "image_gen_prompt",
"excute_function": False, # 是否可以使用这个工具进行函数调用以生成数据
"description_for_model": "图生文是一个可以看图生成文字描述的服务,输入一张图片的地址,将返回图片详细逼真的表述",
"example": "帮我看一下www.baidu.com/img/PCtm_d9c8750bed0b3c7d089fa7d55720d6cf.png这张图片上的今日股价是多少",
"parameters": [
{
"name": "image_path",
"description": "需要图片描述的URL或者本地地址",
"scope": None, # 这个参数的取值范围,如果不限定为None
"required": True, # 这个是否必须
"schema": {"type": "string"},
}
],
},
{
"name_for_human": "知识图谱",
"name_for_model": "knowledge_graph",
"excute_function": True,
"description_for_model": "知识图谱是输入武器种类获取该武器的属性,也可以输入某种属性获得所有武器的该属性",
"example": "帮我查一下敌方直升机的续航里程",
"parameters": [
{
"name": "weapon_query",
"description": "武器名称",
"scope": ["直升机", "坦克", "反坦克导弹", "直升机", "火箭炮", "所有武器"], # 参数的取值范围
"required": True,
"schema": {"type": "string"},
},
{
"name": "attribute",
"description": "武器的属性",
"scope": ["射程", "续航里程", "重量", "速度", "承载量", "适应场景", "克制武器"],
"required": True,
"schema": {"type": "string"},
},
],
},
]
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1,
"frequency_penalty": 1.0,
"temperature": 0.8,
"top_p": 0.8,
"top_k": -1,
"stop": None,
"stop_token_ids": None,
"ignore_eos": False,
"max_tokens": 1000,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
def save_cpm3_data(cpm3_data_path, cpm3_data):
# 将列表转换为 JSON 格式的字符串
json_str = json.dumps(cpm3_data, ensure_ascii=False, indent=4)
# 将 JSON 字符串保存到文件
with open(cpm3_data_path, "w", encoding="utf-8") as json_file:
json_file.write(json_str)
def switch_cpm_tool(tools):
format_tool = {
"type": "function",
"function": {
"name": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "The customer's order ID.",
}
},
"required": ["order_id"],
"additionalProperties": False,
},
},
}
cpm_tools = []
for tool in tools:
format_tool["function"]["name"] = tool["name_for_model"]
format_tool["function"]["description"] = tool["description_for_model"]
# format_tool['function']["parameters"]['properties']=
required_list = []
for param in tool["parameters"]:
"""param{
'name': 'weapon_query',
'description': '武器名称',
'scope':['直升机','坦克','反坦克导弹','直升机','火箭炮','所有武器'],
'required': True,
'schema': {'type': 'string'},
}"""
format_tool["function"]["parameters"]["properties"][param["name"]] = {
"type": param["schema"]["type"],
"description": param["description"],
}
if param["required"]:
required_list.append(param["name"])
format_tool["function"]["parameters"]["required"] = required_list
format_tool["function"]["parameters"]["additionalProperties"] = False
cpm_tools.append(format_tool)
return cpm_tools
def function_call(plugin_name, plugin_args):
args_dict = json5.loads(plugin_args)
if plugin_name == "knowledge_graph":
weapon_name = args_dict["weapon_query"]
attribute = args_dict["attribute"]
kg = {
"直升机": {
"飞行高度": "0.3km以内",
"携带武器": "火箭弹",
"克制武器": "对空导弹",
"重量": "3000kg",
"速度": "100km/h",
"射程": "2km",
"适应场景": "空战",
"续航": "500km",
"满载人数": "7人",
"承载重量": "10000kg",
"续航里程": "1000km",
},
"反坦克导弹": {
"重量": "100kg",
"射程": "0.5千米",
"克制武器": "拦截导弹",
"适应场景": "打击重装甲武器",
"速度": "200km/h",
},
"步兵": {
"射程": "0.3km",
"克制武器": "无人机",
"适应场景": "陆地",
"速度": "40km/h",
"重量": "60kg",
"承载重量": "50kg",
},
"无人机": {
"速度": "100km/h",
"重量": "10kg",
"适应场景": "侦察和暗杀",
"飞行高度": "0.3km以下",
"克制武器": "电磁攻击",
"续航": "50km",
},
"豹2A7坦克": {
"速度": "50km/h",
"携带武器": "激光炮",
"克制武器": "反坦克导弹",
"射程": "5km",
"重量": "10000kg",
"续航": "1000km",
"承载重量": "200000kg",
"满载人数": "5人",
"适应场景": "野战和掩护步兵",
},
"黑狐坦克": {
"速度": "70km/h",
"携带武器": "主炮",
"克制武器": "反坦克导弹",
"射程": "15km",
"重量": "10000kg",
"承载重量": "50000kg",
"续航": "1000km",
"满载人数": "5人",
"适应场景": "野战和掩护步兵",
},
"火箭炮": {
"速度": "4500km/h",
"重量": "500kg",
"射程": "1000km",
"适应场景": "超远程打击",
"飞行高度": "万米高空",
"克制武器": "拦截导弹",
},
"雷达": {"重量": "5000kg", "探测范围": "2km以上20km以下", "适应场景": "探测敌军"},
"装甲车": {
"速度": "80km/h",
"携带武器": "副炮",
"克制武器": "穿甲弹",
"射程": "0.5km",
"重量": "10000kg",
"承载重量": "10000kg",
"续航": "600km",
"满载人数": "10人",
},
"狙击枪": {"射程": "1.2km", "重量": "30kg", "适应场景": "暗杀"},
}
if weapon_name != "所有武器":
try:
return "{}的{}是:{}".format(
weapon_name, attribute, kg[weapon_name][attribute]
)
except:
if weapon_name not in kg:
return "该武器不存在"
else:
return "{}的{}属性不存在".format(weapon_name, attribute)
return kg
def split_react_data(react_str):
pattern = re.compile(
r"Thought:\s*(.*?)\nAction:\s*(.*?)\nAction Input:\s*(.*?)\nObservation:\s*(.*?)\nThought:\s*(.*?)\nFinal Answer:\s*(.*)",
re.DOTALL,
)
matches = pattern.findall(react_str)
try:
for match in matches:
Thought1 = match[0]
Action = match[1]
Action_Input = match[2]
Observation = match[3]
Thought2 = match[4]
Final_Answer = match[5]
return Thought1, Action, Action_Input, Observation, Thought2, Final_Answer
except:
return None, None, None, None, None, None
def get_answer_from_output(output):
pattern = r"「问题开始」(.*?)「问题结束」"
questions = re.findall(pattern, output, re.DOTALL)
questions = [q.strip() for q in questions]
return questions
def get_complex_question_from_output(output):
pattern = r"「复杂问题开始」(.*?)「复杂问题结束」"
questions = re.findall(pattern, output, re.DOTALL)
pattern2 = r"「任务规划开始」(.*?)「任务规划结束」"
task_plan = re.findall(pattern2, output, re.DOTALL)
assert len(task_plan) == len(questions)
questions = [q.strip()+'\n'+task_plan[i] for i,q in enumerate(questions)]
return questions
def get_tool_description(tool):
tool_descp = "工具名称是{},作用是{},".format(
tool["name_for_model"], tool["description_for_model"]
)
for t in tool["parameters"]:
if t["required"]:
if t["scope"]:
tool_descp += "参数”{}“是必须输入的,作用是{},该参数的取值范围是{}。".format(
t["name"], t["description"], t["scope"]
)
else:
tool_descp += "参数“{}”是必须输入的,作用是{}。".format(t["name"], t["description"])
elif t["scope"]:
tool_descp += "参数“{}”是可选的,作用是{},该参数的取值范围是{}。".format(
t["name"], t["description"], t["scope"]
)
else:
tool_descp += "参数“{}”是可选的,作用是{}。".format(t["name"], t["description"])
return tool_descp
def get_question():
if 'llm' not in locals():
llm = LLM(
model=model_path,
tensor_parallel_size=8,
max_model_len=4096,
dtype="bfloat16",
trust_remote_code=True,
enforce_eager=True,
gpu_memory_utilization=0.8,
)
prompt_template = """你是一个智能助手,现在我请你为以下工具生成问题,要求生成的问题能够被这个工具解决。工具的详细介绍如下:\n{}\n我现在给你一个关于此工具问题的示例「问题开始」
{}「问题结束」,接下来请你根据此示例和工具描述再生成{}个能够使用该工具解决的问题,并且用「问题开始」和「问题结束」将其包裹。"""
all_questions = []
all_react_prompt = []
questinos_dict = {}
for tool in tools:
questions = []
while True:
tool_description = get_tool_description(tool)
input_prompt = prompt_template.format(
tool_description, tool["example"], gen_datas_per_tool
)
input_prompt = """<|im_start|> system\n you are a helpful assistant<|im_end|>\n<|im_start|> user\n {}<|im_end|>\n<|im_start|> assistant\n""".format(
input_prompt
)
outputs = llm.generate(input_prompt, sampling_params)
output = outputs[0].outputs[0].text
questions.extend(get_answer_from_output(output))
if len(questions) >= gen_datas_per_tool:
all_questions.extend(questions)
print(questions)
questinos_dict[tool["name_for_model"]] = questions
break
with open(save_question_json, "w", encoding="utf-8") as f:
json.dump(questinos_dict, f, ensure_ascii=False, indent=4)
print("{}条输入指令已经保存到{}".format(len(all_questions), save_question_json))
def get_complex_question():
if 'llm' not in locals():
llm = LLM(
model=model_path,
tensor_parallel_size=8,
max_model_len=4096,
dtype="bfloat16",
trust_remote_code=True,
enforce_eager=True,
gpu_memory_utilization=0.8,
)
example = """「复杂问题开始」请你帮我画出地方坦克的克制武器的图片。「复杂问题结束」
「任务规划开始」1. 使用知识图谱获取敌方坦克的克制武器A.\n2.使用知识图谱获取武器A的所有属性B\n3.根据武器A及其所有属性B使用文生图工具画出武器A的图片C。「任务规划结束」
"""
prompt_template = """\n所有工具的详细介绍如上所示,你是一个智能助手,现在我请你为以下工具生成3个复合问题,要求生成的问题是具体的问题,有具体的目标,实体,任务,且能够被这几个工具所解决,并且最少这个复杂问题最少要使用两个以上的工具才能完成。接下来请你根据每个工具的简单问题示例和工具描述再生成{}个能够使用以上最少两个工具解决的复杂问题及其解决方案,接下来我会给你一个示例:/n{},\n.并且严格按照示例的格式,用「复杂问题开始」和「复杂问题结束」以及「任务规划开始」和「任务规划结束」进行包裹."""
all_questions = []
tool_prompt = ''
questinos_dict = {}
for index,tool in enumerate(tools):
questions = []
tool_description = get_tool_description(tool)
tool_prompt += '\n第{}个工具:'.format(index+1)+tool_description
input_prompt = tool_prompt + prompt_template.format(gen_datas_per_tool,example)
input_prompt = """<|im_start|> system\n You are a helpful assistant<|im_end|>\n<|im_start|> user\n {}<|im_end|>\n<|im_start|> assistant\n""".format(
input_prompt
)
while True:
outputs = llm.generate(input_prompt, sampling_params)
output = outputs[0].outputs[0].text
print(output)
questions.extend(get_complex_question_from_output(output))
if len(questions) >= gen_datas_per_tool:
all_questions.extend(questions)
print(questions)
questinos_dict[tool["name_for_model"]] = questions
break
with open(save_complex_question_json, "w", encoding="utf-8") as f:
json.dump(questinos_dict, f, ensure_ascii=False, indent=4)
print("{}条输入指令已经保存到{}".format(len(all_questions), save_complex_question_json))
def get_react_data():
if 'llm' not in locals():
llm = LLM(
model=model_path,
tensor_parallel_size=8,
max_model_len=4096,
dtype="bfloat16",
trust_remote_code=True,
enforce_eager=True,
gpu_memory_utilization=0.8,
)
with open(save_question_json, "r", encoding="utf-8") as file:
# 将json文件内容解析为Python对象
all_questions = json.load(file)
react_question = [build_input_text([(q, "")], tools) for q in all_questions]
params_dict["top_k"] = 1
params_dict["stop"] = ["Observation:"]
react_qa = []
sampling_params = SamplingParams(**params_dict)
for index in range(0, len(react_question), inference_batch_size):
outputs = llm.generate(
react_question[index : index + inference_batch_size], sampling_params
)
for i in range(len(outputs)):
output = outputs[i].outputs[0].text
try:
plugin_name, plugin_args, text = parse_latest_plugin_call(output)
excute_flag = True
for tool in tools:
if (
tool["name_for_model"] == plugin_name
and tool["excute_function"] == False
):
excute_flag = False
second_input = (
react_question[index + i] + output + "Observation: "
)
output2 = (
llm.generate(second_input, sampling_params)[0]
.outputs[0]
.text
)
if excute_flag:
observation = function_call(plugin_name, plugin_args)
second_input = (
react_question[index + i]
+ output
+ "Observation: {}".format(observation)
)
output2 = llm.generate(second_input, sampling_params)[0].outputs[0].text
print(output2)
# react_qa.append({react_question[index+i]: second_input[len(react_question[index+i]):]+output2})
react_qa.append(
{
"instruction": "You are a helpful assistant.",
"input": react_question[index + i][75:-33],
"output": second_input[len(react_question[index + i]) :]
+ output2,
}
)
except:
pass
with open(save_react_qa_json, "w", encoding="utf-8") as f:
json.dump(react_qa, f, ensure_ascii=False, indent=4)
print("{}条react qa数据已经保存到{}".format(len(react_qa), save_react_qa_json))
def get_cpm_function_call():
with open(save_react_qa_json, "r", encoding="utf-8") as file:
# 将json文件内容解析为Python对象
react_qa = json.load(file)
cpm_tool = switch_cpm_tool(tools)
tokenizer = AutoTokenizer.from_pretrained(cpm3_path, trust_remote_code=True)
cpm_fc_train_data = []
for react in react_qa:
messages = [
{
"role": "system",
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
}
]
query = react["input"].split("Question: ")[-1]
print(query)
react_str = list(react.values())[-1]
Thought1, Action, Action_Input, Observation, Thought2, Final_Answer = split_react_data(
react_str
)
if (
Thought1
and Action
and Action_Input
and Observation
and Thought2
and Final_Answer
):
messages.append({"role": "user", "content": query})
prompt = tokenizer.apply_chat_template(
messages, tools=cpm_tool, tokenize=False, add_generation_prompt=True
)
cpm_thought1 = "<|thought_start|>\n{}\n<|thought_end|>".format(Thought1)
cpm_function_and_param = "\n<|tool_call_start|>\n```python\n{}({})\n```\n<|tool_call_end|>".format(
Action, re.sub(": ", "=", Action_Input)
)
cpm_fc_train_data.append(
[
{"role": "system", "content": prompt.split("<|im_end|>")[0][19:]},
{"role": "user", "content": query},
{
"role": "assistant",
"content": cpm_thought1 + cpm_function_and_param,
},
]
)
cpm_response = "<|im_end|>\n<|im_start|>tool\n{}<|im_end|>\n<|im_start|>assistant\n".format(
Observation
)
cpm_thought2 = "<|thought_start|>\n{}\n<|thought_end|>\n".format(Thought2)
cpm_answer = Final_Answer
cpm_fc_train_data.append(
[
{"role": "system", "content": prompt.split("<|im_end|>")[0][19:]},
{
"role": "user",
"content": query
+ "<|im_start|>assistant\n"
+ cpm_function_and_param
+ cpm_response,
},
{"role": "assistant", "content": cpm_thought2 + cpm_answer},
]
)
else:
print(1)
continue
# cpm_fc_train_data.append({"role":"system",'content':prompt+cpm_function_and_param+cpm_response,'role':'assistant','content':cpm_thought2+cpm_answer})
save_cpm3_data(cpm3_data_save_path, cpm_fc_train_data)
print(
"{}条cpm3 function call数据已经保存到{}".format(
len(cpm_fc_train_data), cpm3_data_save_path
)
)
if __name__ == "__main__":
get_complex_question()
# get_react_data()
# get_cpm_function_call()