Skip to content

Commit

Permalink
RAFT adding promptflow flows
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvidal committed May 29, 2024
1 parent 504da8e commit a3564be
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 0 deletions.
11 changes: 11 additions & 0 deletions raft/azure-ai-studio-ft/pf/raft/extract_final_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from promptflow import tool

# The inputs section will change based on the arguments of the tool function, after you save the code
# Adding type to arguments and return value will help the system show the types properly
# Please update the function name/signature per need
@tool
def extract_final_answer(cot_answer: str) -> str:
"""
Extracts the final answer from the cot_answer field
"""
return cot_answer.split("<ANSWER>: ")[-1]
46 changes: 46 additions & 0 deletions raft/azure-ai-studio-ft/pf/raft/flow.dag.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
id: bring_your_own_data_qna
name: Bring Your Own Data QnA
inputs:
question:
type: string
default: How to use SDK V2?
is_chat_input: false
context:
type: string
is_chat_input: false
outputs:
answer:
type: string
reference: ${extract_final_answer.output}
explanation:
type: string
reference: ${llama2_7b_finetuned_model.output}
nodes:
- name: construct_prompt_with_context
type: prompt
source:
type: code
path: prompt.jinja2
inputs:
context: ${inputs.context}
question: ${inputs.question}
use_variants: false
- name: llama2_7b_finetuned_model
type: python
source:
type: code
path: test_llm.py
inputs:
prompt: ${construct_prompt_with_context.output}
use_variants: false
- name: extract_final_answer
type: python
source:
type: code
path: extract_final_answer.py
inputs:
cot_answer: ${llama2_7b_finetuned_model.output}
use_variants: false
node_variants: {}
environment:
python_requirements_txt: requirements.txt
38 changes: 38 additions & 0 deletions raft/azure-ai-studio-ft/pf/raft/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from promptflow import tool
from promptflow.connections import CustomConnection
from openai import OpenAI


# The inputs section will change based on the arguments of the tool function, after you save the code
# Adding type to arguments and return value will help the system show the types properly
# Please update the function name/signature per need
@tool
def classify_intent(prompt: str) -> str:

endpoint_url = "xxx"
api_key = "xxx"

if not api_key:
raise Exception("A key should be provided to invoke the endpoint")

base_url = endpoint_url + '/v1'
client = OpenAI(
base_url = base_url,
api_key=api_key,
)

deployment_name = "Llama-2-7b-raft-bats-18k-unrrr"

# COMPLETION API
response = client.completions.create(
model=deployment_name,
prompt=prompt,
stop="<STOP>",
temperature=0.5,
max_tokens=512,
top_p=0.1,
best_of=1,
presence_penalty=0,
)

return response.choices[0].text.strip()
12 changes: 12 additions & 0 deletions raft/azure-ai-studio-ft/pf/rag/extract_final_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from promptflow import tool
import re

# The inputs section will change based on the arguments of the tool function, after you save the code
# Adding type to arguments and return value will help the system show the types properly
# Please update the function name/signature per need
@tool
def extract_final_answer(cot_answer: str) -> str:
"""
Extracts the final answer from the cot_answer field
"""
return re.sub(r'<STOP>\s*$', '', cot_answer.split("<ANSWER>: ")[-1])
46 changes: 46 additions & 0 deletions raft/azure-ai-studio-ft/pf/rag/flow.dag.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
id: bring_your_own_data_qna
name: Bring Your Own Data QnA
inputs:
question:
type: string
default: How to use SDK V2?
is_chat_input: false
context:
type: string
is_chat_input: false
outputs:
answer:
type: string
reference: ${extract_final_answer.output}
explanation:
type: string
reference: ${llama2_7b_chat_base_model.output}
nodes:
- name: llama2_7b_chat_base_model
type: custom_llm
source:
type: package_with_prompt
path: llama2_7b_finetuned.jinja2
tool: promptflow.tools.open_model_llm.OpenModelLLM.call
inputs:
api: completion
endpoint_name: serverlessEndpoint/Llama-2-7b-chat-gmqyf
deployment_name: default
temperature: 1
max_new_tokens: 500
top_p: 1
model_kwargs: {}
context: ${inputs.context}
question: ${inputs.question}
use_variants: false
- name: extract_final_answer
type: python
source:
type: code
path: extract_final_answer.py
inputs:
cot_answer: ${llama2_7b_chat_base_model.output}
use_variants: false
node_variants: {}
environment:
python_requirements_txt: requirements.txt

0 comments on commit a3564be

Please sign in to comment.