Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance: calling python from wasm is too slow #140

Open
muayyad-alsadi opened this issue Apr 5, 2023 · 1 comment
Open

performance: calling python from wasm is too slow #140

muayyad-alsadi opened this issue Apr 5, 2023 · 1 comment

Comments

@muayyad-alsadi
Copy link
Contributor

muayyad-alsadi commented Apr 5, 2023

as mentioned in parent performance ticket
#96 (comment)

we have a bottleneck when python is called from inside wasm

here is a minimum reproducible example

unlike #96 which is to be solved by #139
the loop here lives inside wasm and the functions are in python, C (python's math) and wasm
this issue is not solved by #139

experiment setup

(module
  (func $log_i (import "env" "log_i") (param i32))
  (func $math_gcd (import "env" "math_gcd") (param i32 i32) (result i32))
  (func $python_gcd (import "env" "python_gcd") (param i32 i32) (result i32))
  (func $gcd (export "gcd") (param i32 i32) (result i32)
  ;; ...
  )
  (func $wasm_gcd_loop (export "wasm_gcd_loop") (param $n i32) (param $a i32) (param $b i32) (result i32)
  ;; call wasm_gcd(a,b) n times
  )
  (func $math_gcd_loop (export "math_gcd_loop") (param $n i32) (param $a i32) (param $b i32) (result i32)
  ;; call math_gcd(a,b) n times
  )
  (func $python_gcd_loop (export "python_gcd_loop") (param $n i32) (param $a i32) (param $b i32) (result i32)
  ;; call python_gcd(a,b) n times
  )
)

I've logged i and logged gcd to make sure loop is working fine using log_i()
I've removed the logging
I've imported C gcd from python's math package
I've imported python gcd

and did a loop inside wasm calling the 3 versions

print("*** wasm loop benchmark: ")
for name in "wasm_gcd_loop", "math_gcd_loop", "python_gcd_loop":
    gcdf = by_name[name]
    start_time = time.perf_counter()
    g = gcdf(N, a, b)
    total_time = time.perf_counter() - start_time
    print(total_time, "\t\t", name)

results

9.899199358187616e-05            wasm_gcd_loop
0.01052837198949419              math_gcd_loop
0.012672090000705793             python_gcd_loop

conclusion

when wasm loop call's wasm function it was too fast, it was 100x faster than calling C's gcd
since C's gcd is faster than wasm gcd as shown in other benchmark, then this bottleneck is wrapping python part.

profiling

here is how we profile calling C from wasm

pr=Profile()
pr.enable()
print(math_gcd_loop(N, a, b))
pr.disable()
pr.print_stats()
pr.print_stats('cumulative')
pr.print_stats('calls')

using the production 7.0.0, the obvious bootleneck is in:

  • _func.py:172(trampoline)
  • _value.py:129(_convert)
  Ordered by: call count

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     7024    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
     3019    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
     2006    0.001    0.000    0.001    0.000 _bindings.py:178(wasm_valtype_kind)
     2005    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
     2001    0.001    0.000    0.001    0.000 _value.py:167(_value)
     1009    0.000    0.000    0.000    0.000 {built-in method _ctypes.POINTER}
     1008    0.000    0.000    0.000    0.000 {built-in method __new__ of type object at 0x7fc0d01430a0}
     1007    0.001    0.000    0.001    0.000 _types.py:45(_from_ptr)
     1007    0.001    0.000    0.002    0.000 _types.py:86(__del__)
     1004    0.000    0.000    0.000    0.000 _value.py:109(__init__)
     1004    0.000    0.000    0.000    0.000 _value.py:117(__del__)
     1004    0.000    0.000    0.000    0.000 _value.py:161(_unwrap_raw)
     1003    0.001    0.000    0.001    0.000 _bindings.py:172(wasm_valtype_new)
     1003    0.002    0.000    0.002    0.000 _value.py:39(i32)
     1003    0.002    0.000    0.012    0.000 _value.py:129(_convert)
     1003    0.001    0.000    0.003    0.000 _types.py:12(i32)
     1003    0.001    0.000    0.002    0.000 _types.py:54(__eq__)
     1003    0.001    0.000    0.001    0.000 _bindings.py:120(wasm_valtype_delete)
     1000    0.000    0.000    0.000    0.000 _func.py:121(__init__)
     1000    0.005    0.000    0.020    0.000 _func.py:172(trampoline)
     1000    0.000    0.000    0.001    0.000 _func.py:245(get)
     1000    0.000    0.000    0.001    0.000 _value.py:156(_into_raw)
     1000    0.000    0.000    0.000    0.000 typing.py:1737(cast)
     1000    0.000    0.000    0.000    0.000 {built-in method math.gcd}
     1000    0.000    0.000    0.000    0.000 {built-in method builtins.delattr}

by time

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.022    0.022 _func.py:59(__call__)
        1    0.002    0.002    0.022    0.022 _bindings.py:2491(wasmtime_func_call)
     1000    0.005    0.000    0.020    0.000 _func.py:172(trampoline)
     1003    0.002    0.000    0.012    0.000 _value.py:129(_convert)
     1003    0.001    0.000    0.003    0.000 _types.py:12(i32)
     1003    0.001    0.000    0.002    0.000 _types.py:54(__eq__)
     1003    0.002    0.000    0.002    0.000 _value.py:39(i32)
     1007    0.001    0.000    0.002    0.000 _types.py:86(__del__)
     2006    0.001    0.000    0.001    0.000 _bindings.py:178(wasm_valtype_kind)
     1007    0.001    0.000    0.001    0.000 _types.py:45(_from_ptr)
     2001    0.001    0.000    0.001    0.000 _value.py:167(_value)
     1000    0.000    0.000    0.001    0.000 _value.py:156(_into_raw)
     1003    0.001    0.000    0.001    0.000 _bindings.py:120(wasm_valtype_delete)
     1003    0.001    0.000    0.001    0.000 _bindings.py:172(wasm_valtype_new)
     7024    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
     1000    0.000    0.000    0.001    0.000 _func.py:245(get)
@muayyad-alsadi
Copy link
Contributor Author

here is the code

import time
import wasmtime as wa

from math import gcd as math_gcd
from functools import partial
from wasmtime import Store, Module, Instance

def python_gcd(x, y):
    while y:
        x, y = y, x % y
    return abs(x)


store = Store()
module = Module(
    store.engine,
    """
(module
  (func $log_i (import "env" "log_i") (param i32))
  (func $math_gcd (import "env" "math_gcd") (param i32 i32) (result i32))
  (func $python_gcd (import "env" "python_gcd") (param i32 i32) (result i32))
  (func $gcd (export "gcd") (param i32 i32) (result i32)
    (local i32)
    block  ;; label = @1
      block  ;; label = @2
        local.get 0
        br_if 0 (;@2;)
        local.get 1
        local.set 2
        br 1 (;@1;)
      end
      loop  ;; label = @2
        local.get 1
        local.get 0
        local.tee 2
        i32.rem_u
        local.set 0
        local.get 2
        local.set 1
        local.get 0
        br_if 0 (;@2;)
      end
    end
    local.get 2
  )
  (func $wasm_gcd_loop (export "wasm_gcd_loop") (param $n i32) (param $a i32) (param $b i32) (result i32)
    (local $i i32)
    (local $r i32)
    ;; i=0
    i32.const 0
    local.set $i
    block $y_block
      loop $y_loop
        ;; if (i>=n) break
        local.get $i
        local.get $n
        i32.ge_s
        br_if $y_block
        ;; loop body
        local.get $a
        local.get $b
        call $gcd
        local.set $r
        ;;;; log_i(i)
        ;;local.get $i
        ;;call $log_i
        ;;;; log_i(r)
        ;;local.get $r
        ;;call $log_i
        ;; ++i
        local.get $i
        i32.const 1
        i32.add
        local.set $i
        br $y_loop
      end
    end
    local.get $r
  )

  (func $math_gcd_loop (export "math_gcd_loop") (param $n i32) (param $a i32) (param $b i32) (result i32)
    (local $i i32)
    (local $r i32)
    ;; i=0
    i32.const 0
    local.set $i
    block $y_block
      loop $y_loop
        ;; if (i>=n) break
        local.get $i
        local.get $n
        i32.ge_s
        br_if $y_block
        ;; loop body
        local.get $a
        local.get $b
        call $math_gcd
        local.set $r
        ;;;; log_i(i)
        ;;local.get $i
        ;;call $log_i
        ;;;; log_i(r)
        ;;local.get $r
        ;;call $log_i
        ;; ++i
        local.get $i
        i32.const 1
        i32.add
        local.set $i
        br $y_loop
      end
    end
    local.get $r
  )

  (func $python_gcd_loop (export "python_gcd_loop") (param $n i32) (param $a i32) (param $b i32) (result i32)
    (local $i i32)
    (local $r i32)
    ;; i=0
    i32.const 0
    local.set $i
    block $y_block
      loop $y_loop
        ;; if (i>=n) break
        local.get $i
        local.get $n
        i32.ge_s
        br_if $y_block
        ;; loop body
        local.get $a
        local.get $b
        call $python_gcd
        local.set $r
        ;;;; log_i(i)
        ;;local.get $i
        ;;call $log_i
        ;;;; log_i(r)
        ;;local.get $r
        ;;call $log_i
        ;; ++i
        local.get $i
        i32.const 1
        i32.add
        local.set $i
        br $y_loop
      end
    end
    local.get $r
  )
)
""",
)

def log_i(i):
    print("** log_i: ", i)

linker = wa.Linker(store.engine)
linker.define_func("env", "log_i", wa.FuncType([wa.ValType.i32()], []), log_i)
linker.define_func("env", "math_gcd", wa.FuncType([wa.ValType.i32(), wa.ValType.i32()], [wa.ValType.i32()]), math_gcd)
linker.define_func("env", "python_gcd", wa.FuncType([wa.ValType.i32(), wa.ValType.i32()], [wa.ValType.i32()]), python_gcd)
instance = linker.instantiate(store, module)

gcd_store = instance.exports(store)["gcd"]
wasm_gcd = partial(gcd_store, store)

wasm_gcd_loop = partial(instance.exports(store)["wasm_gcd_loop"], store)
math_gcd_loop = partial(instance.exports(store)["math_gcd_loop"], store)
python_gcd_loop = partial(instance.exports(store)["python_gcd_loop"], store)

a = 16516842
b = 154654684
N = 1_000

print(math_gcd(a, b), python_gcd(a, b), wasm_gcd(a, b))
print(wasm_gcd_loop(N, a, b))
print(math_gcd_loop(N, a, b))
print(python_gcd_loop(N, a, b))


by_name = locals()
print("*** python loop benchmark: ")
for name in "math_gcd", "python_gcd", "wasm_gcd":
    gcdf = by_name[name]
    start_time = time.perf_counter()
    for _ in range(N):
        g = gcdf(a, b)
    total_time = time.perf_counter() - start_time
    print(total_time, "\t\t", name)

print("*** wasm loop benchmark: ")
for name in "wasm_gcd_loop", "math_gcd_loop", "python_gcd_loop":
    gcdf = by_name[name]
    start_time = time.perf_counter()
    g = gcdf(N, a, b)
    total_time = time.perf_counter() - start_time
    print(total_time, "\t\t", name)

from cProfile import Profile

pr=Profile()
pr.enable()
print(math_gcd_loop(N, a, b))
pr.disable()
pr.print_stats()
pr.print_stats('cumulative')
pr.print_stats('calls')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant