[dts] [PATCH v2 03/16] framework/utils: support locks for parallel model

Marvin Liu yong.liu at intel.com
Wed Jan 10 01:11:01 CET 2018


1. Add parallel lock support which can protect critical resources and
actions. Parallel locks are function level and separated between DUTs.
2. Add user-defined serialzer function support in pprint function.
3. Remove rsa key action will only do once for all virtual machines.

Signed-off-by: Marvin Liu <yong.liu at intel.com>

diff --git a/framework/utils.py b/framework/utils.py
index 1ecef09..762c927 100644
--- a/framework/utils.py
+++ b/framework/utils.py
@@ -35,9 +35,95 @@ import os
 import inspect
 import socket
 import struct
+import threading
+import types
+from functools import wraps
 
 DTS_ENV_PAT = r"DTS_*"
 
+def create_parallel_locks(num_duts):
+    """
+    Create thread lock dictionary based on DUTs number
+    """
+    global locks_info
+    locks_info = []
+    for _ in range(num_duts):
+        lock_info = dict()
+        lock_info['update_lock'] = threading.RLock()
+        locks_info.append(lock_info)
+
+
+def parallel_lock(num=1):
+    """
+    Wrapper function for protect parallel threads, allow mulitple threads
+    share one lock. Locks are created based on function name. Thread locks are
+    separated between duts according to argument 'dut_id'.
+    Parameter:
+        num: Number of parallel threads for the lock
+    """
+    global locks_info
+
+    def decorate(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if 'dut_id' in kwargs:
+                dut_id = kwargs['dut_id']
+            else:
+                dut_id = 0
+
+            # in case function arguments is not correct
+            if dut_id >= len(locks_info):
+                dut_id = 0
+
+            lock_info = locks_info[dut_id]
+            uplock = lock_info['update_lock']
+
+            name = func.__name__
+            uplock.acquire()
+
+            if name not in lock_info:
+                lock_info[name] = dict()
+                lock_info[name]['lock'] = threading.RLock()
+                lock_info[name]['current_thread'] = 1
+            else:
+                lock_info[name]['current_thread'] += 1
+
+            lock = lock_info[name]['lock']
+
+            # make sure when owned global lock, should also own update lock
+            if lock_info[name]['current_thread'] >= num:
+                if lock._is_owned():
+                    print RED("DUT%d %s waiting for func lock %s" % (dut_id,
+                              threading.current_thread().name, func.__name__))
+                lock.acquire()
+            else:
+                uplock.release()
+
+            try:
+                ret = func(*args, **kwargs)
+            except Exception as e:
+                if not uplock._is_owned():
+                    uplock.acquire()
+
+                if lock._is_owned():
+                    lock.release()
+                    lock_info[name]['current_thread'] = 0
+                uplock.release()
+                raise e
+
+            if not uplock._is_owned():
+                uplock.acquire()
+
+            if lock._is_owned():
+                lock.release()
+                lock_info[name]['current_thread'] = 0
+
+            uplock.release()
+
+            return ret
+        return wrapper
+    return decorate
+
 
 def RED(text):
     return "\x1B[" + "31;1m" + str(text) + "\x1B[" + "0m"
@@ -51,11 +137,11 @@ def GREEN(text):
     return "\x1B[" + "32;1m" + str(text) + "\x1B[" + "0m"
 
 
-def pprint(some_dict):
+def pprint(some_dict, serialzer=None):
     """
     Print JSON format dictionary object.
     """
-    return json.dumps(some_dict, sort_keys=True, indent=4)
+    return json.dumps(some_dict, sort_keys=True, indent=4, default=serialzer)
 
 
 def regexp(s, to_match, allString=False):
@@ -83,26 +169,13 @@ def get_obj_funcs(obj, func_name_regex):
             yield func
 
 
+ at parallel_lock()
 def remove_old_rsa_key(crb, ip):
     """
     Remove the old RSA key of specified IP on crb.
     """
-    if ':' not in ip:
-        ip = ip.strip()
-        port = ''
-    else:
-        addr = ip.split(':')
-        ip = addr[0].strip()
-        port = addr[1].strip()
-
     rsa_key_path = "~/.ssh/known_hosts"
-    if port:
-        remove_rsa_key_cmd = "sed -i '/^\[%s\]:%d/d' %s" % \
-            (ip.strip(), int(
-             port), rsa_key_path)
-    else:
-        remove_rsa_key_cmd = "sed -i '/^%s/d' %s" % \
-            (ip.strip(), rsa_key_path)
+    remove_rsa_key_cmd = "sed -i '/%s/d' %s" % (ip, rsa_key_path)
     crb.send_expect(remove_rsa_key_cmd, "# ")
 
 
-- 
1.9.3



More information about the dts mailing list