Skip to content

Commit 135671a

Browse files
[feature] add client demo for sending request to easyrec processor (alibaba#363)
* add client demo for sending request to easyrec processor * fix code style
1 parent adc5f25 commit 135671a

File tree

17 files changed

+361
-42
lines changed

17 files changed

+361
-42
lines changed

.git_bin_path

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{"leaf_name": "data/test", "leaf_file": ["data/test/batch_criteo_sample.tfrecord", "data/test/criteo_sample.tfrecord", "data/test/dwd_avazu_ctr_deepmodel_10w.csv", "data/test/embed_data.csv", "data/test/lookup_data.csv", "data/test/tag_kv_data.csv", "data/test/test.csv", "data/test/test_sample_weight.txt", "data/test/test_with_quote.csv"]}
2+
{"leaf_name": "data/test/client", "leaf_file": ["data/test/client/item_lst", "data/test/client/user_table_data", "data/test/client/user_table_schema"]}
23
{"leaf_name": "data/test/criteo_data", "leaf_file": ["data/test/criteo_data/category.bin", "data/test/criteo_data/dense.bin", "data/test/criteo_data/label.bin", "data/test/criteo_data/readme"]}
34
{"leaf_name": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "leaf_file": ["data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/ESTIMATOR_TRAIN_DONE", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/atexit_sync_1661483067", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/checkpoint", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/eval_result.txt", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.index", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.meta", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/pipeline.config", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/version"]}
45
{"leaf_name": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "leaf_file": ["data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/checkpoint", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/eval_result.txt", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.index", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.meta", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/pipeline.config"]}

.git_bin_url

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{"leaf_path": "data/test", "sig": "656d73b4e78d0d71e98120050bc51387", "remote_path": "data/git_oss_sample_data/data_test_656d73b4e78d0d71e98120050bc51387"}
2+
{"leaf_path": "data/test/client", "sig": "d2e000187cebd884ee10e3cf804717fc", "remote_path": "data/git_oss_sample_data/data_test_client_d2e000187cebd884ee10e3cf804717fc"}
23
{"leaf_path": "data/test/criteo_data", "sig": "f224ba0b1a4f66eeda096c88703d3afc", "remote_path": "data/git_oss_sample_data/data_test_criteo_data_f224ba0b1a4f66eeda096c88703d3afc"}
34
{"leaf_path": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "sig": "2bc0c12a09e1f4c39f839972cf09674b", "remote_path": "data/git_oss_sample_data/data_test_distribute_eval_test_deepfm_distribute_eval_dwd_avazu_out_multi_cls_2bc0c12a09e1f4c39f839972cf09674b"}
45
{"leaf_path": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "sig": "9fde5d2987654f268a231a1c69db5799", "remote_path": "data/git_oss_sample_data/data_test_distribute_eval_test_dropoutnet_distribute_eval_taobao_ckpt_9fde5d2987654f268a231a1c69db5799"}

easy_rec/__init__.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# -*- encoding:utf-8 -*-
22
# Copyright (c) Alibaba, Inc. and its affiliates.
3+
34
import logging
45
import os
56
import platform
67
import sys
78

8-
import tensorflow as tf
9-
109
from easy_rec.version import __version__
1110

1211
curr_dir, _ = os.path.split(__file__)
@@ -16,33 +15,36 @@
1615
logging.basicConfig(
1716
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
1817

19-
if platform.system() == 'Linux':
20-
ops_dir = os.path.join(curr_dir, 'python/ops')
21-
if 'PAI' in tf.__version__:
22-
ops_dir = os.path.join(ops_dir, '1.12_pai')
23-
elif tf.__version__.startswith('1.12'):
24-
ops_dir = os.path.join(ops_dir, '1.12')
25-
elif tf.__version__.startswith('1.15'):
26-
ops_dir = os.path.join(ops_dir, '1.15')
18+
# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
19+
if 'PROCESSOR_TEST' not in os.environ:
20+
if platform.system() == 'Linux':
21+
ops_dir = os.path.join(curr_dir, 'python/ops')
22+
import tensorflow as tf
23+
if 'PAI' in tf.__version__:
24+
ops_dir = os.path.join(ops_dir, '1.12_pai')
25+
elif tf.__version__.startswith('1.12'):
26+
ops_dir = os.path.join(ops_dir, '1.12')
27+
elif tf.__version__.startswith('1.15'):
28+
ops_dir = os.path.join(ops_dir, '1.15')
29+
else:
30+
ops_dir = None
2731
else:
2832
ops_dir = None
29-
else:
30-
ops_dir = None
3133

32-
from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
33-
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
34-
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
35-
from easy_rec.python.main import export # isort:skip # noqa: E402
36-
from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
37-
from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
34+
from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
35+
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
36+
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
37+
from easy_rec.python.main import export # isort:skip # noqa: E402
38+
from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
39+
from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
3840

39-
try:
40-
import tensorflow_io.oss
41-
except Exception:
42-
pass
41+
try:
42+
import tensorflow_io.oss
43+
except Exception:
44+
pass
4345

44-
print('easy_rec version: %s' % __version__)
45-
print('Usage: easy_rec.help()')
46+
print('easy_rec version: %s' % __version__)
47+
print('Usage: easy_rec.help()')
4648

4749
_global_config = {}
4850

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# EasyRecProcessor Client
2+
3+
Demo
4+
5+
```bash
6+
python -m easy_rec.python.client.client_demo \
7+
--endpoint 1301055xxxxxxxxx.cn-hangzhou.pai-eas.aliyuncs.com \
8+
--service_name ali_rec_rnk_sample_rt_v3 \
9+
--token MmQ3Yxxxxxxxxxxx \
10+
--table_schema data/test/client/user_table_schema \
11+
--table_data data/test/client/user_table_data \
12+
--item_lst data/test/client/item_lst
13+
14+
# output:
15+
# results {
16+
# key: "item_0"
17+
# value {
18+
# scores: 0.0
19+
# scores: 0.0
20+
# }
21+
# }
22+
# results {
23+
# key: "item_1"
24+
# value {
25+
# scores: 0.0
26+
# scores: 0.0
27+
# }
28+
# }
29+
# results {
30+
# key: "item_2"
31+
# value {
32+
# scores: 0.0
33+
# scores: 0.0
34+
# }
35+
# }
36+
# outputs: "probs_is_click"
37+
# outputs: "probs_is_go"
38+
```
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# -*- encoding:utf-8 -*-
2+
# Copyright (c) Alibaba, Inc. and its affiliates.
3+
import argparse
4+
import logging
5+
import sys
6+
import traceback
7+
8+
from easyrec_request import EasyrecRequest
9+
10+
from easy_rec.python.protos.predict_pb2 import PBFeature
11+
from easy_rec.python.protos.predict_pb2 import PBRequest
12+
13+
logging.basicConfig(
14+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
15+
16+
try:
17+
from eas_prediction import PredictClient # TFRequest
18+
except Exception:
19+
logging.error('eas_prediction is not installed: pip install eas-prediction')
20+
sys.exit(1)
21+
22+
23+
def build_request(table_cols, table_data, item_ids=None):
24+
request_pb = PBRequest()
25+
assert isinstance(table_data, list)
26+
try:
27+
for col_id in range(len(table_cols)):
28+
cname, dtype = table_cols[col_id]
29+
value = table_data[col_id]
30+
feat = PBFeature()
31+
if value is None:
32+
continue
33+
if dtype == 'STRING':
34+
feat.string_feature = value
35+
elif dtype in ('FLOAT', 'DOUBLE'):
36+
feat.float_feature = value
37+
elif dtype == 'BIGINT':
38+
feat.long_feature = value
39+
elif dtype == 'INT':
40+
feat.int_feature = value
41+
42+
request_pb.user_features[cname].CopyFrom(feat)
43+
except Exception:
44+
traceback.print_exc()
45+
sys.exit()
46+
request_pb.item_ids.extend(item_ids)
47+
return request_pb
48+
49+
50+
def parse_table_schema(create_table_sql):
51+
create_table_sql = create_table_sql.lower()
52+
spos = create_table_sql.index('(')
53+
epos = create_table_sql[spos + 1:].index(')')
54+
cols = create_table_sql[(spos + 1):epos]
55+
cols = [x.strip().lower() for x in cols.split(',')]
56+
col_info_arr = []
57+
for col in cols:
58+
col = [k for k in col.split() if k != '']
59+
assert len(col) == 2
60+
col[1] = col[1].upper()
61+
col_info_arr.append(col)
62+
return col_info_arr
63+
64+
65+
def send_request(req_pb, client, debug_level=0):
66+
req = EasyrecRequest()
67+
req.add_feed(req_pb, debug_level)
68+
tmp = client.predict(req)
69+
return tmp
70+
71+
72+
if __name__ == '__main__':
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument(
75+
'--endpoint',
76+
type=str,
77+
default=None,
78+
help='eas endpoint, such as 12345.cn-beijing.pai-eas.aliyuncs.com')
79+
parser.add_argument(
80+
'--service_name', type=str, default=None, help='eas service name')
81+
parser.add_argument(
82+
'--token', type=str, default=None, help='eas service token')
83+
parser.add_argument(
84+
'--table_schema',
85+
type=str,
86+
default=None,
87+
help='user feature table schema path')
88+
parser.add_argument(
89+
'--table_data',
90+
type=str,
91+
default=None,
92+
help='user feature table data path')
93+
parser.add_argument('--item_lst', type=str, default=None, help='item list')
94+
95+
args, _ = parser.parse_known_args()
96+
97+
if args.endpoint is None:
98+
logging.error('--endpoint is not set')
99+
sys.exit(1)
100+
if args.service_name is None:
101+
logging.error('--service_name is not set')
102+
sys.exit(1)
103+
if args.token is None:
104+
logging.error('--token is not set')
105+
sys.exit(1)
106+
if args.table_schema is None:
107+
logging.error('--table_schema is not set')
108+
sys.exit(1)
109+
if args.table_data is None:
110+
logging.error('--table_data is not set')
111+
sys.exit(1)
112+
if args.item_lst is None:
113+
logging.error('--item_lst is not set')
114+
sys.exit(1)
115+
116+
client = PredictClient(args.endpoint, args.service_name)
117+
client.set_token(args.token)
118+
client.init()
119+
120+
with open(args.table_schema, 'r') as fin:
121+
create_table_sql = fin.read().strip()
122+
123+
with open(args.table_data, 'r') as fin:
124+
table_data = fin.read().strip()
125+
126+
table_cols = parse_table_schema(create_table_sql)
127+
table_data = table_data.split(';')
128+
129+
with open(args.item_lst, 'r') as fin:
130+
items = fin.read().strip()
131+
items = items.split(',')
132+
133+
req = build_request(table_cols, table_data, item_ids=items)
134+
resp = send_request(req, client)
135+
logging.info(resp)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) Alibaba, Inc. and its affiliates.
3+
from eas_prediction.request import Request
4+
5+
from easy_rec.python.protos.predict_pb2 import PBRequest
6+
from easy_rec.python.protos.predict_pb2 import PBResponse
7+
8+
# from eas_prediction.request import Response
9+
10+
11+
class EasyrecRequest(Request):
12+
"""Request for tensorflow services whose input data is in format of protobuf.
13+
14+
This class privide methods to fill generate PBRequest and parse PBResponse.
15+
"""
16+
17+
def __init__(self, signature_name=None):
18+
self.request_data = PBRequest()
19+
self.signature_name = signature_name
20+
21+
def __str__(self):
22+
return self.request_data
23+
24+
def set_signature_name(self, singature_name):
25+
"""Set the signature name of the model.
26+
27+
Args:
28+
singature_name: signature name of the model
29+
"""
30+
self.signature_name = singature_name
31+
32+
def add_feed(self, data, dbg_lvl=0):
33+
if not isinstance(data, PBRequest):
34+
self.request_data.ParseFromString(data)
35+
else:
36+
self.request_data = data
37+
self.request_data.debug_level = dbg_lvl
38+
39+
def add_user_fea_flt(self, k, v):
40+
self.request_data.user_features[k].float_feature = float(v)
41+
42+
def add_user_fea_s(self, k, v):
43+
self.request_data.user_features[k].string_feature = str(v)
44+
45+
def set_faiss_neigh_num(self, neigh_num):
46+
self.request_data.faiss_neigh_num = neigh_num
47+
48+
def keep_one_item_ids(self):
49+
item_id = self.request_data.item_ids[0]
50+
self.request_data.ClearField('item_ids')
51+
self.request_data.item_ids.extend([item_id])
52+
53+
def to_string(self):
54+
"""Serialize the request to string for transmission.
55+
56+
Returns:
57+
the request data in format of string
58+
"""
59+
return self.request_data.SerializeToString()
60+
61+
def parse_response(self, response_data):
62+
"""Parse the given response data in string format to the related TFResponse object.
63+
64+
Args:
65+
response_data: the service response data in string format
66+
67+
Returns:
68+
the TFResponse object related the request
69+
"""
70+
self.response = PBResponse()
71+
self.response.ParseFromString(response_data)
72+
return self.response

processor/test.py renamed to easy_rec/python/inference/processor/test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def build_array_proto(array_proto, data, dtype):
6565
'--test_dir', type=str, default=None, help='test directory')
6666
args = parser.parse_args()
6767

68+
if not os.path.exists('processor'):
69+
os.mkdir('processor')
6870
if not os.path.exists(PROCESSOR_ENTRY_LIB):
6971
if not os.path.exists('processor/' + PROCESSOR_FILE):
7072
subprocess.check_output(

easy_rec/python/layers/multihead_cross_attention.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -708,11 +708,12 @@ def embedding_postprocessor(input_tensor,
708708
if use_position_embeddings:
709709
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
710710
with tf.control_dependencies([assert_op]):
711-
with tf.variable_scope("position_embedding", reuse=reuse_position_embedding):
711+
with tf.variable_scope(
712+
'position_embedding', reuse=reuse_position_embedding):
712713
full_position_embeddings = tf.get_variable(
713-
name=position_embedding_name,
714-
shape=[max_position_embeddings, width],
715-
initializer=create_initializer(initializer_range))
714+
name=position_embedding_name,
715+
shape=[max_position_embeddings, width],
716+
initializer=create_initializer(initializer_range))
716717
# Since the position embedding table is a learned variable, we create it
717718
# using a (long) sequence length `max_position_embeddings`. The actual
718719
# sequence length might be shorter than this, for faster training of

0 commit comments

Comments
 (0)